R/FeatureSampler.R

#' @title Feature Sampler Class
#'
#' @description Base class for implementing different sampling strategies
#' for feature importance methods like PFI and CFI
#'
#' @export
FeatureSampler = R6Class(
	"FeatureSampler",
	public = list(
		#' @field task ([mlr3::Task]) Original task.
		task = NULL,
		#' @field label (`character(1)`) Name of the sampler.
		label = NULL,
		#' @field feature_types (`character()`) Feature types supported by the sampler.
		#'   Will be checked against the provided [mlr3::Task] to ensure compatibility.
		feature_types = c(
			"numeric",
			"factor",
			"ordered",
			"integer",
			"logical",
			"Date",
			"POSIXct",
			"character"
		),
		#' @field param_set ([paradox::ParamSet]) Parameter set for the sampler.
		param_set = NULL,

		#' @description
		#' Creates a new instance of the FeatureSampler class
		#' @param task ([mlr3::Task]) Task to sample from
		initialize = function(task) {
			self$task = mlr3::assert_task(task, feature_types = self$feature_types)
			# Initialize empty param_set - subclasses should define their own
			self$param_set = paradox::ps()
		},

		#' @description
		#' Sample values for feature(s) from stored task
		#' @param feature (`character`) Feature name(s) to sample (can be single or multiple). Must match those in the stored [Task][mlr3::Task].
		#' @param row_ids (`integer()`: `NULL`) Row IDs of the stored [Task][mlr3::Task] to use as basis for sampling.
		#' @return Modified copy of the input features with the feature(s) sampled:
		#'   A [data.table][data.table::data.table] with same number of columns and one row matching the supplied `row_ids`
		sample = function(feature, row_ids = NULL) {
			cli::cli_abort(c(
				"Abtract method",
				i = "Use a concrete implementation."
			))
		},
		#' @description
		#' Sample values for feature(s) using external data
		#' @param feature (`character`) Feature name(s) to sample (can be single or multiple)
		#' @param newdata ([`data.table`][data.table::data.table] ) External data to use for sampling.
		sample_newdata = function(feature, newdata) {
			cli::cli_abort(c(
				"Abtract method",
				i = "Use a concrete implementation."
			))
		},

		#' @description
		#' Print sampler
		#'
		#' @param ... Ignored.
		print = function(...) {
			cli::cli_h1(self$label)
			cli::cli_ul()
			cli::cli_li(
				"Task: {.val {self$task$id}} ({.strong {self$task$nrow}x{self$task$n_features}})"
			)
			cli::cli_li("Supported feature type{?s}: {.val {self$feature_types}}")
			cli::cli_end()
		}
	),
	private = list(
		.get_task_data_by_row_id = function(row_ids) {
			if (!checkmate::test_subset(row_ids, self$task$row_ids)) {
				cli::cli_abort(c(
					x = "Requested {.val {length(setdiff(row_ids, self$task$row_ids))}} row_id{?s} not in stored {.cls Task}.",
					"!" = "For {.code $sample}, the row_ids must match those of the stored {.cls Task}."
				))
			}
			self$task$data(rows = row_ids)
		},

		# Ensure feature types match task specification, fixing what can be fixed
		# and asserting correctness. Only modifies data when necessary.
		# mlr3 feature types correspond to R's class(), not typeof()
		.ensure_feature_types = function(data) {
			feature_types <- self$task$feature_types

			for (feat in self$task$feature_names) {
				expected_type <- feature_types[id == feat, type]
				actual_class <- class(data[[feat]])[1]

				if (expected_type == actual_class) {
					next
				}

				# Attempt to fix known fixable cases
				if (expected_type == "integer" && actual_class == "numeric") {
					set(data, j = feat, value = as.integer(round(data[[feat]])))
					next
				}

				# Type mismatch that cannot be automatically fixed
				cli::cli_abort(c(
					x = "Feature {.val {feat}} has class {.val {actual_class}} but expected {.val {expected_type}}.",
					i = "This indicates a bug in the sampler implementation."
				))
			}

			data
		}
	)
)

Try the xplainfi package in your browser

Any scripts or data that you put into this service are public.

xplainfi documentation built on Feb. 27, 2026, 1:08 a.m.