Tuning

In this working example we want to demonstrate how to optimize the hyperparameters for a given task and learner and how to explore the effects of the hyperparameters.

library("mlr3")
library("mlr3tuning") # fitnessfunction and tuning
library("mlr3learners") # needed to get the random forest learner
library("paradox") # needed to define parameter spaces
task = mlr_tasks$get("spam")
learner = mlr_learners$get("classif.ranger")
resampling = mlr_resamplings$get("cv", folds = 5)

# stratify resampling on target
task$col_roles$stratum = task$target_names

# we can see that the param set is unconstrained and a bit too big for effective tuning
learner$param_set

# lets build a smaller one
param_set = ParamSet$new(params = list(
  ParamFct$new("num.trees", levels = c("100", "250", "500", "1000")),
  ParamInt$new("mtry", lower = 1, upper = ceiling(task$ncol / 2))
))
param_set$add(learner$param_set$params$splitrule)
param_set$trafo = function(x, param_set) {
  x$num.trees = as.integer(x$num.trees)
  return(x)
}

# lets define the tuning function
fitness_function = FitnessFunction$new(task = task, learner = learner, resampling = resampling, param_set = param_set)

# we want to optimize 50 random evaluations but not more then 5 minutes.
terminator1 = TerminatorEvaluations$new(50)
terminator2 = TerminatorRuntime$new(max_time = 5, units = "mins")
terminator = TerminatorMultiplexer$new(terminators = list(terminator1, terminator2))

# put everything together
tuner = TunerRandomSearch$new(ff = fitness_function, terminator = terminator, batch_size = 5)

# we want to parallelize our tuning!
set.seed(1)
future::plan("multiprocess")
tuner$tune()

# lets look at the results
tuner$tune_result()

library(mlr3viz)
autoplot(fitness_function$bmr) # is the same as tuner$ff$bmr (pointer!)
library(ggplot2)
md = melt(tuner$aggregate(), id.vars = c("hash", "classif.ce", "splitrule"), measure.vars = setdiff(param_set$ids(), "splitrule"))
g = ggplot(md, aes(y = classif.ce, x = value, color = splitrule))
g = g + geom_point()
g + facet_grid(~variable, scales = "free")
agg = tuner$aggregate()
agg = agg[, c("hash", "classif.ce", param_set$ids()), with = FALSE]
bmr_task = TaskRegr$new(id = "bmr", target = "classif.ce", backend = DataBackendDataTable$new(data = agg, primary_key = "hash"))
lin_lrn = mlr_learners$get("regr.lm")
library(mlr3pipelines)
lin_lrn$train(task = bmr_task)
summary(lin_lrn$model)


nguyenngocbinh/mlr3_book_vi documentation built on Jan. 23, 2020, 12:28 p.m.