trans_classifier | R Documentation |
This class is a wrapper for methods of machine-learning-based classification or regression models, including data pre-processing, feature selection, data split, model training, prediction, confusionMatrix and ROC (Receiver Operator Characteristic) or PR (Precision-Recall) curve.
Author(s): Felipe Mansoldo and Chi Liu
new()
Create the trans_classifier object.
trans_classifier$new( dataset = NULL, x.predictors = "all", y.response = NULL, n.cores = 1 )
dataset
the object of microtable
Class.
x.predictors
default "all"; character string or data.frame; a character string represents selecting the corresponding data from microtable$taxa_abund; data.frame represents other customized input. See the following available options:
use all the taxa stored in microtable$taxa_abund
use Genus level table in microtable$taxa_abund, or other specific taxonomic rank, e.g. 'Phylum'
must be a data.frame; It should have the same format with the data.frame in microtable$taxa_abund, i.e. rows are features; cols are samples with same names in sample_table
y.response
default NULL; the response variable in sample_table.
n.cores
default 1; the CPU thread used.
data_feature and data_response in the object.
\donttest{ data(dataset) t1 <- trans_classifier$new( dataset = dataset, x.predictors = "Genus", y.response = "Group") }
cal_preProcess()
Pre-process (centering, scaling etc.) of the feature data based on the caret::preProcess function. See https://topepo.github.io/caret/pre-processing.html for more details.
trans_classifier$cal_preProcess(...)
...
parameters pass to preProcess function of caret package.
converted data_feature in the object.
\dontrun{ t1$cal_preProcess(method = c("center", "scale", "nzv")) }
cal_feature_sel()
Perform feature selection. See https://topepo.github.io/caret/feature-selection-overview.html for more details.
trans_classifier$cal_feature_sel( boruta.maxRuns = 300, boruta.pValue = 0.01, boruta.repetitions = 4, ... )
boruta.maxRuns
default 300; maximal number of importance source runs; passed to the maxRuns parameter in Boruta function of Boruta package.
boruta.pValue
default 0.01; p value passed to the pValue parameter in Boruta function of Boruta package.
boruta.repetitions
default 4; repetition runs for the feature selection.
...
parameters pass to Boruta function of Boruta package.
optimized data_feature in the object.
\dontrun{ t1$cal_feature_sel(boruta.maxRuns = 300, boruta.pValue = 0.01) }
cal_split()
Split data for training and testing.
trans_classifier$cal_split(prop.train = 3/4)
prop.train
default 3/4; the ratio of the dataset used for the training.
data_train and data_test in the object.
\dontrun{ t1$cal_split(prop.train = 3/4) }
set_trainControl()
Control parameters for the following training. See trainControl function of caret package for details.
trans_classifier$set_trainControl( method = "repeatedcv", classProbs = TRUE, savePredictions = TRUE, ... )
method
default 'repeatedcv'; 'repeatedcv': Repeated k-Fold cross validation;
see method parameter in trainControl
function of caret
package for available options.
classProbs
default TRUE; should class probabilities be computed for classification models?;
see classProbs parameter in caret::trainControl
function.
savePredictions
default TRUE; see savePredictions
parameter in caret::trainControl
function.
...
parameters pass to trainControl function of caret package.
trainControl in the object.
\dontrun{ t1$set_trainControl(method = 'repeatedcv') }
cal_train()
Run the model training.
trans_classifier$cal_train(method = "rf", max.mtry = 2, max.ntree = 200, ...)
method
default "rf"; "rf": random forest; see method in caret::train function for other options.
max.mtry
default 2; for method = "rf"; maximum mtry used for the tunegrid to do hyperparameter tuning to optimize the model.
max.ntree
default 200; for method = "rf"; maximum number of trees used to optimize the model.
...
parameters pass to caret::train
function.
res_train in the object.
\dontrun{ # random forest t1$cal_train(method = "rf") # Support Vector Machines with Radial Basis Function Kernel t1$cal_train(method = "svmRadial", tuneLength = 15) }
cal_feature_imp()
Get feature importance from the training model.
trans_classifier$cal_feature_imp(...)
...
parameters pass to varImp function of caret package.
res_feature_imp in the object. One row for each predictor variable. The column(s) are different importance measures. For the method 'rf', it is MeanDecreaseGini (classification) or IncNodePurity (regression).
\dontrun{ t1$cal_feature_imp() }
plot_feature_imp()
Bar plot for feature importance.
trans_classifier$plot_feature_imp(...)
...
parameters pass to plot_diff_bar
function of trans_diff
package.
ggplot2 object.
\dontrun{ t1$plot_feature_imp(use_number = 1:20, coord_flip = FALSE) }
cal_predict()
Run the prediction.
trans_classifier$cal_predict(positive_class = NULL)
positive_class
default NULL; see positive parameter in confusionMatrix function of caret package; If positive_class is NULL, use the first group in data as the positive class automatically.
res_predict, res_confusion_fit and res_confusion_stats stored in the object.
\dontrun{ t1$cal_predict() }
plot_confusionMatrix()
Plot the cross-tabulation of observed and predicted classes with associated statistics.
trans_classifier$plot_confusionMatrix( plot_confusion = TRUE, plot_statistics = TRUE )
plot_confusion
default TRUE; whether plot the confusion matrix.
plot_statistics
default TRUE; whether plot the statistics.
ggplot object.
\dontrun{ t1$plot_confusionMatrix() }
cal_ROC()
Get ROC (Receiver Operator Characteristic) curve data and the performance data.
trans_classifier$cal_ROC(input = "pred")
input
default "pred"; 'pred' or 'train'; 'pred' represents using prediction results; 'train' represents using training results.
a list res_ROC stored in the object.
\dontrun{ t1$cal_ROC() }
plot_ROC()
Plot ROC curve.
trans_classifier$plot_ROC( plot_type = c("ROC", "PR")[1], plot_group = "all", color_values = RColorBrewer::brewer.pal(8, "Dark2"), add_AUC = TRUE, plot_method = FALSE, ... )
plot_type
default c("ROC", "PR")[1]; 'ROC' represents ROC (Receiver Operator Characteristic) curve; 'PR' represents PR (Precision-Recall) curve.
plot_group
default "all"; 'all' represents all the classes in the model; 'add' represents all adding micro-average and macro-average results, see https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html; other options should be one or more class names, same with the names in Group column of res_ROC$res_roc from cal_ROC function.
color_values
default RColorBrewer::brewer.pal(8, "Dark2"); colors used in the plot.
add_AUC
default TRUE; whether add AUC in the legend.
plot_method
default FALSE; If TRUE, show the method in the legend though only one method is found.
...
parameters pass to geom_path function of ggplot2 package.
ggplot2 object.
\dontrun{ t1$plot_ROC(size = 1, alpha = 0.7) }
clone()
The objects of this class are cloneable with this method.
trans_classifier$clone(deep = FALSE)
deep
Whether to make a deep clone.
## ------------------------------------------------
## Method `trans_classifier$new`
## ------------------------------------------------
data(dataset)
t1 <- trans_classifier$new(
dataset = dataset,
x.predictors = "Genus",
y.response = "Group")
## ------------------------------------------------
## Method `trans_classifier$cal_preProcess`
## ------------------------------------------------
## Not run:
t1$cal_preProcess(method = c("center", "scale", "nzv"))
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_feature_sel`
## ------------------------------------------------
## Not run:
t1$cal_feature_sel(boruta.maxRuns = 300, boruta.pValue = 0.01)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_split`
## ------------------------------------------------
## Not run:
t1$cal_split(prop.train = 3/4)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$set_trainControl`
## ------------------------------------------------
## Not run:
t1$set_trainControl(method = 'repeatedcv')
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_train`
## ------------------------------------------------
## Not run:
# random forest
t1$cal_train(method = "rf")
# Support Vector Machines with Radial Basis Function Kernel
t1$cal_train(method = "svmRadial", tuneLength = 15)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_feature_imp`
## ------------------------------------------------
## Not run:
t1$cal_feature_imp()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_feature_imp`
## ------------------------------------------------
## Not run:
t1$plot_feature_imp(use_number = 1:20, coord_flip = FALSE)
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_predict`
## ------------------------------------------------
## Not run:
t1$cal_predict()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_confusionMatrix`
## ------------------------------------------------
## Not run:
t1$plot_confusionMatrix()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$cal_ROC`
## ------------------------------------------------
## Not run:
t1$cal_ROC()
## End(Not run)
## ------------------------------------------------
## Method `trans_classifier$plot_ROC`
## ------------------------------------------------
## Not run:
t1$plot_ROC(size = 1, alpha = 0.7)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.