#' Score kmer-importance through a random forest
#'
#' Kmer "selection" performed in one (if kmer candidates are supplied) or two steps, the first step looks at odds ratios
#' and p-values in a fisher-test (see `kmer_freq`) to categorize kmers as being significantly
#' associated with mutation probability. These kmers together with trinucleotide patterns are
#' incorporated in a random forest model. The mean decrease in gini from this forest together
#' with odds-ratios and p-values from the fisher test can be used to estimate kmer importance
#'
#' @param dataset Granges object, with a `sequence.pyr` column containing sequence region and `mut.pyr` column containing mutations
#' @param ks Int. Size of kmers to be used in the model
#' @param pval_cutoff Numeric. Parameter for the fisher test
#' @param kmers Character vector of candidate kmers (optional). Note that arguments "ks", "pval_cutoff" and "n_keep" is ignored if candidates are already supplied
#' @param n_keep Positive Int. Number of kmers to include after preselection
#' @param maxnodes Parameter controlling the depth of the desicion trees in the random forest
#' @param cores Number of cores to use for parallelization
#' @param n_trees Number of trees in forest
#' @param include_fit Bool. Include resulting fit from random forest training
#' @return A list containing: (1) MeanDecreaseGini information on kmers and (2) a 3D-array of p-values and odd-ratios of kmers and
#' optionally (3) the random forest fit
#' @export
kmer_random_forest <- function(dataset, ks = 5, kmers = NULL, pval_cutoff = 0.001, n_keep = 80, maxnodes = 20, cores = NULL, n_trees = 720, include_fit = FALSE) {
if (is.null(kmers)) {
log_info("`kmer_random_forest` : kmers was not supplied, will call `kmer_freq` to get candidate kmers")
base_kmers = kmer_freq(dataset, ks = ks, pval_cutoff = pval_cutoff, n_keep = n_keep)
kmers = rownames(base_kmers)
} else {
ks = unique(nchar(kmers))
if (length(ks) != 1) {
log_error("`kmer_random_forest` : kmers must be of the same length for argument `kmers`")
}
base_kmers = {
mut_counts = count_muts(dataset, ks)
fisher_test = fisher_test_mutations(mut_counts)
fisher_test[rownames(fisher_test) %in% kmers,,]
}
}
if (length(kmers) < 1) {
log_warning("No kmers passed feature pre-selection")
return("")
} else {
dat_enc = encode_seqs(dataset, kmers)
log_debug("Starting random forest...")
rf = train_forest(dat_enc, maxnodes = maxnodes, cores = cores, n_trees = n_trees)
kmer_importance = forest_kmer_importance_by_gini(rf)
if (include_fit) {
return(list(
random_forest_results = merge_rf_result(kmer_importance, base_kmers, pval_cutoff),
fisher_results = base_kmers,
fit = rf
))
} else {
return(list(
random_forest_results = merge_rf_result(kmer_importance, base_kmers, pval_cutoff),
fisher_results = base_kmers
))
}
}
}
#' Plot Feature Selection Results
#'
#' Create a scatterplot from the results given by `kmer_random_forest` where
#' x and y axis show the mean decrease in gini for cytosines and thymines
#' respectively, and color indicate the log odds ratio for mutated kmers
#' @param rf_res Output from a `kmer_random_forest` run
#' @param rm_trinucs Remove trinucleotides in output
#' @return A ggplot panel showing MeanDecreaseGini on cytosines and thymines
#' @export
plot_kmer_forest = function(rf_res, rm_trinucs = FALSE) {
dat = rf_res$random_forest_results
if (rm_trinucs) {
dat = dat[!(is.na(dat$mut_type)),]
}
ggplot2::ggplot(dat,
ggplot2::aes(x = MeanDecreaseGini.cytosines,
y = MeanDecreaseGini.thymines,
color = lor_value)) +
ggplot2::geom_point(ggplot2::aes(shape = mut_type)) +
ggplot2::geom_text(ggplot2::aes(label = kmer)) +
ggplot2::scale_color_gradient2(mid = 'green', high = 'red', low = 'blue')
}
merge_rf_result = function(importance, preselection, pval_cutoff) {
log_info("Merging fisher statistics with random forest results")
obj = importance
obj[is.na(obj)] = 0
log_debug(paste("`merge_rf_result` : dim(preselection) =", paste(dim(preselection), collapse = ', ')))
lor_annotation = {
N = nrow(preselection)
kmers = rownames(preselection)
mut_classes = colnames(preselection)
lor = log(preselection[,,'odds_ratio'])
lor.abs = abs(lor)
mut_type = character(N)
lor_value = numeric(N)
for (i in seq(N)) {
mask = preselection[i,,'pvalue'] <= pval_cutoff
lors = lor.abs[i,mask]
max_abs_lor = max(lors)
final = mask & (lor.abs[i,] == max_abs_lor)
mut_type[i] = paste(mut_classes[final], sep = '&')
lor_value[i] = lor[i,final]
}
data.frame(
kmer = kmers,
mut_type = mut_type,
lor_value = lor_value
)
}
merge(obj, lor_annotation, by = 'kmer', all = TRUE)
}
train_forest = function(dat, maxnodes = 20, cores = NULL, n_trees = 720) {
if (is.null(cores) || cores <= 1) {
train_aux = function(x) randomForest::randomForest(x$contexts, x$mutations, maxnodes = maxnodes, ntree = n_trees)
} else {
`%dopar%` = foreach::`%dopar%`
n_subtrees = ceiling(n_trees / cores)
doParallel::registerDoParallel(cores = cores)
train_aux = function(x) {
foreach::foreach(ntree = rep(n_subtrees, cores),
.combine = randomForest::combine,
.multicombine = TRUE,
.packages = "randomForest") %dopar% randomForest::randomForest(x$contexts, x$mutations, maxnodes = maxnodes, ntree = ntree)
}
}
on_pyrs(dat, train_aux)
}
forest_kmer_importance_by_gini = function(forest) {
forest_importance = on_pyrs(forest, randomForest::importance)
forest_importance = on_pyrs(forest_importance, as.data.frame)
forest_importance = on_pyrs(forest_importance, function(x) { x$kmer = rownames(x); x})
ret = merge(forest_importance$cytosines,
forest_importance$thymines,
by = 'kmer',
all = TRUE)
colnames(ret) = c(
"kmer",
"MeanDecreaseGini.cytosines",
"MeanDecreaseGini.thymines"
)
ret
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.