R/random_forest_helpers.r

Defines functions train_caret_rf rf_predictions random_forest_summary plot_rf_predictions

Documented in plot_rf_predictions random_forest_summary rf_predictions train_caret_rf

#' Plot predictions from a randomForest model
#'
#' @param rf_predictions A vector of predictions from a randomForest model.
#'
#' @return p A ggplot plot.
#' @export
#'
plot_rf_predictions <- function(rf_predictions){
  # Convert to long for plotting
  rf_predictions.long <- tidyr::gather(rf_predictions,
                                       key = .data$predicted_class,
                                       .data$probability,
                                       -.data$rna_seq_lib,
                                       -.data$prediction, -.data$prior_label)
  # Generate the plot
  p <- ggplot2::ggplot(rf_predictions.long,
                       aes(x = factor(.data$predicted_class),
                           y = .data$probability,
                           group = .data$rna_seq_lib,
                           colour = .data$prior_label,
                           label = .data$rna_seq_lib))
  p <- p + ggplot2::geom_point() + ggplot2::geom_line()
  p <- p + ggplot2::scale_colour_brewer(palette='Set1')
  return(p)
}

#' From a random forest model generated by caret, generate summary plots
#'
#' @param caret.rf A randomForest model generated by caret
#' @param yvals Feature in the experimental design data frame to use as data labels
#'
#' @return caret.forest.importance A data frame containing variable importance for the random forest model
#' @export
#'
random_forest_summary <- function(caret.rf, yvals){
  # Training summary
  print(caret.rf)
  print(caret.rf$finalModel)
  caret::confusionMatrix(caret.rf)
  # Extract the final forest
  caret.forest <- caret.rf$finalModel
  # Plot the variable importances
  randomForest::varImpPlot(caret.forest)
  # Make the MDSplot
  randomForest::MDSplot(caret.forest, yvals)
  # And the margin plot
  caret.forest.margins <- randomForest::margin(caret.forest,
                                               observed=yvals)
  graphics::plot(caret.forest.margins, main='Margin Plot')
  graphics::abline(h=0, lty=2)
  # Return a data-frame of the selected variables
  caret.forest.importance <- as.data.frame(caret.forest$importance)
  caret.forest.importance$gene <- rownames(caret.forest.importance)
  return(caret.forest.importance)
}

#' Extract predictions from a random forest model
#'
#' @param caret.rf A randomForest model generated by caret
#' @param yvals Feature in the experimental design data frame to use as data labels
#'
#' @return rf.predict A vector of predictions from the randomForest model
#' @export
#'
rf_predictions <- function(caret.rf, yvals){
  # Extract the final forest
  caret.forest <- caret.rf$finalModel
  # Generate predictions and tidy
  rf.call <- broom::tidy(stats::predict(caret.forest))
  rf.probs <- broom::tidy(stats::predict(caret.forest, type="prob"))
  rf.predict <- dplyr::bind_cols(rf.call, rf.probs)
  rf.predict <- rf.predict %>%
    dplyr::mutate(rna_seq_lib = .data$.rownames,
                  prediction = as.factor(.data$x)) %>%
    dplyr::select(-.data$x, -.data$.rownames)
  return(rf.predict)
}

#' Train a random forest model with caret
#'
#' @param gene_set A vector containing gene names to subset
#' @param yvals Feature in the experimental design data frame to use as data labels
#' @param exp.design Experimental design data frame
#' @param counts.mat A matrix of gene count data
#'
#' @return caret.rf A an object containing a caret Randomforest
#' @export
#'
train_caret_rf <- function(gene_set, yvals, exp.design, counts.mat){
  # Find rows to keep
  rows_to_keep <- rownames(counts.mat) %in% gene_set
  # Subset the count matrix to the genes of interest
  counts.subset.mat <- counts.mat[rows_to_keep,]
  # Transpose the matrix
  counts.subset.transpose.mat <- t(counts.subset.mat)
  # Use caret to train a RF classifier
  caret.rf <- caret::train(x = counts.subset.transpose.mat,
                           y = yvals,
                           method="rf",
                           trControl=caret::trainControl(method="cv",number=5),
                           prox=TRUE,
                           allowParallel=TRUE,
                           ntree=5000)
  return(caret.rf)
}
rdocking/amlpmpsupport documentation built on Jan. 4, 2021, 7:09 a.m.