knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = FALSE )
devtools::install_github('https://github.com/mlr-org/mlr3extralearners') devtools::install_github('https://github.com/a-hanf/mlr3automl', dependencies = TRUE)
mlr3automl
creates classification and regression models on tabular data sets. The entry point is the AutoML
function,
which requires an mlr3::TaskClassif
or mlr3::TaskRegr
object. Creating a model on the iris
data set works like this:
library(mlr3) library(mlr3automl) iris_task = tsk("iris") iris_model = AutoML(iris_task)
This creates a stable Machine Learning pipeline with Logistic Regression, Random Forest and Gradient Boosting models. For supplying your own data sets, convert them into into an mlr3 Task.
In the above code, our model has not been trained yet. Model training is very similar to training and predictions in mlr3. We use the train()
method, which takes a vector of training indices as an additional argument.
train_indices = sample(1:iris_task$nrow, 2/3*iris_task$nrow) iris_model$train(row_ids = train_indices)
To make predictions we use the predict()
method. We can supply a held-out data set here, or alternatively use indices to specify observations that were not used during training.
predict_indices = setdiff(1:iris_task$nrow, train_indices) predictions = iris_model$predict(row_ids = predict_indices)
In order to make this a bit more convenient, mlr3automl
also provides a resampling endpoint. Instead of performing the train-test-split ourselves, we can let mlr3automl
do this for us.
resampling_result = iris_model$resample()
We obtain an mlr3::ResampleResult
. For more information on how to analyze this result, see the resampling section in the mlr3book.
mlr3automl
models can be easily interpreted using the popular iml or DALEX packages. Let's interpret our model trained on the iris data set from the previous example:
dalex_explainer = iris_model$explain(iml_package = "DALEX") iml_explainer = iris_model$explain(iml_package = "iml") # compute and plot feature permutation importance using DALEX dalex_importance = DALEX::model_parts(dalex_explainer) plot(dalex_importance) # partial dependency plot using iml package iml_pdp = iml::FeatureEffect$new(iml_explainer, feature="Sepal.Width", method="pdp") plot(iml_pdp)
mlr3automl
offers the following customization options in the AutoML
function:
model = AutoML(task = iris_task, runtime = 180)
model = AutoML(task = iris_task, measure = msr("classif.logloss"))
model = AutoML(task = iris_task, learner_list = c("classif.svm", "classif.ranger"))
model = AutoML(task = iris_task, resampling = rsmp("cv"))
model = AutoML(task = iris_task, learner_list = c("classif.ranger"), preprocessing = "none")
Example #1: we create a regression model with custom learners and a fixed time budget. Every hyperparameter evaluation is stopped after 10 seconds. After 300 seconds of training, tuning is stopped and the best result obtained so far is returned:
automl_model = AutoML( task=tsk("mtcars"), learner_list=c("regr.ranger", "regr.lm"), learner_timeout=10, runtime=300)
Example #2: we create a pipeline of preprocessing operators using mlr3pipelines
. This pipeline replaces the preprocessing pipeline in mlr3automl
. You can tune the choice of custom preprocessing operators and their associated hyperparameters by supplying additional parameters (see example #3).
library(mlr3pipelines) imbalanced_preproc = po("imputemean") %>>% po("smote") %>>% po("classweights", minor_weight=2) automl_model = AutoML(task=tsk("pima"), preprocessing = imbalanced_preproc)
Example #3: we add a k-nearest-neighbors classifier to the learners, which has no pre-defined hyperparameter search space in mlr3automl
. To perform hyperparameter tuning, we supply a parameter set using paradox
. A parameter transformation is supplied in order to sample the hyperparameter on an exponential scale.
library(paradox) new_params = ParamSet$new(list( ParamInt$new("classif.kknn.k", lower = 1, upper = 5, default = 3, tags = "kknn"))) my_trafo = function(x, param_set) { if ("classif.kknn.k" %in% names(x)) { x[["classif.kknn.k"]] = 2^x[["classif.kknn.k"]] } return(x) } automl_model = AutoML( task=tsk("iris"), learner_list="classif.kknn", additional_params=new_params, custom_trafo=my_trafo)
mlr3automl
tackles the challenge of Automated Machine Learning from multiple angles.
We tested mlr3automl
on 39 challenging data sets in the AutoML Benchmark. By including up to 12 preprocessing steps, mlr3automl
is
stable in the presence of missing data, categorical and high cardinality features, huge data sets and constrained time budgets.
We evaluated many learning algorithms and implementations in order to find the most stable and accurate learners for mlr3automl
. We decided to use the following:
When selecting the best model for a data set, up to 8 predefined pipelines are evaluated first. These are our most robust, fast and accurate pipelines, which provide us with a strong baseline even on the most challenging tasks.
We use Hyperband [@li2017hyperband] to tune the hyperparameters of our machine learning pipeline. Hyperband speeds up random search through adaptive resource allocation and early-stopping. When tuning the hyperparameters in mlr3automl
, at first many learners will be evaluated on small subsets of the dataset (this is quick). Later on, fewer models get trained on larger subsets or the full dataset (which is more expensive computationally). This allows us to find promising pipelines with little computational cost.
knitr::kable( data.table::data.table( Framework = c("AutoGluon", "autosklearn v0.8", "autosklearn v2.0", "H2O AutoML", "mlr3automl", "TPOT"), "avg. rank(binary tasks)" = c("2.32", "3.34", "3.57", "3.18", "4.55", "4.05"), "avg. rank(multi-class)" = c("2.09", "3.15", "4.06", "3.24", "3.79", "4.68"), Failures = c(0, 2, 9, 2, 0, 6)) )
We benchmarked mlr3automl on the AutoML benchmark, which contains 39 challenging classification data sets. Under a restrictive time budget of at most 1 hour per task, mlr3automl successfully completed every single task.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.