mlr3summary playground

knitr::opts_chunk$set(echo = TRUE)

options(crayon.enabled = TRUE)
ansi_aware_handler <- function(x, options) {
  paste0(
    "<pre class=\"r-output\"><code>",
    fansi::sgr_to_html(x = x, warn = FALSE, term.cap = "256"),
    "</code></pre>"
  )
}

knitr::knit_hooks$set(
  output = ansi_aware_handler, 
  message = ansi_aware_handler, 
  warning = ansi_aware_handler,
  error = ansi_aware_handler
)

Load the necessary packages

library(mlr3)
library(mlr3learners)
library(mlr3filters)
library(mlr3pipelines)

DEVELOP = TRUE

# # Developer mode?
if (!DEVELOP) {
  library(mlr3summary)
} else {
  library(devtools)
  load_all()
}

set.seed(1812L)

Regression example

The first example is based on the mtcars task available in mlr3. We train a regression tree on it.

tsk_cars = tsk("mtcars")
lrn_rpart = lrn("regr.rpart")
lrn_rpart$train(task = tsk_cars)

We can receive an overview of the model with summary().

summary(lrn_rpart)

Because we have no hold-out/test data, performance results, etc. are not shown. This is because performance evaluations on the training data can be biased due to overfitting. To receive performance results we conduct resampling to allow for evaluations based on hold-out data. We use the generated ResampleResult (here rr_1) as an additional input to summary().

rsmp_cv3 = rsmp("cv", folds = 3)
rr_1 = resample(tsk_cars, lrn_rpart, rsmp_cv3, store_model = TRUE)
summary(lrn_rpart, rr_1)

Multiple performance measures are also possible to display via mlr3::msrs

summary(lrn_rpart, rr_1, control = summary_control(measures = msrs(c("regr.bias", "regr.mae"))))

The model can also comprise multiple pre-processing steps conducted via the package mlr3pipelines, these will also be shown in the summary() output. In the following, feature filtering is conducted based on the feature importance. Please note that now a paragraph on the pipeline structure was added to the summary() output.

graph_learner = as_learner(po("filter", filter = mlr3filters::flt("variance"), filter.frac = 0.5) %>>%
  po("learner", mlr3::lrn("regr.rpart")))
graph_learner$train(tsk_cars)
rr_2 = resample(tsk_cars, graph_learner, rsmp_cv3, store_model = TRUE)
summary(graph_learner, rr_2)

Here, graph_learner is a GraphLearner object after transforming it with as_learner(). Calling as_learner() is strictly necessary, trained Graphs currently cannot be processed.

graph = po("filter", filter = mlr3filters::flt("variance"), filter.frac = 0.5) %>>%
  po("learner", mlr3::lrn("regr.rpart"))
graph$train(tsk_cars)
rr_3 = resample(tsk_cars, graph, rsmp_cv3, store_model = TRUE)
summary(graph, rr_3)

Currently only linear pipelines can be displayed in summary(), more complex non-linear structures are only displayed by <suppressed>.

set.seed(1234L)
graph_complex = po("scale", center = TRUE, scale = FALSE) %>>%
  gunion(list(
    po("missind"),
    po("imputemedian")
  )) %>>%
  po("featureunion") %>>%
  po("learner", mlr3::lrn("regr.rpart"))
graph_complex = as_learner(graph_complex)
graph_complex$train(tsk_cars)
rr_4 = resample(tsk_cars, graph_complex, rsmp_cv3, store_model = TRUE)
summary(graph_complex, rr_4)

Multiple importance measures are also possible:

summary(lrn_rpart, rr_3, control = summary_control(importance_measures = c("pfi.rmse", "pdp")))

Only display n_important = 3L most important features.

summary(lrn_rpart, rr_3, control = summary_control(importance_measures = c("pdp", "pfi.rmse"), n_important = 3L))

Also micro/macro versions of performances can be computed, by updating used measures in summary_control().

perfms = c(msr("regr.bias", id = "regr.bias.mi", average = "micro"),
  msr("regr.bias", id = "regr.bias.ma", average = "macro"))
summary(lrn_rpart, rr_3, control = summary_control(measures = perfms))

Classification example

The second example is based on the penguins task available in mlr3. We train an xgboost model on it.

tsk_peng = tsk("iris")
lrn_xgboost = lrn("classif.xgboost", predict_type = "prob")
lrn_xgboost$train(task = tsk_peng)

We can receive an overview of the model with summary().

summary(lrn_xgboost)

To receive performance results we conduct resampling to allow for evaluations based on hold-out data, here we use bootstrap resampling. We use the generated ResampleResult (here rr_5) as an additional input to summary().

rsmp_bs = rsmp("bootstrap", repeats = 5L)
rr_5 = resample(tsk_peng, lrn_xgboost, rsmp_bs, store_model = TRUE)
summary(lrn_xgboost, rr_5)

Binary classification

The third example is based on the breast_cancer task available in mlr3. We train an ranger model on it.

tsk_bc = tsk("breast_cancer")
lrn_ranger = lrn("classif.ranger", predict_type = "prob")
lrn_ranger$train(task = tsk_bc)
rsmp_cv5 = rsmp("subsampling", repeats = 2L, ratio = 0.5)
rr_6 = resample(tsk_bc, lrn_ranger, rsmp_cv5, store_model = TRUE)
summary(lrn_ranger, rr_6, control = summary_control(importance_measures = "shap"))

Fairness task

The following, demonstrates how to receive fairness metrics. Therefore, a protected attribute (pta) needs to be specified. The fairness measure can be adapted in control.

library("mlr3fairness")
tsk_peng = tsk("penguins")
tsk_peng$set_col_roles("sex", add_to = "pta")
lrn_rpart =  lrn("classif.rpart", predict_type = "prob")
lrn_rpart$train(task = tsk_peng)
rsmp_cv5 = rsmp("cv", folds = 5L)
rr_9 = resample(tsk_peng, lrn_rpart, rsmp_cv5, store_model = TRUE)
summary(lrn_rpart, rr_9, summary_control(complexity_measures = NULL, effect_measures = NULL, fairness_measures = (msr("fairness", operation = groupdiff_absdiff, base_measure = msr("classif.acc")))))

An additional paragraph was added in the output for fairness assessment.

Error handling

Input checks ensure that the ML algorithm and task for the model/object and resample_result must match. As an example, we try to summarize lrn_xgboost trained on tsk_peng, using rr_1 as an additional input, which however is based on lrn_rpart trained on tsk_cars.

summary(lrn_xgboost, rr_1)


Try the mlr3summary package in your browser

Any scripts or data that you put into this service are public.

mlr3summary documentation built on May 29, 2024, 2:44 a.m.