#' @title
#' Create \code{trans_classifier} object for machine-learning-based model prediction.
#'
#' @description
#' This class is a wrapper for methods of machine-learning-based classification or regression models, including data pre-processing, feature selection,
#' data split, model training, prediction, confusionMatrix and ROC (Receiver Operator Characteristic) or PR (Precision-Recall) curve.
#'
#' Author(s): Felipe Mansoldo and Chi Liu
#'
#' @export
trans_classifier <- R6::R6Class(classname = "trans_classifier",
public = list(
#' @description
#' Create a trans_classifier object.
#'
#' @param dataset an object of \code{\link{microtable}} class.
#' @param x.predictors default "Genus"; character string or data.frame; a character string represents selecting the corresponding data from \code{microtable$taxa_abund};
#' data.frame denotes other customized input. See the following available options:
#' \describe{
#' \item{\strong{'Genus'}}{use Genus level table in \code{microtable$taxa_abund}, or other specific taxonomic rank, e.g., 'Phylum'.
#' If an input level (e.g., ASV) is not found in the names of taxa_abund list, the function will use \code{otu_table} to calculate relative abundance of features.}
#' \item{\strong{'all'}}{use all the levels stored in \code{microtable$taxa_abund}.}
#' \item{\strong{other input}}{must be a data.frame object. It should have the same format with the tables in microtable$taxa_abund, i.e. rows are features;
#' columns are samples with same names in sample_table.}
#' }
#' @param y.response default NULL; the response variable in \code{sample_table} of input \code{microtable} object.
#' @param n.cores default 1; the CPU thread used.
#' @return \code{data_feature} and \code{data_response} stored in the object.
#' @examples
#' \donttest{
#' data(dataset)
#' t1 <- trans_classifier$new(
#' dataset = dataset,
#' x.predictors = "Genus",
#' y.response = "Group")
#' }
initialize = function(dataset,
x.predictors = "Genus",
y.response = NULL,
n.cores = 1
){
check_microtable(dataset)
sampleinfo <- dataset$sample_table
if(is.null(y.response)){
stop("No y.response provided!")
}
if(!y.response %in% colnames(sampleinfo)){
stop("Input y.response must be a column name of dataset$sample_table!")
}
response_data <- sampleinfo[, y.response]
if(is.numeric(response_data)){
self$type <- "Regression"
self$data_response <- response_data
message("Regression for ", y.response)
}else{
self$type <- "Classification"
if(nlevels(as.factor(response_data)) < 2) {
stop("Response variable must have at least 2 factors!")
}
if(nlevels(as.factor(response_data)) == 2){
ClassificationCase = "2 groups"
}else{
ClassificationCase = "2+ groups"
}
# self$ClassificationCase <- ClassificationCase
message("Classification type = ", ClassificationCase)
ClassNames <- make.names(response_data, unique = F)
MapNames <- data.frame(OriginalNames = response_data, ClassNames = ClassNames)
if(!identical(MapNames$OriginalNames, MapNames$ClassNames)){
message("Factor names are non-standard. A correction was made and the change map was saved in object$data_MapNames ...")
}
self$data_MapNames <- MapNames
self$data_response <- ClassNames
}
# x.predictors must be character or data.frame
if(is.character(x.predictors)){
if(is.null(dataset$taxa_abund)){
message("No taxa_abund found in the dataset. Calculate the relative abundance ...")
dataset$cal_abund()
}
if (grepl("all", x.predictors, ignore.case = TRUE)) {
abund_table <- do.call(rbind, unname(dataset$taxa_abund))
}else{
if(! x.predictors %in% names(dataset$taxa_abund)){
message("x.predictors: ", x.predictors, " is not found in the names of taxa_abund list. Use the features in otu_table ...")
dataset$add_rownames2taxonomy(use_name = x.predictors)
message("Calculate relative abundance of features ...")
suppressMessages(dataset$cal_abund())
}
abund_table <- dataset$taxa_abund[[x.predictors]]
}
}else{
# first check the data.frame
if(!is.data.frame(x.predictors)){
stop("Provided x.predictors is neither character nor data.frame !")
}
abund_table <- x.predictors
if(! any(colnames(abund_table) %in% rownames(sampleinfo))){
stop("Please make sure the column names in provided x.predictors are sample names!")
}
}
# remove meaningless things
abund_table %<>% {
.[!grepl("__$|uncultured$|Incertae..edis$|_sp$", rownames(.), ignore.case = TRUE), ]
}
DataX <- abund_table %>% t() %>% as.data.frame(check.names = FALSE)
message("Total feature numbers: ", ncol(DataX))
if(n.cores > 1){
message("Registering cores = ", n.cores)
doParallel::registerDoParallel(n.cores)
}
# use data_feature to make it easily remember and search
self$data_feature <- DataX
message("The feature table is stored in object$data_feature ...")
message("The response variable is stored in object$data_response ...")
},
#' @description
#' Pre-process (centering, scaling etc.) of the feature data based on the caret::preProcess function.
#' See \href{https://topepo.github.io/caret/pre-processing.html}{https://topepo.github.io/caret/pre-processing.html} for more details.
#'
#' @param ... parameters pass to \code{preProcess} function of caret package.
#' @return preprocessed \code{data_feature} in the object.
#' @examples
#' \dontrun{
#' # "nzv" removes near zero variance predictors
#' t1$cal_preProcess(method = c("center", "scale", "nzv"))
#' }
cal_preProcess = function(...){
raw_feature <- self$data_feature
preProcess_res <- caret::preProcess(raw_feature, ...)
new_data <- predict(preProcess_res, newdata = raw_feature)
self$data_feature <- new_data
message("The converted feature table is stored in object$data_feature ...")
},
#' @description
#' Perform feature selection.
#' See \href{https://topepo.github.io/caret/feature-selection-overview.html}{https://topepo.github.io/caret/feature-selection-overview.html} for more details.
#'
#' @param boruta.maxRuns default 300; maximal number of importance source runs; passed to the \code{maxRuns} parameter in \code{Boruta} function of Boruta package.
#' @param boruta.pValue default 0.01; p value passed to the pValue parameter in \code{Boruta} function of Boruta package.
#' @param boruta.repetitions default 4; repetition runs for the feature selection.
#' @param ... parameters pass to \code{Boruta} function of Boruta package.
#' @return optimized \code{data_feature} in the object.
#' @examples
#' \dontrun{
#' t1$cal_feature_sel(boruta.maxRuns = 300, boruta.pValue = 0.01)
#' }
cal_feature_sel = function(
boruta.maxRuns = 300,
boruta.pValue = 0.01,
boruta.repetitions = 4,
...
){
# ClassNames
data_response <- self$data_response
if(self$type == "Classification"){
data_response <- factor(data_response)
}
DataX <- self$data_feature
###################### ----------------
###################### BORUTA
boruta.list <- list()
boura.fs <- function(i){
boruta.res <- Boruta::Boruta(x = DataX, y = data_response,
maxRuns = boruta.maxRuns, pValue = boruta.pValue, ...)
boruta.stats <- data.frame(Boruta::attStats(boruta.res))
boruta.list[[i]] <- rownames(boruta.stats[boruta.stats$decision =='Confirmed', ])
}
message("Running Feature Selection (Boruta) ...")
boruta.list <- parallel::mclapply(1:boruta.repetitions, boura.fs)
boruta.final <- as.data.frame(table(unlist(boruta.list)))
#boruta.store.top <- as.character(boruta.store[which(boruta.store$Freq>10),1])
boruta.list.top <- as.character(boruta.final[which(boruta.final$Freq >= 2), 1])
boruta.n.features <- length(unique(boruta.list.top))
message("End of Feature Selection - Total of selected features = ", boruta.n.features)
###################### BORUTA end
###################### ----------------
self$data_feature <- DataX[, boruta.list.top]
message("The selected features is reassigned to object$data_feature ...")
},
#' @description
#' Split data for training and testing.
#'
#' @param prop.train default 3/4; the ratio of the data used for the training.
#' @return \code{data_train} and \code{data_test} in the object.
#' @examples
#' \dontrun{
#' t1$cal_split(prop.train = 3/4)
#' }
cal_split = function(prop.train = 3/4){
###################### DATA SPLIT: TRAIN and TEST
message("Creating training set and testing set ...")
data_response <- self$data_response
if(self$type == "Classification"){
data_response %<>% factor
}
data_feature <- self$data_feature
data_all <- data.frame(Response = data_response, data_feature, check.names = FALSE)
SplitData <- rsample::initial_split(data_all, prop = prop.train, strata = "Response")
train_data <- rsample::training(SplitData)
test_data <- rsample::testing(SplitData)
message("Stratified sampling with the proportion of ", prop.train*100 ,"% for the training set ...")
###################### DATA SPLIT end
self$data_train <- train_data
self$data_test <- test_data
message("Training and testing data are stored in object$data_train and object$data_test respectively ...")
},
#' @description
#' Control parameters for the following training. Please see \code{trainControl} function of caret package for details.
#'
#' @param method default 'repeatedcv'; 'repeatedcv': Repeated k-Fold cross validation;
#' see method parameter in \code{trainControl} function of \code{caret} package for available options.
#' @param classProbs default TRUE; should class probabilities be computed for classification models?;
#' see classProbs parameter in \code{caret::trainControl} function.
#' @param savePredictions default TRUE; see \code{savePredictions} parameter in \code{caret::trainControl} function.
#' @param ... parameters pass to \code{trainControl} function of caret package.
#' @return \code{trainControl} in the object.
#' @examples
#' \dontrun{
#' t1$set_trainControl(method = 'repeatedcv')
#' }
set_trainControl = function(
method = 'repeatedcv',
classProbs = TRUE,
savePredictions = TRUE,
...
){
if(classProbs){
if(self$type == "Regression"){
classProbs <- FALSE
}
}
trainControl <- caret::trainControl(method = method,
classProbs = classProbs,
savePredictions = savePredictions,
...)
message('Generating trainControl setting stored in object$trainControl ...')
self$trainControl <- trainControl
},
#' @description
#' Run the model training. Please see \href{https://topepo.github.io/caret/available-models.html}{https://topepo.github.io/caret/available-models.html} for available models.
#'
#' @param method default "rf"; "rf": random forest; see method in \code{train} function of caret package for other options.
#' For method = "rf", the \code{tuneGrid} is set: \code{expand.grid(mtry = seq(from = 1, to = max.mtry))}
#' @param max.mtry default 2; for method = "rf"; maximum mtry used in the \code{tuneGrid} to do hyperparameter tuning to optimize the model.
#' @param ntree default 500; for method = "rf"; Number of trees to grow.
#' The default 500 is same with the \code{ntree} parameter in \code{randomForest} function in randomForest package.
#' When it is a vector with more than one element, the function will try to optimize the model to select a best one, such as \code{c(100, 500, 1000)}.
#' @param ... parameters pass to \code{caret::train} function.
#' @return \code{res_train} in the object.
#' @examples
#' \dontrun{
#' # random forest
#' t1$cal_train(method = "rf")
#' # Support Vector Machines with Radial Basis Function Kernel
#' t1$cal_train(method = "svmRadial", tuneLength = 15)
#' }
cal_train = function(
method = "rf",
max.mtry = 2,
ntree = 500,
...
){
train_data <- self$data_train
if(is.null(train_data)){
message("No training data is found! The reason is function cal_split is not performed! Use all the samples for the training ...")
train_data <- data.frame(Response = self$data_response, self$data_feature, check.names = FALSE)
if(self$type == "Classification"){
train_data$Response %<>% as.factor
}
self$data_train <- train_data
}
trControl <- self$trainControl
if(is.null(trControl)){
trControl <- caret::trainControl()
}
###################### ----------------
if(method == "rf" & self$type == "Classification"){
# Optimization of RF parameters
message("Optimization of Random Forest parameters ...")
modellist <- list()
# capture the parameters
all_parameters <- c(as.list(environment()), list(...))
tuneGrid <- expand.grid(mtry = seq(from = 1, to = max.mtry))
if(length(ntree) > 1){
for (test_ntree in ntree){
fit <- caret::train(Response ~ ., data = train_data, method = method, tuneGrid = tuneGrid, trControl = trControl, ntree = test_ntree, ...)
key <- toString(test_ntree)
modellist[[key]] <- fit
}
# compare results
results.tune1 <- caret::resamples(modellist)
res.tune1 <- summary(results.tune1)
res.tune1 <- as.data.frame(res.tune1$statistics$Accuracy)
ntree <- as.numeric(rownames(res.tune1)[which(res.tune1$Mean == max(res.tune1$Mean))])[1]
fit <- caret::train(Response ~ ., data = train_data, method = method, tuneGrid = tuneGrid, trControl = trControl, ntree = ntree, ...)
message("ntree used:", ntree)
}else{
fit <- caret::train(Response ~ ., data = train_data, method = method, tuneGrid = tuneGrid, trControl = trControl, ntree = ntree, ...)
}
message("best mtry:", fit$bestTune$mtry)
######################Optimization of RF parameters end
tuneGrid <- expand.grid(.mtry=fit$bestTune$mtry)
res_train <- caret::train(Response ~ ., data = train_data, method = method, tuneGrid = tuneGrid, trControl = trControl, ntree = ntree, ...)
self$train_params <- list(mtry = fit$bestTune$mtry, ntree = ntree)
}else{
res_train <- caret::train(Response ~ ., data = train_data, method = method, trControl = trControl, ...)
}
self$res_train <- res_train
message('The training result is stored in object$res_train ...')
self$train_method <- method
},
#' @description
#' Get feature importance from the training model.
#' @param rf_feature_sig default FALSE; whether calculate feature significance in 'rf' model using \code{rfPermute} package;
#' only available for \code{method = "rf"} in \code{cal_train} function.
#' @param ... parameters pass to \code{varImp} function of caret package.
#' If \code{rf_feature_sig} is TURE and \code{train_method} is "rf", the parameters will be passed to \code{rfPermute} function of rfPermute package.
#' @return \code{res_feature_imp} in the object. One row for each predictor variable. The column(s) are different importance measures.
#' For the method 'rf', it is MeanDecreaseGini (classification) or IncNodePurity (regression) when \code{rf_feature_sig = FALSE}.
#' @examples
#' \dontrun{
#' t1$cal_feature_imp()
#' }
cal_feature_imp = function(rf_feature_sig = FALSE, ...){
if(is.null(self$res_train)){
stop("Please first run cal_train to train the model !")
}
if(self$train_method == "rf"){
if(rf_feature_sig){
train_data <- self$data_train
# replace feature names with simplified character
match_table <- data.frame(rawname = colnames(train_data)[2:ncol(train_data)], replacename = paste0("r", 1:(ncol(train_data) - 1)))
colnames(train_data)[2:ncol(train_data)] <- match_table$replacename
if(is.null(self$train_params)){
rfp_res <- rfPermute::rfPermute(Response ~ ., data = train_data, ...)
}else{
rfp_res <- rfPermute::rfPermute(Response ~ ., data = train_data, ntree = self$train_params[["ntree"]], mtry = self$train_params[["mtry"]], ...)
}
res_feature_imp <- rfPermute::importance(rfp_res, scale = TRUE) %>% as.data.frame(check.names = FALSE)
rownames(res_feature_imp) <- match_table[match(rownames(res_feature_imp), match_table[, 2]), 1]
}else{
res_feature_imp <- caret::varImp(self$res_train$finalModel, ...)
}
}else{
res_feature_imp <- caret::varImp(self$res_train$finalModel, ...)
}
self$res_feature_imp <- res_feature_imp
message('The feature importance is stored in object$res_feature_imp ...')
},
#' @description
#' Bar plot for feature importance.
#' @param rf_sig_show default NULL; "MeanDecreaseAccuracy" (Default) or "MeanDecreaseGini" for random forest classification;
#' "\%IncMSE" (Default) or "IncNodePurity" for random forest regression;
#' Only available when \code{rf_feature_sig = TRUE} in function \code{cal_feature_imp},
#' which generate "MeanDecreaseGini" (and "MeanDecreaseAccuracy") or "\%IncMSE" (and "IncNodePurity") in the column names of \code{res_feature_imp};
#' Function can also generate "Significance" according to the p value.
#' @param show_sig_group default FALSE; whether show the features with different significant groups;
#' Only available when "Significance" is found in the data.
#' @param ... parameters pass to \code{plot_diff_bar} function of \code{trans_diff} package.
#' @return \code{ggplot2} object.
#' @examples
#' \dontrun{
#' t1$plot_feature_imp(use_number = 1:20, coord_flip = FALSE)
#' }
plot_feature_imp = function(rf_sig_show = NULL, show_sig_group = FALSE, ...){
if(is.null(self$res_feature_imp)){
stop("Please first run function cal_feature_imp !")
}
tmp <- data.frame(Taxa = rownames(self$res_feature_imp), self$res_feature_imp, check.names = FALSE)
tmp$Taxa %<>% gsub("`", "", ., fixed = TRUE) %>% gsub("\\.(.__)", "\\|\\1", .)
if(! "Value" %in% colnames(tmp)){
if("Overall" %in% colnames(tmp)){
colnames(tmp)[colnames(tmp) == "Overall"] <- "Value"
}else{
if(any(c("MeanDecreaseAccuracy", "MeanDecreaseGini", "%IncMSE", "IncNodePurity") %in% colnames(tmp))){
if(is.null(rf_sig_show)){
if(self$type == "Classification"){
rf_sig_show = "MeanDecreaseAccuracy"
}else{
rf_sig_show = "%IncMSE"
}
}else{
if(self$type == "Classification"){
if(! rf_sig_show %in% c("MeanDecreaseAccuracy", "MeanDecreaseGini")){
stop("Provided rf_sig_show must be one of 'MeanDecreaseAccuracy' and 'MeanDecreaseGini'!")
}
}else{
if(! rf_sig_show %in% c("%IncMSE", "IncNodePurity")){
stop("Provided rf_sig_show must be one of '%IncMSE' and 'IncNodePurity'!")
}
}
}
colnames(tmp)[colnames(tmp) == rf_sig_show] <- "Value"
colnames(tmp)[colnames(tmp) == paste0(rf_sig_show, ".pval")] <- "pvalue"
tmp$Significance <- generate_p_siglabel(tmp$pvalue, nonsig = "ns")
tmp %<>% .[, c("Taxa", "Value", "pvalue", "Significance")]
}else{
if(is.numeric(tmp[, 2])){
colnames(tmp)[2] <- "Value"
}else{
stop("The res_feature_imp format can not be correctly recognized!")
}
}
}
}
if(show_sig_group){
if("Significance" %in% colnames(tmp)){
tmp$Group <- tmp$Significance %>% factor(., levels = c("***", "**", "*", "ns"))
}
}
suppressMessages(trans_diff_tmp <- trans_diff$new(dataset = NULL))
trans_diff_tmp$res_diff <- tmp
trans_diff_tmp$method <- "rf"
g1 <- trans_diff_tmp$plot_diff_bar(...)
if(is.null(rf_sig_show)){
g1 <- g1 + ylab("Value")
}else{
g1 <- g1 + ylab(rf_sig_show)
}
g1
},
#' @description
#' Run the prediction.
#'
#' @param positive_class default NULL; see positive parameter in \code{confusionMatrix} function of caret package;
#' If positive_class is NULL, use the first group in data as the positive class automatically.
#' @return \code{res_predict}, \code{res_confusion_fit} and \code{res_confusion_stats} stored in the object.
#' The \code{res_predict} is the predicted result for \code{data_test}.
#' Several evaluation metrics in \code{res_confusion_fit} are defined as follows:
#' \deqn{Accuracy = \frac{TP + TN}{TP + TN + FP + FN}}
#' \deqn{Sensitivity = Recall = TPR = \frac{TP}{TP + FN}}
#' \deqn{Specificity = TNR = 1 - FPR = \frac{TN}{TN + FP}}
#' \deqn{Precision = \frac{TP}{TP + FP}}
#' \deqn{Prevalence = \frac{TP + FN}{TP + TN + FP + FN}}
#' \deqn{F1-Score = \frac{2 * Precision * Recall}{Precision + Recall}}
#' \deqn{Kappa = \frac{Accuracy - Pe}{1 - Pe}}
#' where TP is true positive; TN is ture negative; FP is false positive; and FN is false negative;
#' FPR is False Positive Rate; TPR is True Positive Rate; TNR is True Negative Rate;
#' Pe is the hypothetical probability of chance agreement on the classes for reference and prediction in the confusion matrix.
#' Accuracy represents the ratio of correct predictions.
#' Precision identifies how the model accurately predicted the positive classes.
#' Recall (sensitivity) measures the ratio of actual positives that are correctly identified by the model.
#' F1-score is the weighted average score of recall and precision. The value at 1 is the best performance and at 0 is the worst.
#' Prevalence represents how often positive events occurred.
#' Kappa identifies how well the model is predicting.
#' @examples
#' \dontrun{
#' t1$cal_predict()
#' }
cal_predict = function(positive_class = NULL){
###################### ----------------
###################### Evaluation for the test set
if(is.null(self$res_train)){
stop("Please first run cal_train to train the model !")
}
fit.best <- self$res_train
test_data <- self$data_test
if(is.null(test_data)){
stop("No testing data is found! Please first run cal_split function!")
}
fit.best.predict <- predict(fit.best, test_data[, 2:ncol(test_data)])
self$res_predict <- fit.best.predict
message('The result of model prediction is stored in object$res_predict ...')
###################### end: Evaluation for the test set
###################### ----------------
if(self$type == "Classification"){
if (is.null(positive_class)){
positive_class <- levels(test_data[, 1])[1]
}
positive_class.display <- self$data_MapNames %>% dplyr::filter(ClassNames %in% positive_class) %>%
dplyr::select(OriginalNames) %>% unique() %>% dplyr::pull()
message('Calculating confusionMatrix with positive class = ', positive_class.display, " ...")
confusion.fit.best <- caret::confusionMatrix(as.factor(fit.best.predict),
as.factor(test_data[,1]),
positive = positive_class)
self$res_confusion_fit <- confusion.fit.best
message('The result of confusionMatrix is stored in object$res_confusion_fit ...')
confusion.data.sts <- data.frame(confusion.fit.best$overall)
Confusion.Sts <- data.frame("Overall Statistics" = paste0(round(confusion.data.sts[,1],2) * 100,"%") )
rownames(Confusion.Sts) <- rownames(confusion.data.sts)
self$res_confusion_stats <- Confusion.Sts
message('The statistics of confusionMatrix is stored in object$res_confusion_stats ...')
message('Model prediction Accuracy = ',Confusion.Sts$Overall.Statistics[1])
}
},
#' @description
#' Plot the cross-tabulation of observed and predicted classes with associated statistics based on the results of function \code{cal_predict}.
#'
#' @param plot_confusion default TRUE; whether plot the confusion matrix.
#' @param plot_statistics default TRUE; whether plot the statistics.
#' @return \code{ggplot} object.
#' @examples
#' \dontrun{
#' t1$plot_confusionMatrix()
#' }
plot_confusionMatrix = function(
plot_confusion = TRUE,
plot_statistics = TRUE
){
if(self$type == "Regression"){
stop("The function can only be available for the Classification !")
}
if(is.null(self$res_confusion_fit)){
stop("Please first run cal_predict to get the prediction performance !")
}
p1 <- ggplot(data = as.data.frame(self$res_confusion_fit$table) ,
aes(x = Reference, y = Prediction)) +
geom_tile(aes(fill = log(Freq)), colour = "white") +
scale_fill_gradient(low = "white", high = "steelblue") +
geom_text(aes(x = Reference, y = Prediction, label = Freq)) +
theme(legend.position = "none")
Confusion.Sts <- self$res_confusion_stats
p2 <- gridExtra::tableGrob(Confusion.Sts)
if(plot_confusion == TRUE & plot_statistics == TRUE){
p3 <- gridExtra::grid.arrange(p1, p2,nrow = 1, ncol = 2,
top=grid::textGrob("Confusion Matrix and Statistics",gp=grid::gpar(fontsize=15,font=0.5)))
}
if(plot_confusion == TRUE & plot_statistics == FALSE){
p3 <- gridExtra::grid.arrange(p1,nrow = 1, ncol = 1,
top=grid::textGrob("Confusion Matrix",gp=grid::gpar(fontsize=15,font=0.5)))
}
if(plot_confusion == FALSE & plot_statistics == TRUE){
p3 <- gridExtra::grid.arrange(p2,nrow = 1, ncol = 1,
top=grid::textGrob("Statistics",gp=grid::gpar(fontsize=15,font=0.5)))
}
p3
},
#' @description
#' Get ROC (Receiver Operator Characteristic) curve data and the performance data.
#'
#' @param input default "pred"; 'pred' or 'train'; 'pred' represents using prediction results;
#' 'train' represents using training results.
#' @return a list \code{res_ROC} stored in the object. It has two tables: \code{res_roc} and \code{res_pr}. AUC: Area Under the ROC Curve.
#' For the definition of metrics, please refer to the return part of function \code{cal_predict}.
#' @examples
#' \dontrun{
#' t1$cal_ROC()
#' }
cal_ROC = function(input = "pred"){
if(self$type == "Regression"){
stop("The function can only be available for the Classification !")
}
input <- match.arg(input, c("pred", "train"))
if(is.null(self$res_train)){
stop("Please first run cal_train to train the model !")
}
fit.best <- self$res_train
train_method <- fit.best$method
if(input == "pred"){
test_data <- self$data_test
prediction_prob <- predict(fit.best, test_data[, 2:ncol(test_data)] , type="prob")
class_names <- levels(droplevels(test_data[, 1])) #drop because sometimes there is empty classes
true_label <- test_data[, 1]
}else{
# use the prediction data in the training part
class_names <- fit.best$levels
prediction_prob <- fit.best$pred[, class_names]
true_label <- fit.best$pred$obs
}
# use multiROC package
label_df <- lapply(class_names, function(x){ifelse(true_label == x, 1, 0)}) %>%
do.call(cbind, .) %>%
as.data.frame %>%
`colnames<-`(paste0(class_names, "_true"))
prob_df <- prediction_prob %>% `colnames<-`(paste0(colnames(.), "_pred_", train_method))
use_df <- cbind(label_df, prob_df)
roc_res <- multiROC::multi_roc(use_df, force_diag = T)
pr_res <- multiROC::multi_pr(use_df, force_diag = T)
plot_roc_df <- multiROC::plot_roc_data(roc_res)
plot_pr_df <- multiROC::plot_pr_data(pr_res)
# store the results
res_ROC <- list()
res_ROC$res_roc <- plot_roc_df
res_ROC$res_pr <- plot_pr_df
self$res_ROC <- res_ROC
message('Specificity-sensitivity data is stored in object$res_ROC$res_roc ...')
message('Recall-Precision is stored in object$res_ROC$res_pr ...')
},
#' @description
#' Plot ROC curve.
#'
#' @param plot_type default c("ROC", "PR")[1]; 'ROC' represents ROC (Receiver Operator Characteristic) curve;
#' 'PR' represents PR (Precision-Recall) curve.
#' @param plot_group default "all"; 'all' represents all the classes in the model;
#' 'add' represents all adding micro-average and macro-average results, see
#' \href{https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html}{https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html};
#' other options should be one or more class names, same with the names in Group column of res_ROC$res_roc from cal_ROC function.
#' @param color_values default RColorBrewer::brewer.pal(8, "Dark2"); colors used in the plot.
#' @param add_AUC default TRUE; whether add AUC in the legend.
#' @param plot_method default FALSE; If TRUE, show the method in the legend though only one method is found.
#' @param ... parameters pass to \code{geom_path} function of ggplot2 package.
#' @return \code{ggplot2} object.
#' @examples
#' \dontrun{
#' t1$plot_ROC(size = 1, alpha = 0.7)
#' }
plot_ROC = function(
plot_type = c("ROC", "PR")[1],
plot_group = "all",
color_values = RColorBrewer::brewer.pal(8, "Dark2"),
add_AUC = TRUE,
plot_method = FALSE,
...
){
if(is.null(self$res_ROC)){
stop("Please first run cal_ROC to get the data for ROC curve !")
}
plot_type <- match.arg(plot_type, c("ROC", "PR"))
if(plot_type == "ROC"){
plot_data <- self$res_ROC$res_roc
}else{
plot_data <- self$res_ROC$res_pr
}
if(plot_group != "add"){
if(plot_group == "all"){
plot_data %<>% .[! .$Group %in% c("Micro", "Macro"), ]
}else{
if(!any(plot_data$Group %in% plot_group)){
stop("Please input the correct plot_group !")
}else{
plot_data %<>% .[.$Group %in% plot_group, ]
}
}
}
if(add_AUC){
plot_data$Group <- paste0(plot_data$Group, "\n AUC = ", round(plot_data$AUC, 2))
}
if(plot_type == "ROC"){
p <- ggplot(plot_data, aes(x = 1-Specificity, y = Sensitivity)) +
geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), colour = 'grey', linetype = 'dashed')
}else{
p <- ggplot(plot_data, aes(x = Recall, y = Precision))
}
if(length(unique(plot_data$Method)) > 1){
p <- p + geom_path(aes(color = Group, linetype = Method), ...)
}else{
if(plot_method){
p <- p + geom_path(aes(color = Group, linetype = Method), ...)
}else{
p <- p + geom_path(aes(color = Group), ...)
}
}
p <- p + theme_bw() +
coord_equal() +
xlim(0, 1) +
ylim(0, 1) +
scale_color_manual(values = color_values) +
theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank()) +
theme(legend.title = element_blank())
p
},
#' @description
#' Use \code{caretList} function of caretEnsemble package to run multiple models. For the available models, please run \code{names(getModelInfo())}.
#'
#' @param ... parameters pass to \code{caretList} function of \code{caretEnsemble} package.
#' @return \code{res_caretList_models} in the object.
#' @examples
#' \dontrun{
#' t1$cal_caretList(methodList = c('rf', 'svmRadial'))
#' }
cal_caretList = function(...){
if(!require(caretEnsemble)){
stop("Please first install caretEnsemble package from CRAN!")
}
use_trainControl <- self$trainControl
train_data <- self$data_train
if(is.null(train_data)){
message("No training data is found! The reason is function cal_split is not performed! Use all the samples for the training ...")
train_data <- data.frame(Response = self$data_response, self$data_feature, check.names = FALSE)
if(self$type == "Classification"){
train_data$Response %<>% as.factor
}
self$data_train <- train_data
}
models <- caretList(Response ~ ., data = train_data, trControl = use_trainControl, ...)
self$res_caretList_models <- models
message('Models are stored in object$res_caretList_models ...')
},
#' @description
#' Use \code{resamples} function of caret package to collect the metric values based on the \code{res_caretList_models} data.
#'
#' @param ... parameters pass to \code{resamples} function of \code{caret} package.
#' @return \code{res_caretList_resamples} list and \code{res_caretList_resamples_reshaped} table in the object.
#' @examples
#' \dontrun{
#' t1$cal_caretList_resamples()
#' }
cal_caretList_resamples = function(...){
if(is.null(self$res_caretList_models)){
stop("Please first run cal_caretList function!")
}
res_caretList_models <- self$res_caretList_models
results <- caret::resamples(res_caretList_models, ...)
self$res_caretList_resamples <- results
message('Raw resamples results are stored in object$res_caretList_resamples ...')
all_data <- results$values
model_names <- colnames(all_data) %>% .[. != "Resample"] %>% gsub("~.*", "", .) %>% unique
metric_names <- colnames(all_data) %>% .[. != "Resample"] %>% gsub(".*~", "", .) %>% unique
reshaped_data <- lapply(model_names, function(x){
extract_value <- all_data[, paste0(x, "~", metric_names), drop = FALSE]
colnames(extract_value) <- metric_names
data.frame(Model = x, extract_value)
}) %>% do.call(rbind, .) %>% as.data.frame(check.names = FALSE)
reshaped_data <- reshape2::melt(reshaped_data, id.vars = "Model")
colnames(reshaped_data) <- c("Model", "Metric", "Value")
self$res_caretList_resamples_reshaped <- reshaped_data
message('Reshaped metric values are stored in object$res_caretList_resamples_reshaped ...')
},
#' @description
#' Visualize the metric values based on the \code{res_caretList_resamples_reshaped} data.
#'
#' @param color_values default \code{RColorBrewer::brewer.pal}(8, "Dark2"); colors palette for the box.
#' @param ... parameters pass to \code{geom_boxplot} function of \code{ggplot2} package.
#' @return ggplot object.
#' @examples
#' \dontrun{
#' t1$plot_caretList_resamples()
#' }
plot_caretList_resamples = function(color_values = RColorBrewer::brewer.pal(8, "Dark2"), ...){
if(is.null(self$res_caretList_resamples_reshaped)){
stop("Please first run cal_caretList_resamples function!")
}
reshaped_data <- self$res_caretList_resamples_reshaped
p <- ggplot(reshaped_data, aes(x = Model, y = Value, colour = Model)) +
geom_boxplot(...) +
scale_color_manual(values = color_values) +
facet_grid(Metric ~ ., drop = TRUE, scale = "free", space = "fixed")
p
}
),
lock_class = FALSE,
lock_objects = FALSE
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.