Nothing
#' Run Random Forest Classification on Cytokine Data,
#'
#' This function trains and evaluates a Random Forest classification model on
#' cytokine data. It includes feature importance visualization, cross-
#' validation for feature selection, and performance metrics such as accuracy,
#' sensitivity, and specificity. Optionally, for binary classification, the
#' function also plots the ROC curve and computes the AUC.
#'
#' @param data A data frame containing the cytokine data, with one column as
#' the grouping variable and the rest as numerical features.
#' @param group_col A string representing the name of the column with the
#' grouping variable (the target variable for classification).
#' @param ntree An integer specifying the number of trees to grow in the forest
#' (default is 500).
#' @param mtry An integer specifying the number of variables randomly selected
#' at each split (default is 5).
#' @param train_fraction A numeric value between 0 and 1 representing the
#' proportion of data to use for training (default is 0.7).
#' @param plot_roc A logical value indicating whether to plot the ROC curve and
#' compute the AUC for binary classification (default is FALSE).
#' @param k_folds An integer specifying the number of folds for
#' cross-validation (default is 5).
#' @param step A numeric value specifying the fraction of variables to remove
#' at each step during cross-validation for feature selection (default is
#' 0.5).
#' @param run_rfcv A logical value indicating whether to run Random Forest
#' cross-validation for feature selection (default is TRUE).
#' @param verbose A logical value indicating whether to print additional
#' informational output to the console. When \code{TRUE}, the function will
#' display progress messages, and intermediate results when
#' \code{FALSE} (the default), it runs quietly.
#' @param seed An integer specifying the seed for reproducibility (default is 123).
#'
#' @return A list containing:
#' \item{model}{The trained Random Forest model.}
#' \item{confusion_matrix}{The confusion matrix of the test set predictions.}
#' \item{importance_plot}{A ggplot object showing the variable importance
#' plot based on Mean Decrease Gini.}
#' \item{rfcv_result}{Results from Random Forest cross-validation for feature
#' selection (if \code{run_rfcv} is TRUE).}
#' \item{importance_data}{A data frame containing the variable importance
#' based on the Gini index.}
#'
#' @details
#' The function fits a Random Forest model to the provided data by splitting it
#' into training and test sets. It calculates performance metrics such as
#' accuracy, sensitivity, and specificity for both sets. For binary
#' classification, it can also plot the ROC curve and compute the AUC. If
#' \code{run_rfcv} is TRUE, cross-validation is performed to select the optimal
#' number of features.
#' If \code{verbose} is TRUE, the function prints additional information to the
#' console, including training results, test results, and plots.
#'
#' @examples
#' data.df0 <- ExampleData1
#' data.df <- data.frame(data.df0[, 1:3], log2(data.df0[, -c(1:3)]))
#' data.df <- data.df[, -c(2:3)]
#' data.df <- dplyr::filter(data.df, Group != "ND")
#'
#' cyt_rf(
#' data = data.df, group_col = "Group", k_folds = 5, ntree = 1000,
#' mtry = 4, run_rfcv = TRUE, plot_roc = TRUE, verbose = FALSE
#' )
#'
#' @importFrom randomForest randomForest rfcv importance
#' @importFrom caret createDataPartition confusionMatrix
#' @import ggplot2
#' @importFrom pROC roc auc ggroc
#' @importFrom utils capture.output
#' @export
#'
cyt_rf <- function(data, group_col, ntree = 500, mtry = 5,
train_fraction = 0.7, plot_roc = FALSE,
k_folds = 5, step = 0.5, run_rfcv = TRUE,
verbose = FALSE,
seed = 123) {
# Ensure the grouping variable is a factor
data[[group_col]] <- as.factor(data[[group_col]])
# Split the data into training and testing sets
set.seed(seed) # For reproducibility
train_indices <- caret::createDataPartition(data[[group_col]], p = train_fraction, list = FALSE)
train_data <- data[train_indices, ]
test_data <- data[-train_indices, ]
# Prepare the formula for the random forest model
predictors <- setdiff(colnames(data), group_col)
formula_rf <- as.formula(paste(group_col, "~", paste(predictors, collapse = "+")))
# Fit the Random Forest model on training data
rf_model <- randomForest::randomForest(formula_rf, data = train_data, ntree = ntree,
mtry = mtry, importance = TRUE, do.trace = FALSE)
# Print basic Random Forest results for training set if verbose is TRUE
if (verbose) {
cat("\n### RANDOM FOREST RESULTS ON TRAINING SET ###\n")
cat(paste(capture.output(rf_model), collapse = "\n"), "\n")
}
# Calculate training metrics (from training confusion matrix)
train_confusion <- rf_model$confusion[, 1:(ncol(rf_model$confusion) - 1)]
accuracy_train <- sum(diag(train_confusion)) / sum(train_confusion)
if (verbose) {
cat("\nAccuracy on training set: ", round(accuracy_train, 3), "\n")
}
for (i in seq_len(nrow(train_confusion))) {
true_positives <- train_confusion[i, i]
false_negatives <- sum(train_confusion[i, ]) - true_positives
false_positives <- sum(train_confusion[, i]) - true_positives
true_negatives <- sum(train_confusion) - (true_positives + false_negatives + false_positives)
sensitivity <- true_positives / (true_positives + false_negatives)
specificity <- true_negatives / (true_negatives + false_positives)
if (verbose) {
cat("\nTraining set - Class '", rownames(train_confusion)[i], "' metrics:\n", sep = "")
cat(" Sensitivity: ", round(sensitivity, 3), "\n")
cat(" Specificity: ", round(specificity, 3), "\n")
}
}
# Make predictions on the test set and calculate confusion matrix
rf_pred <- predict(rf_model, newdata = test_data)
if (verbose) {
cat("\n### PREDICTIONS ON TEST SET ###\n")
}
confusion_mat <- caret::confusionMatrix(rf_pred, test_data[[group_col]])
if (verbose) {
cat(paste(capture.output(print(confusion_mat$table)), collapse = "\n"), "\n")
cat("\nAccuracy on test set: ", round(confusion_mat$overall["Accuracy"], 3), "\n")
}
# Report sensitivity and specificity on the test set
if (length(levels(data[[group_col]])) == 2) {
# Two-class case: Sensitivity and Specificity are single values
sensitivity_val <- confusion_mat$byClass["Sensitivity"]
specificity_val <- confusion_mat$byClass["Specificity"]
if (verbose) {
cat("\nSensitivity by class:\n")
cat(paste0("Class: ", levels(data[[group_col]])[1], ": ", round(sensitivity_val, 3), "\n"))
cat(paste0("Class: ", levels(data[[group_col]])[2], ": ", round(1 - specificity_val, 3), "\n"))
cat("\nSpecificity by class:\n")
cat(paste0("Class: ", levels(data[[group_col]])[2], ": ", round(specificity_val, 3), "\n"))
cat(paste0("Class: ", levels(data[[group_col]])[1], ": ", round(1 - sensitivity_val, 3), "\n"))
}
} else {
# Multi-class case: Sensitivity and Specificity are vectors
sensitivity_vec <- confusion_mat$byClass[, "Sensitivity"]
specificity_vec <- confusion_mat$byClass[, "Specificity"]
if (verbose) {
cat("\nSensitivity by class:\n")
for (i in seq_along(sensitivity_vec)) {
cat(paste0("Class: ", levels(data[[group_col]])[i], ": ", round(sensitivity_vec[i], 3), "\n"))
}
cat("\nSpecificity by class:\n")
for (i in seq_along(specificity_vec)) {
cat(paste0("Class: ", levels(data[[group_col]])[i], ": ", round(specificity_vec[i], 3), "\n"))
}
}
}
# ROC curve and AUC for binary classification, if requested
roc_plot <- NULL
if (plot_roc && length(levels(data[[group_col]])) == 2) {
rf_prob <- predict(rf_model, newdata = test_data, type = "prob")[, 2]
roc_obj <- pROC::roc(test_data[[group_col]], rf_prob, quiet = TRUE)
auc_value <- pROC::auc(roc_obj)
if (verbose) {
cat("\nAUC: ", round(auc_value, 3), "\n")
}
roc_plot <- pROC::ggroc(roc_obj, color = "blue", linewidth = 1.5, legacy.axes = TRUE) +
ggplot2::geom_abline(linetype = "dashed", color = "red", linewidth = 1) +
ggplot2::labs(
title = "ROC Curve (Test Set)",
x = "1 - Specificity",
y = "Sensitivity"
) +
ggplot2::annotate("text", x = 0.75, y = 0.25, label = paste("AUC =", round(auc_value, 3)),
size = 5, color = "blue") +
ggplot2::theme_minimal() +
ggplot2::theme(
panel.background = ggplot2::element_rect(fill = "white", color = NA),
plot.background = ggplot2::element_rect(fill = "white", color = NA),
panel.grid.major = ggplot2::element_line(color = "grey90"),
panel.grid.minor = ggplot2::element_line(color = "grey95")
)
print(roc_plot)
}
# Extract variable importance and generate plot
importance_data <- data.frame(Variable = rownames(randomForest::importance(rf_model)),
Gini = randomForest::importance(rf_model)[, "MeanDecreaseGini"])
vip_plot <- ggplot2::ggplot(importance_data, ggplot2::aes(x = reorder(Variable, Gini), y = Gini)) +
ggplot2::geom_bar(stat = "identity", fill = "red2") +
ggplot2::coord_flip() +
ggplot2::ggtitle("Variable Importance Plot (Mean Decrease in Gini)") +
ggplot2::xlab("Features") +
ggplot2::ylab("Importance (Gini Index)") +
ggplot2::theme_minimal() +
ggplot2::theme(
legend.position = "none",
panel.background = ggplot2::element_rect(fill = "white", colour = "white"),
plot.background = ggplot2::element_rect(fill = "white", colour = "white"),
legend.background = ggplot2::element_rect(fill = "white", colour = "white"),
axis.title = ggplot2::element_text(color = "black", size = 12, face = "bold"),
legend.title = ggplot2::element_text(color = "black", size = 10, face = "bold"),
legend.text = ggplot2::element_text(color = "black")
)
print(vip_plot)
# Optional: Random Forest Cross-Validation for Feature Selection
rfcv_data <- NULL
rfcv_result <- NULL
rfcv_plot <- NULL
if (run_rfcv) {
if (verbose) {
cat("\n### RANDOM FOREST CROSS-VALIDATION FOR FEATURE SELECTION ###\n")
}
x_train <- train_data[, predictors]
y_train <- train_data[[group_col]]
rfcv_result <- randomForest::rfcv(x_train, y_train, cv.fold = k_folds, step = step, do.trace = FALSE)
rfcv_data <- data.frame(Variables = rfcv_result$n.var, Error = rfcv_result$error.cv)
rfcv_plot <- ggplot2::ggplot(rfcv_data, ggplot2::aes(x = Variables, y = Error)) +
ggplot2::geom_line(color = "blue") +
ggplot2::geom_point(color = "blue") +
ggplot2::ggtitle("Cross-Validation Error vs. Number of Variables") +
ggplot2::xlab("Number of Variables") +
ggplot2::ylab("Cross-Validation Error") +
ggplot2::scale_x_continuous(breaks = 1:(ncol(train_data) - 1)) +
ggplot2::theme_minimal() +
ggplot2::theme(
panel.background = ggplot2::element_rect(fill = "white", color = NA),
plot.background = ggplot2::element_rect(fill = "white", color = NA),
panel.grid.major = ggplot2::element_line(color = "grey90"),
panel.grid.minor = ggplot2::element_line(color = "grey95")
)
print(rfcv_plot)
if (verbose) {
cat("Random Forest CV completed for feature selection. Check the plot for error vs. number of variables.\n")
}
}
# return(list(
# model = rf_model,
# confusion_matrix = confusion_mat,
# importance_plot = vip_plot,
# rfcv_result = rfcv_result,
# rfcv_plot = rfcv_plot,
# importance_data = importance_data,
# rfcv_data = rfcv_data,
# roc_plot = roc_plot
# ))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.