R/kmer_random_forest.R

Defines functions forest_kmer_importance_by_gini train_forest merge_rf_result plot_kmer_forest kmer_random_forest

Documented in kmer_random_forest plot_kmer_forest

#' Score kmer-importance through a random forest
#'
#' Kmer "selection" performed in one (if kmer candidates are supplied) or two steps, the first step looks at  odds ratios
#' and p-values in a fisher-test (see `kmer_freq`) to categorize kmers as being significantly
#' associated with mutation probability. These kmers together with trinucleotide patterns are
#' incorporated in a random forest model. The mean decrease in gini from this forest together 
#' with odds-ratios and p-values from the fisher test can be used to estimate kmer importance
#'
#' @param dataset Granges object, with a `sequence.pyr` column containing sequence region and `mut.pyr` column containing mutations
#' @param ks Int. Size of kmers to be used in the model
#' @param pval_cutoff Numeric. Parameter for the fisher test
#' @param kmers Character vector of candidate kmers (optional). Note that arguments "ks", "pval_cutoff" and "n_keep" is ignored if candidates are already supplied
#' @param n_keep Positive Int. Number of kmers to include after preselection
#' @param maxnodes Parameter controlling the depth of the desicion trees in the random forest
#' @param cores Number of cores to use for parallelization
#' @param n_trees Number of trees in forest
#' @param include_fit Bool. Include resulting fit from random forest training
#' @return A list containing: (1) MeanDecreaseGini information on kmers and (2) a 3D-array of p-values and odd-ratios of kmers and
#' optionally (3) the random forest fit
#' @export
kmer_random_forest <- function(dataset, ks = 5, kmers = NULL, pval_cutoff = 0.001,  n_keep = 80, maxnodes = 20, cores = NULL, n_trees = 720, include_fit = FALSE) {
  if (is.null(kmers)) {
    log_info("`kmer_random_forest` : kmers was not supplied, will call `kmer_freq` to get candidate kmers")
    base_kmers = kmer_freq(dataset, ks = ks, pval_cutoff = pval_cutoff, n_keep = n_keep)
    kmers         = rownames(base_kmers)
  } else {
    ks = unique(nchar(kmers))
    if (length(ks) != 1) {
      log_error("`kmer_random_forest` : kmers must be of the same length for argument `kmers`")
    }
    base_kmers = {
      mut_counts  = count_muts(dataset, ks)
      fisher_test = fisher_test_mutations(mut_counts)
      fisher_test[rownames(fisher_test) %in% kmers,,]
    }
  }
  if (length(kmers) < 1) {
    log_warning("No kmers passed feature pre-selection")
    return("")
  } else {
    dat_enc            = encode_seqs(dataset, kmers)
    log_debug("Starting random forest...")
    rf                 = train_forest(dat_enc, maxnodes = maxnodes, cores = cores, n_trees = n_trees)
    kmer_importance = forest_kmer_importance_by_gini(rf)
    if (include_fit) {
      return(list(
        random_forest_results = merge_rf_result(kmer_importance, base_kmers, pval_cutoff),
        fisher_results        = base_kmers,
        fit                   = rf
      ))
    } else {
      return(list(
        random_forest_results = merge_rf_result(kmer_importance, base_kmers, pval_cutoff),
        fisher_results        = base_kmers
      ))
    }
  }
}

#' Plot Feature Selection Results
#'
#' Create a scatterplot from the results given by `kmer_random_forest` where
#' x and y axis show the mean decrease in gini for cytosines and thymines
#' respectively, and color indicate the log odds ratio for mutated kmers
#' @param rf_res Output from a `kmer_random_forest` run
#' @param rm_trinucs Remove trinucleotides in output
#' @return A ggplot panel showing MeanDecreaseGini on cytosines and thymines
#' @export
plot_kmer_forest = function(rf_res, rm_trinucs = FALSE) {
  dat = rf_res$random_forest_results
  if (rm_trinucs) {
    dat = dat[!(is.na(dat$mut_type)),]
  }
  ggplot2::ggplot(dat,
                  ggplot2::aes(x = MeanDecreaseGini.cytosines,
                               y = MeanDecreaseGini.thymines,
                               color = lor_value)) +
    ggplot2::geom_point(ggplot2::aes(shape = mut_type)) +
    ggplot2::geom_text(ggplot2::aes(label = kmer)) +
    ggplot2::scale_color_gradient2(mid = 'green', high = 'red', low = 'blue')
}


merge_rf_result = function(importance, preselection, pval_cutoff) {
  log_info("Merging fisher statistics with random forest results")
  obj             = importance
  obj[is.na(obj)] = 0
  
  log_debug(paste("`merge_rf_result` : dim(preselection) =", paste(dim(preselection), collapse = ', ')))
  
  lor_annotation = {
    N           = nrow(preselection)
    kmers       = rownames(preselection)
    mut_classes = colnames(preselection)
    lor         = log(preselection[,,'odds_ratio'])
    lor.abs     = abs(lor)
    mut_type    = character(N)
    lor_value   = numeric(N)
    for (i in seq(N)) {
      mask         = preselection[i,,'pvalue'] <= pval_cutoff
      lors         = lor.abs[i,mask]
      max_abs_lor  = max(lors)
      final        = mask & (lor.abs[i,] == max_abs_lor)
      
      mut_type[i]  = paste(mut_classes[final], sep = '&')
      lor_value[i] = lor[i,final]
    }
    
    data.frame(
      kmer = kmers,
      mut_type = mut_type,
      lor_value = lor_value
    )
  }
  
  merge(obj, lor_annotation, by = 'kmer', all = TRUE)
}
  

train_forest = function(dat, maxnodes = 20, cores = NULL, n_trees = 720) {
  if (is.null(cores) || cores <= 1) {
    train_aux = function(x) randomForest::randomForest(x$contexts, x$mutations, maxnodes = maxnodes, ntree = n_trees)
  } else {
    `%dopar%` = foreach::`%dopar%`
    n_subtrees = ceiling(n_trees / cores)
    doParallel::registerDoParallel(cores = cores)
    train_aux = function(x) {
      
      foreach::foreach(ntree = rep(n_subtrees, cores),
                       .combine = randomForest::combine,
                       .multicombine = TRUE,
                       .packages = "randomForest") %dopar% randomForest::randomForest(x$contexts, x$mutations, maxnodes = maxnodes, ntree = ntree)
    }
  }
  on_pyrs(dat, train_aux)
}


forest_kmer_importance_by_gini = function(forest) {
  forest_importance = on_pyrs(forest,            randomForest::importance)
  forest_importance = on_pyrs(forest_importance, as.data.frame)
  forest_importance = on_pyrs(forest_importance, function(x) { x$kmer = rownames(x); x})
  ret = merge(forest_importance$cytosines,
              forest_importance$thymines,
              by = 'kmer',
              all = TRUE)
  colnames(ret) = c(
    "kmer",
    "MeanDecreaseGini.cytosines",
    "MeanDecreaseGini.thymines"
  )
  ret
}
lindberg-m/contextendR documentation built on Jan. 8, 2022, 3:16 a.m.