Nothing
#' @title (experimental) Conditional Inference Tree Conditional Sampler
#'
#' @description Implements conditional sampling using conditional inference trees (ctree).
#' Builds a tree predicting target features from conditioning features, then samples
#' from the terminal node corresponding to each test observation.
#'
#' @details
#' This sampler approximates the conditional distribution \eqn{P(X_B | X_A = x_A)} by:
#' 1. Building a conditional inference tree with \eqn{X_B} as response and \eqn{X_A} as predictors
#' 2. For each test observation, finding its terminal (leaf) node in the tree
#' 3. Sampling uniformly from training observations in that same terminal node
#'
#' Conditional inference trees (ctree) use permutation tests to determine splits,
#' which helps avoid overfitting and handles mixed feature types naturally. The tree
#' partitions the feature space based on the conditioning variables, creating local
#' neighborhoods that respect the conditional distribution structure.
#'
#' **Key advantages over other samplers:**
#' - Handles mixed feature types (continuous and categorical)
#' - Non-parametric (no distributional assumptions)
#' - Automatic feature selection (splits only on informative features)
#' - Can capture non-linear conditional relationships
#' - Statistically principled splitting criteria
#'
#' **Hyperparameters** control tree complexity:
#' - `mincriterion`: Significance level for splits (higher = fewer splits)
#' - `minsplit`: Minimum observations required for a split
#' - `minbucket`: Minimum observations in terminal nodes
#'
#' This implementation is inspired by shapr's ctree approach but simplified for our
#' use case (we build trees on-demand rather than pre-computing all subsets).
#'
#' **Advantages:**
#' - Works with any feature types
#' - Robust to outliers
#' - Interpretable tree structure
#' - Handles high-dimensional conditioning
#'
#' **Limitations:**
#' - Requires model fitting (slower than kNN)
#' - Can produce duplicates if terminal nodes are small
#' - Tree building time increases with data size
#'
#' @examples
#' \donttest{
#' library(mlr3)
#' task = tgen("friedman1")$generate(n = 100)
#'
#' # Create sampler with default parameters
#' sampler = ConditionalCtreeSampler$new(task)
#'
#' # Sample features conditioned on others
#' test_data = task$data(rows = 1:5)
#' sampled = sampler$sample_newdata(
#' feature = c("important2", "important3"),
#' newdata = test_data,
#' conditioning_set = "important1"
#' )
#' }
#'
#' @references `r print_bib("hothorn_2006", "aas_2021")`
#'
#' @export
ConditionalCtreeSampler = R6Class(
"ConditionalCtreeSampler",
inherit = ConditionalSampler,
public = list(
#' @field feature_types (`character()`) Feature types supported by the sampler.
#' Will be checked against the provided [mlr3::Task] to ensure compatibility.
feature_types = c("numeric", "integer", "factor", "ordered"),
#' @field tree_cache (`environment`) Cache for fitted ctree models.
tree_cache = NULL,
#' @description
#' Creates a new ConditionalCtreeSampler.
#' @param task ([mlr3::Task]) Task to sample from.
#' @param conditioning_set (`character` | `NULL`) Default conditioning set to use in `$sample()`.
#' @param mincriterion (`numeric(1)`: `0.95`) Significance level threshold for splitting (1 - p-value).
#' Higher values result in fewer splits (simpler trees).
#' @param minsplit (`integer(1)`: `20L`) Minimum number of observations required for a split.
#' @param minbucket (`integer(1)`: `7L`) Minimum number of observations in terminal nodes.
#' @param use_cache (`logical(1)`: `TRUE`) Whether to cache fitted trees.
initialize = function(
task,
conditioning_set = NULL,
mincriterion = 0.95,
minsplit = 20L,
minbucket = 7L,
use_cache = TRUE
) {
require_package("partykit")
super$initialize(task, conditioning_set = conditioning_set)
# Initialize tree cache
self$tree_cache = new.env(parent = emptyenv())
# Extend param_set with ctree-specific parameters
self$param_set = c(
self$param_set,
paradox::ps(
mincriterion = paradox::p_dbl(lower = 0, upper = 1, default = 0.95),
minsplit = paradox::p_int(lower = 1L, default = 20L),
minbucket = paradox::p_int(lower = 1L, default = 7L),
use_cache = paradox::p_lgl(default = TRUE)
)
)
self$param_set$set_values(
mincriterion = mincriterion,
minsplit = minsplit,
minbucket = minbucket,
use_cache = use_cache
)
self$label = "Conditional Inference Tree Conditional Sampler"
}
),
private = list(
# Core ctree sampling logic implementing conditional inference tree sampling
.sample_conditional = function(data, feature, conditioning_set, ...) {
# Get training data from task
training_data = self$task$data(cols = self$task$feature_names)
# Handle marginal case (no conditioning)
if (length(conditioning_set) == 0) {
# Simple random sampling (with replacement) from training data
data[, (feature) := lapply(.SD, sample, replace = TRUE), .SDcols = feature]
return(data[, .SD, .SDcols = c(self$task$target_names, self$task$feature_names)])
}
# Get or build tree for this feature/conditioning set combination
tree = private$.get_or_build_tree(feature, conditioning_set)
# Add temporary ID to track original rows
data[, ..eval_id := seq_len(.N)]
# Predict terminal node for each observation
# Use type = "node" to get node IDs
node_ids = predict(tree, newdata = data, type = "node")
data[, ..node_id := node_ids]
# Get terminal node memberships for training data
train_nodes = predict(tree, newdata = training_data, type = "node")
train_data_with_nodes = copy(training_data)
train_data_with_nodes[, ..node_id := train_nodes]
# For each evaluation observation, sample from training observations in same node
result = data[,
{
# Get training observations in this node
node_obs = train_data_with_nodes[..node_id == ..node_id[1]]
if (nrow(node_obs) == 0) {
cli::cli_warn("No training observations in terminal node, using marginal sample")
sampled_values = lapply(feature, function(feat) {
sample(training_data[[feat]], 1)
})
} else {
# Sample one observation from this node
sampled_idx = sample(seq_len(nrow(node_obs)), 1)
sampled_values = lapply(feature, function(feat) {
node_obs[[feat]][sampled_idx]
})
}
# Create result row
result_row = .SD[1]
for (i in seq_along(feature)) {
set(result_row, j = feature[i], value = sampled_values[[i]])
}
result_row
},
by = ..eval_id
]
# Clean up temporary columns
result[, ..eval_id := NULL]
result[, ..node_id := NULL]
result[, .SD, .SDcols = c(self$task$target_names, self$task$feature_names)]
},
# Get cached tree or build a new one
.get_or_build_tree = function(feature, conditioning_set) {
# Create cache key from feature names and conditioning set
cache_key = paste(
paste(sort(match(feature, self$task$feature_names)), collapse = "+"),
"|",
paste(sort(match(conditioning_set, self$task$feature_names)), collapse = "+")
)
use_cache = self$param_set$values$use_cache
# Check cache
if (use_cache && exists(cache_key, envir = self$tree_cache)) {
return(get(cache_key, envir = self$tree_cache))
}
# Build new tree
tree = private$.build_tree(feature, conditioning_set)
# Cache if enabled
if (use_cache) {
assign(cache_key, tree, envir = self$tree_cache)
}
tree
},
# Build a ctree model
.build_tree = function(feature, conditioning_set) {
# Get training data from task
training_data = self$task$data(cols = self$task$feature_names)
# Create formula: features ~ conditioning_set
# For multiple features, use cbind on LHS
if (length(feature) == 1) {
formula_str = sprintf(
"`%s` ~ %s",
feature,
paste(sprintf("`%s`", conditioning_set), collapse = " + ")
)
} else {
formula_str = sprintf(
"cbind(%s) ~ %s",
paste(sprintf("`%s`", feature), collapse = ", "),
paste(sprintf("`%s`", conditioning_set), collapse = " + ")
)
}
formula = as.formula(formula_str)
# Extract control parameters
mincriterion = self$param_set$values$mincriterion
minsplit = self$param_set$values$minsplit
minbucket = self$param_set$values$minbucket
# Build tree with partykit::ctree
tree = partykit::ctree(
formula = formula,
data = training_data,
control = partykit::ctree_control(
mincriterion = mincriterion,
minsplit = minsplit,
minbucket = minbucket
)
)
tree
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.