#' Plot predictions from a randomForest model
#'
#' @param rf_predictions A vector of predictions from a randomForest model.
#'
#' @return p A ggplot plot.
#' @export
#'
plot_rf_predictions <- function(rf_predictions){
# Convert to long for plotting
rf_predictions.long <- tidyr::gather(rf_predictions,
key = .data$predicted_class,
.data$probability,
-.data$rna_seq_lib,
-.data$prediction, -.data$prior_label)
# Generate the plot
p <- ggplot2::ggplot(rf_predictions.long,
aes(x = factor(.data$predicted_class),
y = .data$probability,
group = .data$rna_seq_lib,
colour = .data$prior_label,
label = .data$rna_seq_lib))
p <- p + ggplot2::geom_point() + ggplot2::geom_line()
p <- p + ggplot2::scale_colour_brewer(palette='Set1')
return(p)
}
#' From a random forest model generated by caret, generate summary plots
#'
#' @param caret.rf A randomForest model generated by caret
#' @param yvals Feature in the experimental design data frame to use as data labels
#'
#' @return caret.forest.importance A data frame containing variable importance for the random forest model
#' @export
#'
random_forest_summary <- function(caret.rf, yvals){
# Training summary
print(caret.rf)
print(caret.rf$finalModel)
caret::confusionMatrix(caret.rf)
# Extract the final forest
caret.forest <- caret.rf$finalModel
# Plot the variable importances
randomForest::varImpPlot(caret.forest)
# Make the MDSplot
randomForest::MDSplot(caret.forest, yvals)
# And the margin plot
caret.forest.margins <- randomForest::margin(caret.forest,
observed=yvals)
graphics::plot(caret.forest.margins, main='Margin Plot')
graphics::abline(h=0, lty=2)
# Return a data-frame of the selected variables
caret.forest.importance <- as.data.frame(caret.forest$importance)
caret.forest.importance$gene <- rownames(caret.forest.importance)
return(caret.forest.importance)
}
#' Extract predictions from a random forest model
#'
#' @param caret.rf A randomForest model generated by caret
#' @param yvals Feature in the experimental design data frame to use as data labels
#'
#' @return rf.predict A vector of predictions from the randomForest model
#' @export
#'
rf_predictions <- function(caret.rf, yvals){
# Extract the final forest
caret.forest <- caret.rf$finalModel
# Generate predictions and tidy
rf.call <- broom::tidy(stats::predict(caret.forest))
rf.probs <- broom::tidy(stats::predict(caret.forest, type="prob"))
rf.predict <- dplyr::bind_cols(rf.call, rf.probs)
rf.predict <- rf.predict %>%
dplyr::mutate(rna_seq_lib = .data$.rownames,
prediction = as.factor(.data$x)) %>%
dplyr::select(-.data$x, -.data$.rownames)
return(rf.predict)
}
#' Train a random forest model with caret
#'
#' @param gene_set A vector containing gene names to subset
#' @param yvals Feature in the experimental design data frame to use as data labels
#' @param exp.design Experimental design data frame
#' @param counts.mat A matrix of gene count data
#'
#' @return caret.rf A an object containing a caret Randomforest
#' @export
#'
train_caret_rf <- function(gene_set, yvals, exp.design, counts.mat){
# Find rows to keep
rows_to_keep <- rownames(counts.mat) %in% gene_set
# Subset the count matrix to the genes of interest
counts.subset.mat <- counts.mat[rows_to_keep,]
# Transpose the matrix
counts.subset.transpose.mat <- t(counts.subset.mat)
# Use caret to train a RF classifier
caret.rf <- caret::train(x = counts.subset.transpose.mat,
y = yvals,
method="rf",
trControl=caret::trainControl(method="cv",number=5),
prox=TRUE,
allowParallel=TRUE,
ntree=5000)
return(caret.rf)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.