devtools::install_github('https://github.com/mlr-org/mlr3extralearners')
devtools::install_github('https://github.com/a-hanf/mlr3automl', dependencies = TRUE)
mlr3automl
creates classification and regression models on tabular
data sets. The entry point is the AutoML
function, which requires an
mlr3::TaskClassif
or mlr3::TaskRegr
object. Creating a model on the
iris
data set works like this:
library(mlr3)
library(mlr3automl)
iris_task = tsk("iris")
iris_model = AutoML(iris_task)
This creates a stable Machine Learning pipeline with Logistic Regression, Random Forest and Gradient Boosting models. For supplying your own data sets, convert them into into an mlr3 Task.
In the above code, our model has not been trained yet. Model training is
very similar to training and predictions in
mlr3. We use the
train()
method, which takes a vector of training indices as an
additional argument.
train_indices = sample(1:iris_task$nrow, 2/3*iris_task$nrow)
iris_model$train(row_ids = train_indices)
To make predictions we use the predict()
method. We can supply a
held-out data set here, or alternatively use indices to specify
observations that were not used during training.
predict_indices = setdiff(1:iris_task$nrow, train_indices)
predictions = iris_model$predict(row_ids = predict_indices)
In order to make this a bit more convenient, mlr3automl
also provides
a resampling endpoint. Instead of performing the train-test-split
ourselves, we can let mlr3automl
do this for us.
resampling_result = iris_model$resample()
We obtain an mlr3::ResampleResult
. For more information on how to
analyze this result, see the resampling section in the
mlr3book.
mlr3automl
models can be easily interpreted using the popular
iml or
DALEX packages. Let’s
interpret our model trained on the iris data set from the previous
example:
dalex_explainer = iris_model$explain(iml_package = "DALEX")
iml_explainer = iris_model$explain(iml_package = "iml")
# compute and plot feature permutation importance using DALEX
dalex_importance = DALEX::model_parts(dalex_explainer)
plot(dalex_importance)
# partial dependency plot using iml package
iml_pdp = iml::FeatureEffect$new(iml_explainer, feature="Sepal.Width", method="pdp")
plot(iml_pdp)
mlr3automl
offers the following customization options in the AutoML
function:
model = AutoML(task = iris_task, runtime = 180)
model = AutoML(task = iris_task, measure = msr("classif.logloss"))
model = AutoML(task = iris_task, learner_list = c("classif.svm", "classif.ranger"))
model = AutoML(task = iris_task, resampling = rsmp("cv"))
model = AutoML(task = iris_task, learner_list = c("classif.ranger"), preprocessing = "none")
Example #1: we create a regression model with custom learners and a fixed time budget. Every hyperparameter evaluation is stopped after 10 seconds. After 300 seconds of training, tuning is stopped and the best result obtained so far is returned:
automl_model = AutoML(
task=tsk("mtcars"),
learner_list=c("regr.ranger", "regr.lm"),
learner_timeout=10,
runtime=300)
Example #2: we create a pipeline of preprocessing operators using
mlr3pipelines
. This pipeline replaces the preprocessing pipeline in
mlr3automl
. You can tune the choice of custom preprocessing operators
and their associated hyperparameters by supplying additional parameters
(see example #3).
library(mlr3pipelines)
imbalanced_preproc = po("imputemean") %>>%
po("smote") %>>%
po("classweights", minor_weight=2)
automl_model = AutoML(task=tsk("pima"),
preprocessing = imbalanced_preproc)
Example #3: we add a k-nearest-neighbors classifier to the learners,
which has no pre-defined hyperparameter search space in mlr3automl
. To
perform hyperparameter tuning, we supply a parameter set using
paradox
. A parameter transformation is supplied in order to sample the
hyperparameter on an exponential scale.
library(paradox)
new_params = ParamSet$new(list(
ParamInt$new("classif.kknn.k",
lower = 1, upper = 5, default = 3, tags = "kknn")))
my_trafo = function(x, param_set) {
if ("classif.kknn.k" %in% names(x)) {
x[["classif.kknn.k"]] = 2^x[["classif.kknn.k"]]
}
return(x)
}
automl_model = AutoML(
task=tsk("iris"),
learner_list="classif.kknn",
additional_params=new_params,
custom_trafo=my_trafo)
mlr3automl
tackles the challenge of Automated Machine Learning from
multiple angles.
We tested mlr3automl
on 39 challenging data sets in the AutoML
Benchmark.
By including up to 12 preprocessing steps, mlr3automl
is stable in the
presence of missing data, categorical and high cardinality features,
huge data sets and constrained time budgets.
We evaluated many learning algorithms and implementations in order to
find the most stable and accurate learners for mlr3automl
. We decided
to use the following:
When selecting the best model for a data set, up to 8 predefined pipelines are evaluated first. These are our most robust, fast and accurate pipelines, which provide us with a strong baseline even on the most challenging tasks.
We use Hyperband (Li et al. 2017) to tune the hyperparameters of our
machine learning pipeline. Hyperband speeds up random search through
adaptive resource allocation and early-stopping. When tuning the
hyperparameters in mlr3automl
, at first many learners will be
evaluated on small subsets of the dataset (this is quick). Later on,
fewer models get trained on larger subsets or the full dataset (which is
more expensive computationally). This allows us to find promising
pipelines with little computational cost.
| Framework | avg. rank(binary tasks) | avg. rank(multi-class) | Failures | |:-----------------|:------------------------|:-----------------------|---------:| | AutoGluon | 2.32 | 2.09 | 0 | | autosklearn v0.8 | 3.34 | 3.15 | 2 | | autosklearn v2.0 | 3.57 | 4.06 | 9 | | H2O AutoML | 3.18 | 3.24 | 2 | | mlr3automl | 4.55 | 3.79 | 0 | | TPOT | 4.05 | 4.68 | 6 |
We benchmarked mlr3automl on the AutoML benchmark, which contains 39 challenging classification data sets. Under a restrictive time budget of at most 1 hour per task, mlr3automl successfully completed every single task.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.