#' @title Abstract Surrogate Model Filtering Base Class
#'
#' @include Filtor.R
#'
#' @description
#' Abstract base class for surrogate model filtering.
#'
#' A *surrogate model* is a regression model, based on an [`mlr3::Learner`], which predicts the approximate performance of newly sampled configurations
#' given the empirical performance of already evaluated configurations. The surrogate model can be used to propose points that have, according to the
#' surrogate model, a relatively high chance of performing well.
#'
#' The `FiltorSurrogate` base class can be inherited from to create different [`Filtor`]s that filter based on a surrogate model, for example tournament
#' filtering or progresive filtering.
#'
#' @section Configuration Parameters:
#' `FiltorSurrogateProgressive`'s configuration parameters are the hyperparameters of the `surrogate_learner` [`Learner`][mlr3::Learner], as well as
#' the configuration parameters of the `surrogate_selector` [`Selector`].
#'
#' @section Supported Operand Types:
#'
#' Supported [`Domain`][paradox::Domain] classes depend on the supported feature types of the `surrogate_learner`, as reported
#' by `surrogate_learner$feature_types`: `"ParamInt"` requires
#' `"integer"`, `"ParamDbl"` requires `"numeric"`, `"ParamLgl"` requires `"logical"`, and `"ParamFct"` requires `"factor"`.
#'
#' @template param_surrogate_learner
#' @template param_surrogate_selector
#' @param param_set ([`ParamSet`][paradox::ParamSet])\cr
#' [`ParamSet`][paradox::ParamSet] of the method implemented in the inheriting class with configuration parameters that go beyond the
#' parameters of the `surrogate_learner` and `surrogate_selector`.
#' @template param_packages
#' @template param_dict_entry
#'
#' @family base classes
#' @family filtors
#'
FiltorSurrogate = R6Class("FiltorSurrogate",
inherit = Filtor,
public = list(
#' @description
#' Initialize the base class components of the `FiltorSurrogate`.
#' @template param_surrogate_learner
#' @template param_surrogate_selector
#' @template param_param_set
#' @template param_packages
#' @template param_dict_entry
initialize = function(surrogate_learner, surrogate_selector = SelectorBest$new(), param_set = ps(), packages = character(0), dict_entry = NULL) {
private$.surrogate_learner = mlr3::as_learner(surrogate_learner, clone = TRUE)
# can't assert LearnerRegr because GraphLearner doesn't announce that. Instead, we check $task_type
assert_true(private$.surrogate_learner$task_type == "regr", .var.name = 'surrogate_learner$task_type == "regr"')
private$.surrogate_selector = assert_r6(surrogate_selector, "Selector")$clone(deep = TRUE)
private$.own_param_set = param_set
if (!paradox_s3) {
private$.surrogate_selector$param_set$set_id = "select"
private$.own_param_set$set_id = "filter"
}
private$.own_param_set_id = "filter"
param_classes = c("ParamInt", "ParamDbl", "ParamLgl", "ParamFct")
param_classes = param_classes[c("integer", "numeric", "logical", "factor") %in% surrogate_learner$feature_types]
param_classes = intersect(param_classes, surrogate_selector$param_classes)
super$initialize(param_classes, alist(filter = private$.own_param_set,
select = private$.surrogate_selector$param_set, private$.surrogate_learner$param_set),
supported = surrogate_selector$supported,
packages = c("mlr3", surrogate_selector$packages, surrogate_learner$packages, packages),
dict_entry = dict_entry, own_param_set = quote(private$.own_param_set)
)
},
#' @description
#' See [`MiesOperator`] method. Primes both this operator, as well as the wrapped operator
#' given to `surrogate_selector` during construction.
#' @param param_set ([`ParamSet`][paradox::ParamSet])\cr
#' Passed to [`MiesOperator`]`$prime()`.
#' @return [invisible] `self`.
prime = function(param_set) {
private$.surrogate_selector$prime(param_set)
if (param_set$has_deps && "missings" %nin% private$.surrogate_learner$properties) {
stop("Surrogate learner %s needs to handle missing values for search space with dependencies", private$.surrogate_learner$id)
}
super$prime(param_set)
invisible(self)
}
),
active = list(
#' @field surrogate_learner ([`mlr3::LearnerRegr`])\cr
#' Regression learner for the surrogate model filtering algorithm.
surrogate_learner = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.surrogate_learner)) {
stop("surrogate_learner is read-only.")
}
private$.surrogate_learner
},
#' @field surrogate_selector ([`Selector`])\cr
#' [`Selector`] with which to select using surrogate-predicted performance
surrogate_selector = function(rhs) {
if (!missing(rhs) && !identical(rhs, private$.surrogate_selector)) {
stop("surrogate_selector is read-only.")
}
private$.surrogate_selector
}
),
private = list(
.filter = function(values, known_values, fitnesses, n_filter) {
params = private$.own_param_set$get_values()
primed = self$primed_ps
values = first(values, self$needed_input(n_filter))
if (nrow(values) == n_filter) return(seq_len(n_filter))
fcolname = "fitnesses"
while (fcolname %in% colnames(known_values)) {
fcolname = paste0(".", fcolname)
}
surrogate_prediction = apply(fitnesses, 2, function(f) {
known_values[[fcolname]] = f
self$surrogate_learner$train(
mlr3::TaskRegr$new("surrogate", with_factor_cols(known_values, primed), target = fcolname)
)$predict_newdata(with_factor_cols(values, primed))$data$response
})
# when things are one-dimensional they cease to be a matrix, so we force it here.
surrogate_prediction = matrix(surrogate_prediction, nrow = nrow(values), ncol = ncol(fitnesses))
private$.filter_surrogate(values, surrogate_prediction, known_values, fitnesses, n_filter)
},
.filter_surrogate = function(values, surrogate_prediction, known_values, fitnesses, n_filter) stop("abstract."),
.surrogate_learner = NULL,
.surrogate_selector = NULL,
.own_param_set = NULL
)
)
with_factor_cols = function(table, param_set) {
table = copy(table)
pclass = param_set$class
fcols = names(pclass)[pclass == "ParamFct"]
plevels = param_set$levels
for (col in fcols) {
set(table, , col, factor(table[[col]], plevels[[col]]))
}
table
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.