trans_classifier | R Documentation |
trans_classifier
object for machine-learning-based model prediction.This class is a wrapper for methods of machine-learning-based classification or regression models, including data pre-processing, feature selection, data split, model training, prediction, confusionMatrix and ROC (Receiver Operator Characteristic) or PR (Precision-Recall) curve.
Author(s): Felipe Mansoldo and Chi Liu
new()
Create a trans_classifier object.
trans_classifier$new( dataset, x.predictors = "Genus", y.response = NULL, n.cores = 1 )
dataset
an object of microtable
class.
x.predictors
default "Genus"; character string or data.frame; a character string represents selecting the corresponding data from microtable$taxa_abund
;
data.frame denotes other customized input. See the following available options:
use Genus level table in microtable$taxa_abund
, or other specific taxonomic rank, e.g., 'Phylum'.
If an input level (e.g., ASV) is not found in the names of taxa_abund list, the function will use otu_table
to calculate relative abundance of features.
use all the levels stored in microtable$taxa_abund
.
must be a data.frame object. It should have the same format with the tables in microtable$taxa_abund, i.e. rows are features; columns are samples with same names in sample_table.
y.response
default NULL; the response variable in sample_table
of input microtable
object.
n.cores
default 1; the CPU thread used.
data_feature
and data_response
stored in the object.
\donttest{ data(dataset) t1 <- trans_classifier$new( dataset = dataset, x.predictors = "Genus", y.response = "Group") }
cal_preProcess()
Pre-process (centering, scaling etc.) of the feature data based on the caret::preProcess function. See https://topepo.github.io/caret/pre-processing.html for more details.
trans_classifier$cal_preProcess(...)
...
parameters pass to preProcess
function of caret package.
preprocessed data_feature
in the object.
\dontrun{ # "nzv" removes near zero variance predictors t1$cal_preProcess(method = c("center", "scale", "nzv")) }
cal_feature_sel()
Perform feature selection. See https://topepo.github.io/caret/feature-selection-overview.html for more details.
trans_classifier$cal_feature_sel( boruta.maxRuns = 300, boruta.pValue = 0.01, boruta.repetitions = 4, ... )
boruta.maxRuns
default 300; maximal number of importance source runs; passed to the maxRuns
parameter in Boruta
function of Boruta package.
boruta.pValue
default 0.01; p value passed to the pValue parameter in Boruta
function of Boruta package.
boruta.repetitions
default 4; repetition runs for the feature selection.
...
parameters pass to Boruta
function of Boruta package.
optimized data_feature
in the object.
\dontrun{ t1$cal_feature_sel(boruta.maxRuns = 300, boruta.pValue = 0.01) }
cal_split()
Split data for training and testing.
trans_classifier$cal_split(prop.train = 3/4)
prop.train
default 3/4; the ratio of the data used for the training.
data_train
and data_test
in the object.
\dontrun{ t1$cal_split(prop.train = 3/4) }
set_trainControl()
Control parameters for the following training. Please see trainControl
function of caret package for details.
trans_classifier$set_trainControl( method = "repeatedcv", classProbs = TRUE, savePredictions = TRUE, ... )
method
default 'repeatedcv'; 'repeatedcv': Repeated k-Fold cross validation;
see method parameter in trainControl
function of caret
package for available options.
classProbs
default TRUE; should class probabilities be computed for classification models?;
see classProbs parameter in caret::trainControl
function.
savePredictions
default TRUE; see savePredictions
parameter in caret::trainControl
function.
...
parameters pass to trainControl
function of caret package.
trainControl
in the object.
\dontrun{ t1$set_trainControl(method = 'repeatedcv') }
cal_train()
Run the model training. Please see https://topepo.github.io/caret/available-models.html for available models.
trans_classifier$cal_train(method = "rf", max.mtry = 2, ntree = 500, ...)
method
default "rf"; "rf": random forest; see method in train
function of caret package for other options.
For method = "rf", the tuneGrid
is set: expand.grid(mtry = seq(from = 1, to = max.mtry))
max.mtry
default 2; for method = "rf"; maximum mtry used in the tuneGrid
to do hyperparameter tuning to optimize the model.
ntree
default 500; for method = "rf"; Number of trees to grow.
The default 500 is same with the ntree
parameter in randomForest
function in randomForest package.
When it is a vector with more than one element, the function will try to optimize the model to select a best one, such as c(100, 500, 1000)
.
...
parameters pass to caret::train
function.
res_train
in the object.
\dontrun{ # random forest t1$cal_train(method = "rf") # Support Vector Machines with Radial Basis Function Kernel t1$cal_train(method = "svmRadial", tuneLength = 15) }
cal_feature_imp()
Get feature importance from the training model.
trans_classifier$cal_feature_imp(rf_feature_sig = FALSE, ...)
rf_feature_sig
default FALSE; whether calculate feature significance in 'rf' model using rfPermute
package;
only available for method = "rf"
in cal_train
function.
...
parameters pass to varImp
function of caret package.
If rf_feature_sig
is TURE and train_method
is "rf", the parameters will be passed to rfPermute
function of rfPermute package.
res_feature_imp
in the object. One row for each predictor variable. The column(s) are different importance measures.
For the method 'rf', it is MeanDecreaseGini (classification) or IncNodePurity (regression) when rf_feature_sig = FALSE
.
\dontrun{ t1$cal_feature_imp() }
plot_feature_imp()
Bar plot for feature importance.
trans_classifier$plot_feature_imp( rf_sig_show = NULL, show_sig_group = FALSE, ... )
rf_sig_show
default NULL; "MeanDecreaseAccuracy" (Default) or "MeanDecreaseGini" for random forest classification;
"%IncMSE" (Default) or "IncNodePurity" for random forest regression;
Only available when rf_feature_sig = TRUE
in function cal_feature_imp
,
which generate "MeanDecreaseGini" (and "MeanDecreaseAccuracy") or "%IncMSE" (and "IncNodePurity") in the column names of res_feature_imp
;
Function can also generate "Significance" according to the p value.
show_sig_group
default FALSE; whether show the features with different significant groups; Only available when "Significance" is found in the data.
...
parameters pass to plot_diff_bar
function of trans_diff
package.
ggplot2
object.
\dontrun{ t1$plot_feature_imp(use_number = 1:20, coord_flip = FALSE) }
cal_predict()
Run the prediction.
trans_classifier$cal_predict(positive_class = NULL)
positive_class
default NULL; see positive parameter in confusionMatrix
function of caret package;
If positive_class is NULL, use the first group in data as the positive class automatically.
res_predict
, res_confusion_fit
and res_confusion_stats
stored in the object.
The res_predict
is the predicted result for data_test
.
Several evaluation metrics in res_confusion_fit
are defined as follows:
Accuracy = \frac{TP + TN}{TP + TN + FP + FN}
Sensitivity = Recall = TPR = \frac{TP}{TP + FN}
Specificity = TNR = 1 - FPR = \frac{TN}{TN + FP}
Precision = \frac{TP}{TP + FP}
Prevalence = \frac{TP + FN}{TP + TN + FP + FN}
F1-Score = \frac{2 * Precision * Recall}{Precision + Recall}
Kappa = \frac{Accuracy - Pe}{1 - Pe}
where TP is true positive; TN is ture negative; FP is false positive; and FN is false negative; FPR is False Positive Rate; TPR is True Positive Rate; TNR is True Negative Rate; Pe is the hypothetical probability of chance agreement on the classes for reference and prediction in the confusion matrix. Accuracy represents the ratio of correct predictions. Precision identifies how the model accurately predicted the positive classes. Recall (sensitivity) measures the ratio of actual positives that are correctly identified by the model. F1-score is the weighted average score of recall and precision. The value at 1 is the best performance and at 0 is the worst. Prevalence represents how often positive events occurred. Kappa identifies how well the model is predicting.
\dontrun{ t1$cal_predict() }
plot_confusionMatrix()
Plot the cross-tabulation of observed and predicted classes with associated statistics based on the results of function cal_predict
.
trans_classifier$plot_confusionMatrix( plot_confusion = TRUE, plot_statistics = TRUE )
plot_confusion
default TRUE; whether plot the confusion matrix.
plot_statistics
default TRUE; whether plot the statistics.
ggplot
object.
\dontrun{ t1$plot_confusionMatrix() }
cal_ROC()
Get ROC (Receiver Operator Characteristic) curve data and the performance data.
trans_classifier$cal_ROC(input = "pred")
input
default "pred"; 'pred' or 'train'; 'pred' represents using prediction results; 'train' represents using training results.
a list res_ROC
stored in the object. It has two tables: res_roc
and res_pr
. AUC: Area Under the ROC Curve.
For the definition of metrics, please refer to the return part of function cal_predict
.
\dontrun{ t1$cal_ROC() }
plot_ROC()
Plot ROC curve.
trans_classifier$plot_ROC( plot_type = c("ROC", "PR")[1], plot_group = "all", color_values = RColorBrewer::brewer.pal(8, "Dark2"), add_AUC = TRUE, plot_method = FALSE, ... )
plot_type
default c("ROC", "PR")[1]; 'ROC' represents ROC (Receiver Operator Characteristic) curve; 'PR' represents PR (Precision-Recall) curve.
plot_group
default "all"; 'all' represents all the classes in the model; 'add' represents all adding micro-average and macro-average results, see https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html; other options should be one or more class names, same with the names in Group column of res_ROC$res_roc from cal_ROC function.
color_values
default RColorBrewer::brewer.pal(8, "Dark2"); colors used in the plot.
add_AUC
default TRUE; whether add AUC in the legend.
plot_method
default FALSE; If TRUE, show the method in the legend though only one method is found.
...
parameters pass to geom_path
function of ggplot2 package.
ggplot2
object.
\dontrun{ t1$plot_ROC(size = 1, alpha = 0.7) }
cal_caretList()
Use caretList
function of caretEnsemble package to run multiple models. For the available models, please run names(getModelInfo())
.
trans_classifier$cal_caretList(...)
...
parameters pass to caretList
function of caretEnsemble
package.
res_caretList_models
in the object.
\dontrun{ t1$cal_caretList(methodList = c('rf', 'svmRadial')) }
cal_caretList_resamples()
Use resamples
function of caret package to collect the metric values based on the res_caretList_models
data.
trans_classifier$cal_caretList_resamples(...)
...
parameters pass to resamples
function of caret
package.
res_caretList_resamples
list and res_caretList_resamples_reshaped
table in the object.
\dontrun{ t1$cal_caretList_resamples() }
plot_caretList_resamples()
Visualize the metric values based on the res_caretList_resamples_reshaped
data.
trans_classifier$plot_caretList_resamples( color_values = RColorBrewer::brewer.pal(8, "Dark2"), ... )
color_values
default RColorBrewer::brewer.pal
(8, "Dark2"); colors palette for the box.
...
parameters pass to geom_boxplot
function of ggplot2
package.
ggplot object.
\dontrun{ t1$plot_caretList_resamples() }
clone()
The objects of this class are cloneable with this method.
trans_classifier$clone(deep = FALSE)
deep
Whether to make a deep clone.
## ------------------------------------------------
## Method `trans_classifier$new`
## ------------------------------------------------
data(dataset)
t1 <- trans_classifier$new(
dataset = dataset,
x.predictors = "Genus",
y.response = "Group")
## ------------------------------------------------
## Method `trans_classifier$cal_preProcess`
## ------------------------------------------------
## Not run:
# "nzv" removes near zero variance predictors
t1$cal_preProcess(method = c("center", "scale", "nzv"))
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_feature_sel`
## ------------------------------------------------
## Not run:
t1$cal_feature_sel(boruta.maxRuns = 300, boruta.pValue = 0.01)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_split`
## ------------------------------------------------
## Not run:
t1$cal_split(prop.train = 3/4)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$set_trainControl`
## ------------------------------------------------
## Not run:
t1$set_trainControl(method = 'repeatedcv')
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_train`
## ------------------------------------------------
## Not run:
# random forest
t1$cal_train(method = "rf")
# Support Vector Machines with Radial Basis Function Kernel
t1$cal_train(method = "svmRadial", tuneLength = 15)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_feature_imp`
## ------------------------------------------------
## Not run:
t1$cal_feature_imp()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_feature_imp`
## ------------------------------------------------
## Not run:
t1$plot_feature_imp(use_number = 1:20, coord_flip = FALSE)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_predict`
## ------------------------------------------------
## Not run:
t1$cal_predict()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_confusionMatrix`
## ------------------------------------------------
## Not run:
t1$plot_confusionMatrix()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_ROC`
## ------------------------------------------------
## Not run:
t1$cal_ROC()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_ROC`
## ------------------------------------------------
## Not run:
t1$plot_ROC(size = 1, alpha = 0.7)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_caretList`
## ------------------------------------------------
## Not run:
t1$cal_caretList(methodList = c('rf', 'svmRadial'))
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_caretList_resamples`
## ------------------------------------------------
## Not run:
t1$cal_caretList_resamples()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_caretList_resamples`
## ------------------------------------------------
## Not run:
t1$plot_caretList_resamples()
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.