R/embedded_ensemble_fselect.R

Defines functions embedded_ensemble_fselect

Documented in embedded_ensemble_fselect

#' @title Embedded Ensemble Feature Selection
#'
#' @include CallbackBatchFSelect.R
#'
#' @description
#' Ensemble feature selection using multiple learners.
#' The ensemble feature selection method is designed to identify the most predictive features from a given dataset by leveraging multiple machine learning models and resampling techniques.
#' Returns an [EnsembleFSResult].
#'
#' @details
#' The method begins by applying an initial resampling technique specified by the user, to create **multiple subsamples** from the original dataset (train/test splits).
#' This resampling process helps in generating diverse subsets of data for robust feature selection.
#'
#' For each subsample (train set) generated in the previous step, the method applies learners
#' that support **embedded feature selection**.
#' These learners are then scored on their ability to predict on the resampled
#' test sets, storing the selected features during training, for each
#' combination of subsample and learner.
#'
#' Results are stored in an [EnsembleFSResult].
#'
#' @param learners (list of [mlr3::Learner])\cr
#'  The learners to be used for feature selection.
#'  All learners must have the `selected_features` property, i.e. implement
#'  embedded feature selection (e.g. regularized models).
#' @param init_resampling ([mlr3::Resampling])\cr
#'  The initial resampling strategy of the data, from which each train set
#'  will be passed on to the learners and each test set will be used for
#'  prediction.
#'  Can only be [mlr3::ResamplingSubsampling] or [mlr3::ResamplingBootstrap].
#' @param measure ([mlr3::Measure])\cr
#'  The measure used to score each learner on the test sets generated by
#'  `init_resampling`.
#'  If `NULL`, default measure is used.
#' @param store_benchmark_result (`logical(1)`)\cr
#'  Whether to store the benchmark result in [EnsembleFSResult] or not.
#'
#' @template param_task
#'
#' @returns an [EnsembleFSResult] object.
#'
#' @source
#' `r format_bib("meinshausen2010", "hedou2024")`
#' @export
#' @examples
#' \donttest{
#'   eefsr = embedded_ensemble_fselect(
#'     task = tsk("sonar"),
#'     learners = lrns(c("classif.rpart", "classif.featureless")),
#'     init_resampling = rsmp("subsampling", repeats = 5),
#'     measure = msr("classif.ce")
#'   )
#'   eefsr
#' }
embedded_ensemble_fselect = function(
  task,
  learners,
  init_resampling,
  measure,
  store_benchmark_result = TRUE
  ) {
  assert_task(task)
  assert_learners(as_learners(learners), task = task, properties = "selected_features")
  assert_resampling(init_resampling)
  assert_choice(class(init_resampling)[1], choices = c("ResamplingBootstrap", "ResamplingSubsampling"))
  assert_measure(measure, task = task)
  assert_flag(store_benchmark_result)

  init_resampling$instantiate(task)

  design = benchmark_grid(
    tasks = task,
    learners = learners,
    resamplings = init_resampling
  )

  bmr = benchmark(design, store_models = TRUE)

  trained_learners = bmr$score()$learner

  # extract selected features
  features = map(trained_learners, function(learner) {
    learner$selected_features()
  })

  # extract n_features
  n_features = map_int(features, length)

  # extract scores on the test sets
  scores = bmr$score(measure)
  # remove `bmr_score` class
  class(scores) = c("data.table", "data.frame")

  set(scores, j = "features", value = features)
  set(scores, j = "n_features", value = n_features)
  setnames(scores, "iteration", "resampling_iteration")

  # remove R6 objects
  set(scores, j = "learner", value = NULL)
  set(scores, j = "task", value = NULL)
  set(scores, j = "resampling", value = NULL)
  set(scores, j = "prediction_test", value = NULL)
  set(scores, j = "task_id", value = NULL)
  set(scores, j = "nr", value = NULL)
  set(scores, j = "resampling_id", value = NULL)
  set(scores, j = "uhash", value = NULL)

  EnsembleFSResult$new(
    result = scores,
    features = task$feature_names,
    benchmark_result = if (store_benchmark_result) bmr,
    measure = measure
  )
}

Try the mlr3fselect package in your browser

Any scripts or data that you put into this service are public.

mlr3fselect documentation built on April 3, 2025, 7:49 p.m.