Nothing
#' Fit and evaluate with leakage guards over predefined splits
#'
#' Performs cross-validated model training and evaluation using
#' leakage-protected preprocessing (.guard_fit) and user-specified learners.
#'
#' @param x SummarizedExperiment or matrix/data.frame
#' @param outcome outcome column name (if x is SE or data.frame), or a length-2
#' character vector of time/event column names for survival outcomes.
#' @param splits LeakSplits object from make_split_plan(), or an `rsample` rset/rsplit.
#' @param split_cols Optional named list/character vector or `"auto"` (default)
#' overriding group/batch/study/time column names when `splits` is an rsample
#' object and its attributes are missing. `"auto"` falls back to common
#' metadata column names (e.g., `group`, `subject`, `batch`, `study`, `time`).
#' Supported names are `group`, `batch`, `study`, and `time`.
#' @param store_refit_data Logical; when TRUE (default), stores the original
#' data and learner configuration inside the fit to enable refit-based
#' permutation tests without manual `perm_refit_spec` setup.
#' @param preprocess list(impute, normalize, filter=list(...), fs) or a
#' `recipes::recipe` object. When a recipe is supplied, the guarded preprocessing
#' pipeline is bypassed and the recipe is prepped on training data only.
#' Recipe/workflow leakage guardrails run before fitting; configure policy via
#' \code{options(bioLeak.validation_mode = "warn" | "error" | "off")}.
#' @param learner parsnip model_spec (or list of model_spec objects) describing
#' the model(s) to fit, or a `workflows::workflow`. For legacy use, a character
#' vector of learner names (e.g., "glmnet", "ranger") or custom learner IDs is
#' still supported.
#' @param learner_args list of additional arguments passed to legacy learners
#' (ignored when `learner` is a parsnip model_spec).
#' @param custom_learners named list of custom learner definitions used only
#' with legacy character learners. Each entry
#' must contain \code{fit} and \code{predict} functions. The \code{fit} function
#' should accept \code{x}, \code{y}, \code{task}, and \code{weights}, and return
#' a model object. The \code{predict} function should accept \code{object},
#' \code{newdata}, and \code{task}. For binomial/regression/survival tasks it
#' should return a numeric vector; for multiclass tasks it should return either
#' class labels or a matrix/data.frame of class probabilities.
#' @param metrics named list of metric functions, vector of metric names, or a
#' `yardstick::metric_set`. When a yardstick metric set (or list of yardstick
#' metric functions) is supplied, metrics are computed using yardstick with the
#' positive class set to the second factor level.
#' @param class_weights optional named numeric vector of weights for binomial or
#' multiclass outcomes
#' @param positive_class optional value indicating the positive class for binomial outcomes.
#' When set, the outcome levels are reordered so that \code{positive_class} is treated
#' as the positive class (level 2). If NULL, the second factor level is used.
#' @param classification_threshold Numeric threshold in \code{[0, 1]} used to
#' convert binomial probabilities into class predictions for
#' \code{pred_class} and accuracy metrics. Ignored for non-binomial tasks.
#' @param parallel logical, use future.apply for multicore execution
#' @param refit logical, if TRUE retrain final model on full data
#' @param seed integer, for reproducibility
#' @return A \code{\linkS4class{LeakFit}} S4 object containing:
#' \describe{
#' \item{\code{splits}}{The \code{LeakSplits} object used for resampling.}
#' \item{\code{metrics}}{Data.frame of per-fold, per-learner performance
#' metrics with columns \code{fold}, \code{learner}, and one column per
#' requested metric.}
#' \item{\code{metric_summary}}{Data.frame summarizing metrics across folds
#' for each learner with columns \code{learner}, and \code{<metric>_mean}
#' and \code{<metric>_sd} for each requested metric.}
#' \item{\code{audit}}{Data.frame with per-fold audit information including
#' \code{fold}, \code{n_train}, \code{n_test}, \code{learner}, and
#' \code{features_final} (number of features after preprocessing).}
#' \item{\code{predictions}}{List of data.frames containing out-of-fold
#' predictions with columns \code{id} (sample identifier), \code{truth}
#' (true outcome), \code{pred} (predicted value or probability), \code{fold},
#' and \code{learner}. For classification tasks, includes \code{pred_class}.
#' For multiclass, includes per-class probability columns.}
#' \item{\code{preprocess}}{List of preprocessing state objects from each fold,
#' storing imputation parameters, normalization statistics, and feature
#' selection results.}
#' \item{\code{learners}}{List of fitted model objects from each fold.}
#' \item{\code{outcome}}{Character string naming the outcome variable.}
#' \item{\code{task}}{Character string indicating the task type
#' (\code{"binomial"}, \code{"multiclass"}, \code{"gaussian"}, or
#' \code{"survival"}).}
#' \item{\code{feature_names}}{Character vector of feature names after
#' preprocessing.}
#' \item{\code{info}}{List of additional metadata including \code{hash},
#' \code{metrics_used}, \code{class_weights}, \code{positive_class},
#' \code{sample_ids}, \code{fold_status}, \code{refit}, \code{final_model} (refitted model if
#' \code{refit = TRUE}), \code{final_preprocess}, \code{learner_names},
#' and \code{perm_refit_spec} (for permutation-based audits).}
#' }
#' Use \code{summary()} to print a formatted report, or access slots directly
#' with \code{@}.
#' @details
#' Preprocessing is fit on the training fold and applied to the test fold,
#' preventing leakage from global imputation, scaling, or feature selection.
#' When a `recipes::recipe` or `workflows::workflow` is supplied, the recipe is
#' prepped on the training fold and baked on the test fold.
#' For data.frame or matrix inputs, columns used to define splits
#' (outcome, group, batch, study, time) are excluded from the predictor matrix.
#' Use \code{learner_args} to pass model-specific arguments, either as a named
#' list keyed by learner or a single list applied to all learners. For custom
#' learners, \code{learner_args[[name]]} may be a list with \code{fit} and
#' \code{predict} sublists to pass distinct arguments to each stage. For binomial
#' tasks, predictions and metrics assume the positive class is the second factor
#' level; use \code{positive_class} to control this. Use
#' \code{classification_threshold} to change the probability cutoff used for
#' class labels and accuracy. Parsnip learners must support
#' probability predictions for binomial metrics (AUC/PR-AUC/accuracy) and
#' multiclass log-loss when requested.
#' @examples
#' set.seed(1)
#' df <- data.frame(
#' subject = rep(1:10, each = 2),
#' outcome = rbinom(20, 1, 0.5),
#' x1 = rnorm(20),
#' x2 = rnorm(20)
#' )
#' splits <- make_split_plan(df, outcome = "outcome",
#' mode = "subject_grouped", group = "subject", v = 5)
#'
#' # glmnet learner (requires glmnet package)
#' fit <- fit_resample(df, outcome = "outcome", splits = splits,
#' learner = "glmnet", metrics = "auc")
#' summary(fit)
#'
#' # Custom learner (logistic regression) - no extra packages needed
#' custom <- list(
#' glm = list(
#' fit = function(x, y, task, weights, ...) {
#' stats::glm(y ~ ., data = as.data.frame(x),
#' family = stats::binomial(), weights = weights)
#' },
#' predict = function(object, newdata, task, ...) {
#' as.numeric(stats::predict(object, newdata = as.data.frame(newdata), type = "response"))
#' }
#' )
#' )
#' fit2 <- fit_resample(df, outcome = "outcome", splits = splits,
#' learner = "glm", custom_learners = custom,
#' metrics = "accuracy")
#'
#' summary(fit2)
#' @export
fit_resample <- function(x, outcome, splits,
preprocess = list(
impute = list(method = "median"),
normalize = list(method = "zscore"),
filter = list(var_thresh = 0, iqr_thresh = 0),
fs = list(method = "none")
),
learner = c("glmnet", "ranger"),
learner_args = list(),
custom_learners = list(),
metrics = c("auc", "pr_auc", "accuracy"),
class_weights = NULL,
positive_class = NULL,
classification_threshold = 0.5,
parallel = FALSE,
refit = TRUE,
seed = 1,
split_cols = "auto",
store_refit_data = TRUE) {
set.seed(seed)
.bio_strict_checks(context = "fit_resample", seed = seed)
classification_threshold_supplied <- !missing(classification_threshold)
learner_input <- learner
if (!is.numeric(classification_threshold) || length(classification_threshold) != 1L ||
!is.finite(classification_threshold) || classification_threshold < 0 ||
classification_threshold > 1) {
.bio_stop("classification_threshold must be a single numeric value in [0, 1].",
"bioLeak_input_error")
}
if (is.null(custom_learners)) custom_learners <- list()
if (!is.list(custom_learners)) {
.bio_stop("custom_learners must be a named list of learner definitions.",
"bioLeak_input_error")
}
if (length(custom_learners) && is.null(names(custom_learners))) {
.bio_stop("custom_learners must be a named list.",
"bioLeak_input_error")
}
if (length(custom_learners)) {
dup_names <- intersect(names(custom_learners), c("glmnet", "ranger"))
if (length(dup_names)) {
.bio_stop(sprintf("custom_learners cannot override built-in learners: %s",
paste(dup_names, collapse = ", ")),
"bioLeak_input_error")
}
bad <- vapply(custom_learners, function(def) {
!is.list(def) || !is.function(def$fit) || !is.function(def$predict)
}, logical(1))
if (any(bad)) {
.bio_stop("Each custom learner must be a list with `fit` and `predict` functions.",
"bioLeak_input_error")
}
}
is_parsnip_spec <- function(obj) inherits(obj, "model_spec")
is_workflow <- function(obj) inherits(obj, "workflow")
use_parsnip <- FALSE
use_workflow <- FALSE
learner_specs <- NULL
learner_names <- NULL
if (is_workflow(learner)) {
use_workflow <- TRUE
learner_specs <- list(learner)
} else if (is.list(learner) && length(learner) &&
all(vapply(learner, is_workflow, logical(1)))) {
use_workflow <- TRUE
learner_specs <- learner
} else if (is_parsnip_spec(learner)) {
use_parsnip <- TRUE
learner_specs <- list(learner)
} else if (is.list(learner) && length(learner) &&
all(vapply(learner, is_parsnip_spec, logical(1)))) {
use_parsnip <- TRUE
learner_specs <- learner
}
if (use_workflow) {
if (!requireNamespace("workflows", quietly = TRUE)) {
.bio_stop("Package 'workflows' is required when passing a workflow to learner.",
"bioLeak_dependency_error")
}
if (length(custom_learners)) {
warning("custom_learners ignored when learner is a workflow.")
}
if (length(learner_args)) {
warning("learner_args ignored when learner is a workflow.")
}
learner_names <- names(learner_specs)
if (is.null(learner_names)) learner_names <- rep("", length(learner_specs))
missing_names <- !nzchar(learner_names)
if (any(missing_names)) {
fallback <- paste0("workflow_", seq_along(learner_specs))
learner_names[missing_names] <- fallback[missing_names]
}
} else if (use_parsnip) {
if (!requireNamespace("parsnip", quietly = TRUE)) {
.bio_stop("Package 'parsnip' is required when passing a model_spec to learner.",
"bioLeak_dependency_error")
}
if (length(custom_learners)) {
warning("custom_learners ignored when learner is a parsnip model_spec.")
}
if (length(learner_args)) {
warning("learner_args ignored when learner is a parsnip model_spec.")
}
parsnip_label <- function(spec, fallback) {
model_class <- class(spec)
model_class <- model_class[model_class != "model_spec"]
model_class <- model_class[1]
label <- if (!is.null(model_class) && nzchar(model_class)) model_class else fallback
engine <- NULL
if (!is.null(spec$engine) && nzchar(spec$engine)) {
engine <- spec$engine
} else if (!is.null(spec$method) && !is.null(spec$method$engine) &&
nzchar(spec$method$engine)) {
engine <- spec$method$engine
}
if (!is.null(engine)) label <- paste0(label, "/", engine)
label
}
learner_names <- names(learner_specs)
if (is.null(learner_names)) learner_names <- rep("", length(learner_specs))
missing_names <- !nzchar(learner_names)
if (any(missing_names)) {
fallback <- paste0("spec_", seq_along(learner_specs))
spec_labels <- vapply(seq_along(learner_specs), function(i) {
parsnip_label(learner_specs[[i]], fallback[[i]])
}, character(1))
learner_names[missing_names] <- spec_labels[missing_names]
}
} else {
builtin_learners <- c("glmnet", "ranger")
all_learners <- c(builtin_learners, names(custom_learners))
if (!is.character(learner)) {
.bio_stop("learner must be a character vector of legacy learners, a parsnip model_spec, or a workflow.",
"bioLeak_input_error")
}
learner <- match.arg(learner, choices = all_learners, several.ok = TRUE)
learner_names <- learner
}
use_recipe <- .bio_is_recipe(preprocess)
if (use_recipe && !requireNamespace("recipes", quietly = TRUE)) {
.bio_stop("Package 'recipes' is required when preprocess is a recipe.",
"bioLeak_dependency_error")
}
if (use_workflow && isTRUE(use_recipe)) {
warning("Recipe preprocess ignored when learner is a workflow.")
use_recipe <- FALSE
}
preprocess_mode <- if (use_workflow) "workflow" else if (use_recipe) "recipe" else "guard"
validation_mode <- .bio_validation_mode()
if (use_recipe) {
.bio_validate_recipe_graph(
preprocess,
context = "fit_resample",
mode = validation_mode
)
}
if (use_workflow && length(learner_specs)) {
for (wf in learner_specs) {
.bio_validate_workflow_graph(
wf,
context = "fit_resample",
mode = validation_mode
)
}
}
Xall <- .bio_get_x(x)
Xall_raw <- Xall
yall <- .bio_get_y(x, outcome)
if (!inherits(splits, "LeakSplits")) {
if (.bio_is_rsample(splits)) {
coldata <- if (.bio_is_se(x)) {
as.data.frame(SummarizedExperiment::colData(x))
} else if (is.data.frame(x)) {
x
} else if (is.matrix(x)) {
data.frame(row_id = seq_len(nrow(Xall)))
} else {
NULL
}
splits <- .bio_as_leaksplits_from_rsample(splits, n = nrow(Xall), coldata = coldata,
split_cols = split_cols)
} else {
.bio_stop("splits must be a LeakSplits or rsample rset/rsplit.",
"bioLeak_input_error")
}
}
drop_cols <- outcome
if (inherits(splits, "LeakSplits")) {
split_info <- splits@info
drop_cols <- unique(c(drop_cols,
split_info$group,
split_info$batch,
split_info$study,
split_info$time))
}
drop_cols <- drop_cols[!is.na(drop_cols) & nzchar(drop_cols)]
if (length(drop_cols) && !is.null(colnames(Xall))) {
drop_cols <- intersect(colnames(Xall), drop_cols)
}
if (length(drop_cols)) {
Xall <- Xall[, setdiff(colnames(Xall), drop_cols), drop = FALSE]
}
sample_ids <- NULL
if (inherits(splits, "LeakSplits") && !is.null(splits@info$coldata)) {
cd <- splits@info$coldata
rn_cd <- rownames(cd)
if (!is.null(rn_cd) && !anyNA(rn_cd) && all(nzchar(rn_cd)) && !anyDuplicated(rn_cd)) {
sample_ids <- rn_cd
} else if ("row_id" %in% names(cd)) {
rid <- as.character(cd[["row_id"]])
if (length(rid) == nrow(cd) && !anyNA(rid) && !anyDuplicated(rid) && all(nzchar(rid))) {
sample_ids <- rid
}
}
}
if (is.null(sample_ids)) {
rn <- rownames(Xall)
if (!is.null(rn) && !anyNA(rn) && all(nzchar(rn)) && !anyDuplicated(rn)) {
sample_ids <- rn
}
}
if (is.null(sample_ids) || length(sample_ids) != nrow(Xall)) {
sample_ids <- as.character(seq_len(nrow(Xall)))
}
ids <- sample_ids
compact <- isTRUE(splits@info$compact)
fold_assignments <- splits@info$fold_assignments
split_mode <- splits@mode
split_time <- splits@info$time
split_horizon <- splits@info$horizon %||% 0
split_purge <- splits@info$purge %||% 0
split_embargo <- splits@info$embargo %||% 0
split_coldata <- splits@info$coldata
time_vec <- NULL
if (compact && identical(split_mode, "time_series")) {
if (is.null(split_coldata) || is.null(split_time) || !split_time %in% names(split_coldata)) {
stop("time_series compact splits require time column in coldata.")
}
time_vec <- split_coldata[[split_time]]
}
task <- if (.bio_is_survival(yall)) "survival"
else if (.bio_is_binomial(yall)) "binomial"
else if (.bio_is_multiclass(yall)) "multiclass"
else if (.bio_is_regression(yall)) "gaussian"
else if (is.factor(yall) && nlevels(yall) == 2) "binomial"
else if (is.factor(yall) && nlevels(yall) > 2) "multiclass"
else .bio_stop("Unsupported outcome type: require binomial/multiclass factor, numeric regression, or survival outcome.",
"bioLeak_input_error")
if (task == "binomial") {
if (!is.factor(yall)) yall <- factor(yall)
yall <- droplevels(yall)
if (anyNA(yall)) {
.bio_stop("Binomial outcome contains NA values. Remove or impute missing outcomes before fitting.",
"bioLeak_input_error")
}
if (nlevels(yall) != 2) {
.bio_stop("Binomial task requires exactly two outcome levels after preprocessing.",
"bioLeak_input_error")
}
if (!is.null(positive_class)) {
pos_chr <- as.character(positive_class)
if (length(pos_chr) != 1L) {
.bio_stop("positive_class must be a single value.",
"bioLeak_input_error")
}
levels_y <- levels(yall)
if (!pos_chr %in% levels_y) {
.bio_stop(sprintf("positive_class '%s' not found in outcome levels: %s",
pos_chr, paste(levels_y, collapse = ", ")),
"bioLeak_input_error")
}
if (!identical(pos_chr, levels_y[2])) {
levels_y <- c(setdiff(levels_y, pos_chr), pos_chr)
yall <- factor(yall, levels = levels_y)
}
}
class_levels <- levels(yall)
if (!is.null(class_weights)) {
if (!is.numeric(class_weights)) stop("class_weights must be numeric.")
if (is.null(names(class_weights))) {
if (length(class_weights) != length(class_levels)) {
stop("class_weights must align with outcome levels.")
}
names(class_weights) <- class_levels[seq_along(class_weights)]
}
missing_cw <- setdiff(class_levels, names(class_weights))
if (length(missing_cw)) {
stop(sprintf("class_weights missing levels: %s", paste(missing_cw, collapse = ", ")))
}
class_weights <- class_weights[class_levels]
}
} else if (task == "multiclass") {
if (!is.factor(yall)) yall <- factor(yall)
yall <- droplevels(yall)
if (anyNA(yall)) {
stop("Multiclass outcome contains NA values. Remove or impute missing outcomes before fitting.", call. = FALSE)
}
if (nlevels(yall) < 3) {
stop("Multiclass task requires 3 or more outcome levels after preprocessing.")
}
class_levels <- levels(yall)
if (!is.null(class_weights)) {
if (!is.numeric(class_weights)) stop("class_weights must be numeric.")
if (is.null(names(class_weights))) {
if (length(class_weights) != length(class_levels)) {
stop("class_weights must align with outcome levels.")
}
names(class_weights) <- class_levels[seq_along(class_weights)]
}
missing_cw <- setdiff(class_levels, names(class_weights))
if (length(missing_cw)) {
stop(sprintf("class_weights missing levels: %s", paste(missing_cw, collapse = ", ")))
}
class_weights <- class_weights[class_levels]
}
if (!is.null(positive_class)) {
warning("positive_class is ignored for multiclass tasks.")
}
if (classification_threshold_supplied) {
warning("classification_threshold is ignored for multiclass tasks.")
}
} else if (task == "gaussian") {
if (!is.numeric(yall)) {
yall <- as.numeric(yall)
if (anyNA(yall)) stop("Gaussian task requires numeric outcome values.")
}
class_levels <- NULL
if (!is.null(class_weights)) {
warning("class_weights is ignored for gaussian tasks.")
}
if (!is.null(positive_class)) {
warning("positive_class is ignored for gaussian tasks.")
}
if (classification_threshold_supplied) {
warning("classification_threshold is ignored for gaussian tasks.")
}
} else {
class_levels <- NULL
if (!inherits(yall, "Surv")) {
stop("Survival task requires a Surv outcome.")
}
if (anyNA(yall)) {
stop("Survival outcome contains NA values. Remove or impute missing outcomes before fitting.", call. = FALSE)
}
if (!is.null(class_weights)) {
warning("class_weights is ignored for survival tasks.")
}
if (!is.null(positive_class)) {
warning("positive_class is ignored for survival tasks.")
}
if (classification_threshold_supplied) {
warning("classification_threshold is ignored for survival tasks.")
}
}
if (use_workflow && task == "binomial" && !is.null(class_weights)) {
warning("class_weights are ignored for workflow learners unless explicitly handled in the workflow.")
}
if (use_workflow && task == "multiclass" && !is.null(class_weights)) {
warning("class_weights are ignored for workflow learners unless explicitly handled in the workflow.")
}
metrics_input <- metrics
metric_mode <- "legacy"
yardstick_set <- NULL
yardstick_metrics <- NULL
if (!is.null(metrics) && inherits(metrics, "metric_set")) {
metric_mode <- "yardstick"
yardstick_set <- metrics
} else if (!is.null(metrics) && .bio_is_yardstick_metric(metrics)) {
metric_mode <- "yardstick"
yardstick_metrics <- list(metrics)
} else if (is.list(metrics) && length(metrics) &&
all(vapply(metrics, .bio_is_yardstick_metric, logical(1)))) {
metric_mode <- "yardstick"
yardstick_metrics <- metrics
}
if (identical(metric_mode, "yardstick")) {
if (!requireNamespace("yardstick", quietly = TRUE)) {
stop("Package 'yardstick' is required when metrics are a yardstick set.", call. = FALSE)
}
if (is.null(yardstick_set)) {
yardstick_set <- do.call(yardstick::metric_set, yardstick_metrics)
}
metric_labels <- character(0)
} else {
if (is.null(metrics)) {
metrics <- if (task == "binomial") c("auc", "pr_auc", "accuracy")
else if (task == "multiclass") c("accuracy", "macro_f1")
else if (task == "survival") c("cindex")
else c("rmse")
}
if (is.character(metrics)) {
allowed <- if (task == "binomial") c("auc", "pr_auc", "accuracy")
else if (task == "multiclass") c("accuracy", "macro_f1", "log_loss")
else if (task == "survival") c("cindex")
else c("rmse", "cindex")
invalid <- setdiff(metrics, allowed)
if (length(invalid)) {
warning(sprintf("Dropping metrics not applicable to %s task: %s", task,
paste(invalid, collapse = ", ")))
metrics <- setdiff(metrics, invalid)
}
if (!length(metrics)) {
metrics <- if (task == "binomial") c("auc", "pr_auc", "accuracy")
else if (task == "multiclass") c("accuracy", "macro_f1")
else if (task == "survival") c("cindex")
else c("rmse")
}
}
metrics <- if (is.list(metrics)) metrics else as.list(metrics)
metric_labels <- vapply(seq_along(metrics), function(i) {
nm <- names(metrics)[i]
if (!is.null(nm) && !is.na(nm) && nzchar(nm)) return(nm)
mi <- metrics[[i]]
if (is.character(mi) && length(mi) == 1) return(mi)
paste0("metric_", i)
}, character(1))
}
learner_objs <- if (use_parsnip || use_workflow) learner_specs else as.list(learner_names)
# helper: safe metric computation -------------------------------------------
compute_metric <- function(name, y, pred) {
if (is.function(name)) return(name(y, pred))
yb <- NULL
if (task == "binomial") {
yb <- if (is.factor(y)) as.numeric(y) - 1 else as.numeric(y)
}
if (name == "auc" && task == "binomial") {
if (requireNamespace("pROC", quietly = TRUE))
return(as.numeric(pROC::auc(pROC::roc(y, pred, quiet = TRUE))))
pos <- pred[yb == 1]
neg <- pred[yb == 0]
comp <- outer(pos, neg, function(a, b) (a > b) + 0.5 * (a == b))
return(mean(comp))
}
if (name == "pr_auc" && task == "binomial") {
if (requireNamespace("PRROC", quietly = TRUE)) {
pr <- PRROC::pr.curve(scores.class0 = pred[yb == 1],
scores.class1 = pred[yb == 0],
curve = FALSE)
return(pr$auc.integral)
}
return(NA_real_)
}
if (name == "accuracy" && task == "binomial") {
return(mean((pred >= classification_threshold) == as.logical(yb)))
}
if (name == "rmse" && task == "gaussian")
return(sqrt(mean((as.numeric(y) - pred)^2)))
if (name == "cindex" && task == "gaussian")
return(.cindex_pairwise(pred, y))
if (name == "cindex" && task == "survival")
return(.cindex_survival(pred, y))
NA_real_
}
compute_yardstick <- function(y, pred, pred_class, prob = NULL) {
if (task == "survival") {
stop("Yardstick metrics are not supported for survival tasks.", call. = FALSE)
}
if (task == "multiclass") {
df <- data.frame(truth = y, pred_class = pred_class, stringsAsFactors = FALSE)
if (!is.null(prob)) {
prob <- as.data.frame(prob, check.names = FALSE)
prob_cols <- paste0(".pred_", make.names(class_levels))
if (ncol(prob) == length(class_levels)) {
names(prob) <- prob_cols
}
df <- cbind(df, prob)
}
pred_cols <- grep("^\\.pred_", names(df), value = TRUE)
args <- list(df, truth = quote(truth), estimate = quote(pred_class))
for (col in pred_cols) {
args[[col]] <- as.name(col)
}
res <- try(do.call(yardstick_set, args), silent = TRUE)
} else if (task == "binomial") {
df <- data.frame(truth = y, .pred = as.numeric(pred), stringsAsFactors = FALSE)
if (!is.null(pred_class)) df$.pred_class <- pred_class
res <- try(yardstick_set(df, truth = truth, estimate = .pred_class,
.pred, event_level = "second"),
silent = TRUE)
} else {
df <- data.frame(truth = y, .pred = as.numeric(pred), stringsAsFactors = FALSE)
res <- try(yardstick_set(df, truth = truth, estimate = .pred),
silent = TRUE)
}
if (inherits(res, "try-error")) {
err_msg <- attr(res, "condition")$message
stop(sprintf("Yardstick metrics failed: %s", err_msg), call. = FALSE)
}
stats::setNames(res$.estimate, res$.metric)
}
align_probabilities <- function(prob, class_levels) {
prob_mat <- if (is.data.frame(prob)) as.matrix(prob) else as.matrix(prob)
if (is.null(colnames(prob_mat))) {
if (ncol(prob_mat) != length(class_levels)) {
stop("Probability predictions do not match class levels.", call. = FALSE)
}
colnames(prob_mat) <- class_levels
return(prob_mat)
}
exp_cols <- paste0(".pred_", make.names(class_levels))
if (all(exp_cols %in% colnames(prob_mat))) {
prob_mat <- prob_mat[, exp_cols, drop = FALSE]
} else if (all(class_levels %in% colnames(prob_mat))) {
prob_mat <- prob_mat[, class_levels, drop = FALSE]
} else if (all(make.names(class_levels) %in% colnames(prob_mat))) {
prob_mat <- prob_mat[, make.names(class_levels), drop = FALSE]
} else if (ncol(prob_mat) == length(class_levels)) {
prob_mat <- prob_mat[, seq_len(ncol(prob_mat)), drop = FALSE]
} else {
stop("Probability predictions do not align with class levels.", call. = FALSE)
}
colnames(prob_mat) <- class_levels
prob_mat
}
# --- Robust design matrix builder ------------------------------------------
make_design_matrix <- function(X, ref_cols = NULL) {
X <- as.data.frame(X)
X <- X[, !names(X) %in% c("y", "outcome"), drop = FALSE]
is_num <- vapply(X, is.numeric, logical(1))
if (all(is_num)) {
mm <- as.matrix(X)
} else {
mf <- stats::model.frame(~ ., data = X, na.action = stats::na.pass)
mm <- stats::model.matrix(~ . - 1, data = mf)
mm <- as.matrix(mm)
}
if (!is.null(ref_cols)) {
missing_cols <- setdiff(ref_cols, colnames(mm))
if (length(missing_cols)) {
mm <- cbind(
mm,
matrix(0, nrow = nrow(mm), ncol = length(missing_cols),
dimnames = list(NULL, missing_cols))
)
}
extra_cols <- setdiff(colnames(mm), ref_cols)
if (length(extra_cols)) {
mm <- mm[, setdiff(colnames(mm), extra_cols), drop = FALSE]
}
mm <- mm[, ref_cols, drop = FALSE]
return(list(matrix = mm, columns = ref_cols))
}
if (!ncol(mm)) {
stop("All predictors have zero variance after preprocessing.")
}
keep <- apply(mm, 2, sd, na.rm = TRUE) > 0
if (!any(keep)) {
stop("All predictors have zero variance after preprocessing.")
}
mm <- mm[, keep, drop = FALSE]
list(matrix = mm, columns = colnames(mm))
}
resolve_weights <- function(y, weights_spec) {
if (is.null(weights_spec) || !task %in% c("binomial", "multiclass")) return(NULL)
if (!is.numeric(weights_spec)) stop("class_weights must be numeric.")
cw <- weights_spec
if (is.null(names(cw))) {
if (length(cw) != length(class_levels)) {
stop("Provide class_weights as a named vector matching outcome levels.")
}
names(cw) <- class_levels[seq_along(cw)]
}
missing_levels <- setdiff(class_levels, names(cw))
if (length(missing_levels)) {
stop(sprintf("class_weights missing levels: %s", paste(missing_levels, collapse = ", ")))
}
cw[as.character(y)]
}
resolve_args <- function(name, defaults = list()) {
if (!length(learner_args)) return(modifyList(defaults, list()))
if (!is.null(names(learner_args)) && all(names(learner_args) %in% learner)) {
extras <- learner_args[[name]] %||% list()
} else {
extras <- learner_args
}
modifyList(defaults, extras)
}
resolve_custom_args <- function(name) {
if (!length(learner_args)) return(list(fit = list(), predict = list()))
if (!is.null(names(learner_args)) && all(names(learner_args) %in% learner)) {
extras <- learner_args[[name]] %||% list()
} else {
extras <- learner_args
}
if (is.list(extras) && (("fit" %in% names(extras)) || ("predict" %in% names(extras)))) {
return(list(fit = extras$fit %||% list(),
predict = extras$predict %||% list()))
}
list(fit = extras, predict = extras)
}
# single learner wrapper ----------------------------------------------------
train_one_learner <- function(learner_obj, learner_label, Xtrg, ytr, Xteg, yte, weights = NULL) {
if (inherits(learner_obj, "model_spec")) {
if (!requireNamespace("parsnip", quietly = TRUE)) {
stop("Package 'parsnip' is required when learner is a model_spec.")
}
if (!is.null(weights) && !inherits(weights, "hardhat_case_weights")) {
if (!requireNamespace("hardhat", quietly = TRUE)) {
stop("Package 'hardhat' is required for case weights with parsnip learners.")
}
weights <- hardhat::frequency_weights(weights)
}
y_for_fit <- if (task %in% c("binomial", "multiclass")) {
factor(ytr, levels = class_levels)
} else if (task == "survival") {
ytr
} else {
as.numeric(ytr)
}
fit <- if (is.null(weights)) {
parsnip::fit_xy(learner_obj, x = Xtrg, y = y_for_fit)
} else {
parsnip::fit_xy(learner_obj, x = Xtrg, y = y_for_fit, case_weights = weights)
}
if (task == "binomial") {
prob <- try(stats::predict(fit, new_data = Xteg, type = "prob"), silent = TRUE)
if (inherits(prob, "try-error")) {
# Extract the actual error text from parsnip/xgboost
err_msg <- attr(prob, "condition")$message
stop(sprintf("Parsnip learner '%s' failed to predict: %s",
learner_label, err_msg))
}
prob_df <- as.data.frame(prob)
pos_col <- paste0(".pred_", make.names(class_levels[2]))
if (!pos_col %in% names(prob_df)) {
if (ncol(prob_df) >= 2L) {
pos_col <- names(prob_df)[2]
} else {
stop(sprintf("Parsnip learner '%s' did not return class probabilities.",
learner_label))
}
}
pred <- as.numeric(prob_df[[pos_col]])
pred_class <- factor(ifelse(pred >= classification_threshold, class_levels[2], class_levels[1]),
levels = class_levels)
return(list(pred = pred, pred_class = pred_class, fit = fit))
}
if (task == "multiclass") {
prob <- try(stats::predict(fit, new_data = Xteg, type = "prob"), silent = TRUE)
if (inherits(prob, "try-error")) {
err_msg <- attr(prob, "condition")$message
stop(sprintf("Parsnip learner '%s' failed to predict: %s",
learner_label, err_msg))
}
prob_mat <- align_probabilities(prob, class_levels)
class_pred <- try(stats::predict(fit, new_data = Xteg, type = "class"), silent = TRUE)
if (inherits(class_pred, "try-error")) {
err_msg <- attr(class_pred, "condition")$message
stop(sprintf("Parsnip learner '%s' failed to predict classes: %s",
learner_label, err_msg))
}
class_df <- as.data.frame(class_pred)
pred_class <- factor(as.character(class_df[[1]]), levels = class_levels)
return(list(pred = pred_class, pred_class = pred_class, prob = prob_mat, fit = fit))
}
if (task == "survival") {
pred_df <- try(stats::predict(fit, new_data = Xteg, type = "numeric"), silent = TRUE)
if (inherits(pred_df, "try-error")) {
pred_df <- try(stats::predict(fit, new_data = Xteg, type = "risk"), silent = TRUE)
}
if (inherits(pred_df, "try-error")) {
err_msg <- attr(pred_df, "condition")$message
stop(sprintf("Parsnip learner '%s' failed to predict: %s",
learner_label, err_msg))
}
pred <- as.numeric(as.data.frame(pred_df)[[1]])
return(list(pred = pred, fit = fit))
} else {
pred_df <- stats::predict(fit, new_data = Xteg, type = "numeric")
pred <- as.numeric(pred_df[[1]])
return(list(pred = pred, fit = fit))
}
}
if (learner_obj %in% names(custom_learners)) {
def <- custom_learners[[learner_obj]]
args <- resolve_custom_args(learner_obj)
fit_args <- c(list(x = Xtrg, y = ytr, task = task, weights = weights), args$fit)
model <- do.call(def$fit, fit_args)
pred_args <- c(list(object = model, newdata = Xteg, task = task), args$predict)
pred <- do.call(def$predict, pred_args)
if (task == "multiclass") {
if (is.data.frame(pred) || is.matrix(pred)) {
prob_mat <- align_probabilities(pred, class_levels)
pred_class <- factor(class_levels[max.col(prob_mat, ties.method = "first")],
levels = class_levels)
return(list(pred = pred_class, pred_class = pred_class, prob = prob_mat, fit = model))
}
if (is.factor(pred) || is.character(pred)) {
pred_class <- factor(as.character(pred), levels = class_levels)
if (length(pred_class) != nrow(Xteg)) {
stop(sprintf("Custom learner '%s' returned %d predictions for %d rows.",
learner_obj, length(pred_class), nrow(Xteg)))
}
return(list(pred = pred_class, pred_class = pred_class, fit = model))
}
stop(sprintf("Custom learner '%s' must return class labels or class probabilities for multiclass tasks.",
learner_obj))
}
pred <- as.numeric(pred)
if (length(pred) != nrow(Xteg)) {
stop(sprintf("Custom learner '%s' returned %d predictions for %d rows.",
learner_obj, length(pred), nrow(Xteg)))
}
return(list(pred = pred, fit = model))
}
if (learner_obj == "glmnet") {
if (!requireNamespace("glmnet", quietly = TRUE)) stop("Install 'glmnet'.")
fam <- if (task == "binomial") "binomial"
else if (task == "multiclass") "multinomial"
else if (task == "survival") "cox"
else "gaussian"
la <- resolve_args("glmnet", list(alpha = 0.9, standardize = FALSE))
Xtr_design <- make_design_matrix(Xtrg)
Xte_design <- make_design_matrix(Xteg, ref_cols = Xtr_design$columns)
y_for_fit <- if (task == "binomial") {
as.numeric(factor(ytr, levels = class_levels)) - 1
} else if (task == "multiclass") {
factor(ytr, levels = class_levels)
} else if (task == "survival") {
ytr
} else {
as.numeric(ytr)
}
cv_args <- c(list(x = Xtr_design$matrix,
y = y_for_fit,
family = fam,
alpha = la$alpha,
standardize = la$standardize %||% FALSE,
weights = weights),
la[setdiff(names(la), c("alpha", "standardize"))])
cvfit <- do.call(glmnet::cv.glmnet, cv_args)
if (task == "multiclass") {
pred_arr <- predict(cvfit, Xte_design$matrix, s = "lambda.min", type = "response")
if (length(dim(pred_arr)) == 3L) {
prob_mat <- pred_arr[, , 1, drop = FALSE][,,1]
} else {
prob_mat <- pred_arr
}
prob_mat <- align_probabilities(prob_mat, class_levels)
pred_class <- factor(class_levels[max.col(prob_mat, ties.method = "first")],
levels = class_levels)
return(list(pred = pred_class, pred_class = pred_class, prob = prob_mat, fit = cvfit))
}
pred_type <- if (task == "survival") "link" else "response"
pred <- as.numeric(predict(cvfit, Xte_design$matrix, s = "lambda.min",
type = pred_type))
return(list(pred = pred, fit = cvfit))
}
if (learner_obj == "ranger") {
if (!requireNamespace("ranger", quietly = TRUE)) stop("Install 'ranger'.")
if (task == "survival") {
stop("Learner 'ranger' does not support survival tasks in bioLeak; use parsnip/workflow or a custom learner.")
}
y_for_fit <- if (task %in% c("binomial", "multiclass")) {
factor(ytr, levels = class_levels)
} else {
as.numeric(ytr)
}
dftr <- data.frame(y = y_for_fit, Xtrg, check.names = FALSE)
frm <- stats::as.formula("y ~ .")
rng_args <- resolve_args("ranger", list())
if (task %in% c("binomial", "multiclass")) {
rng_args$class.weights <- rng_args$class.weights %||% class_weights
}
rg <- do.call(ranger::ranger, c(list(formula = frm, data = dftr,
probability = (task %in% c("binomial", "multiclass"))),
rng_args))
dfte <- data.frame(Xteg, check.names = FALSE)
pr <- predict(rg, dfte)
if (task == "binomial") {
pred <- pr$predictions[, 2]
pred_class <- factor(ifelse(pred >= classification_threshold, class_levels[2], class_levels[1]),
levels = class_levels)
return(list(pred = pred, pred_class = pred_class, fit = rg))
}
if (task == "multiclass") {
prob_mat <- pr$predictions
prob_mat <- align_probabilities(prob_mat, class_levels)
pred_class <- factor(class_levels[max.col(prob_mat, ties.method = "first")],
levels = class_levels)
return(list(pred = pred_class, pred_class = pred_class, prob = prob_mat, fit = rg))
}
pred <- pr$predictions
return(list(pred = pred, fit = rg))
}
stop("Unsupported learner.")
}
train_one_workflow <- function(learner_obj, learner_label, dftr, dfte, weights = NULL) {
if (!requireNamespace("workflows", quietly = TRUE)) {
stop("Package 'workflows' is required when learner is a workflow.")
}
fit <- try(generics::fit(learner_obj, data = dftr), silent = TRUE)
if (inherits(fit, "try-error")) {
err_msg <- attr(fit, "condition")$message
stop(sprintf("Workflow learner '%s' failed to fit: %s", learner_label, err_msg))
}
if (task == "binomial") {
prob <- try(stats::predict(fit, new_data = dfte, type = "prob"), silent = TRUE)
if (inherits(prob, "try-error")) {
err_msg <- attr(prob, "condition")$message
stop(sprintf("Workflow learner '%s' failed to predict: %s", learner_label, err_msg))
}
prob_df <- as.data.frame(prob)
pos_col <- paste0(".pred_", make.names(class_levels[2]))
if (!pos_col %in% names(prob_df)) {
if (ncol(prob_df) >= 2L) {
pos_col <- names(prob_df)[2]
} else {
stop(sprintf("Workflow learner '%s' did not return class probabilities.",
learner_label))
}
}
pred <- as.numeric(prob_df[[pos_col]])
pred_class <- factor(ifelse(pred >= classification_threshold, class_levels[2], class_levels[1]),
levels = class_levels)
return(list(pred = pred, pred_class = pred_class, fit = fit))
}
if (task == "multiclass") {
prob <- try(stats::predict(fit, new_data = dfte, type = "prob"), silent = TRUE)
if (inherits(prob, "try-error")) {
err_msg <- attr(prob, "condition")$message
stop(sprintf("Workflow learner '%s' failed to predict: %s", learner_label, err_msg))
}
prob_mat <- align_probabilities(prob, class_levels)
class_pred <- try(stats::predict(fit, new_data = dfte, type = "class"), silent = TRUE)
if (inherits(class_pred, "try-error")) {
err_msg <- attr(class_pred, "condition")$message
stop(sprintf("Workflow learner '%s' failed to predict classes: %s",
learner_label, err_msg))
}
class_df <- as.data.frame(class_pred)
pred_class <- factor(as.character(class_df[[1]]), levels = class_levels)
return(list(pred = pred_class, pred_class = pred_class, prob = prob_mat, fit = fit))
}
if (task == "survival") {
pred_df <- try(stats::predict(fit, new_data = dfte, type = "numeric"), silent = TRUE)
if (inherits(pred_df, "try-error")) {
pred_df <- try(stats::predict(fit, new_data = dfte, type = "risk"), silent = TRUE)
}
if (inherits(pred_df, "try-error")) {
err_msg <- attr(pred_df, "condition")$message
stop(sprintf("Workflow learner '%s' failed to predict: %s", learner_label, err_msg))
}
pred <- as.numeric(as.data.frame(pred_df)[[1]])
return(list(pred = pred, fit = fit))
} else {
pred_df <- stats::predict(fit, new_data = dfte, type = "numeric")
pred <- as.numeric(pred_df[[1]])
return(list(pred = pred, fit = fit))
}
}
resolve_fold_indices <- function(fold) {
if (!isTRUE(compact) || !is.null(fold$train)) return(fold)
if (is.null(fold_assignments) || !length(fold_assignments)) {
stop("Compact splits require fold assignments to compute indices.")
}
r <- fold$repeat_id
if (is.null(r) || !is.finite(r)) r <- 1L
assign_vec <- fold_assignments[[r]]
if (is.null(assign_vec)) {
stop(sprintf("Missing fold assignments for repeat %s.", r))
}
test <- which(assign_vec == fold$fold)
if (identical(split_mode, "time_series")) {
if (is.null(time_vec) || !length(time_vec)) {
stop("time_series compact splits require time column values.")
}
train <- .bio_time_series_train_indices(
time_vec = time_vec,
test_idx = test,
candidate_idx = seq_len(nrow(Xall)),
horizon = split_horizon,
purge = split_purge,
embargo = split_embargo
)
} else {
train <- setdiff(seq_len(nrow(Xall)), test)
}
fold_seq <- fold$fold_seq %||% fold$fold
list(train = train, test = test, fold = fold$fold,
repeat_id = fold$repeat_id, fold_seq = fold_seq)
}
make_fold_df <- function(X, y) {
df <- as.data.frame(X, check.names = FALSE)
if (length(outcome) == 2L) {
if (!inherits(y, "Surv")) {
stop("Survival tasks require a Surv outcome when building fold data.")
}
y_mat <- as.matrix(y)
if (ncol(y_mat) < 2L) {
stop("Survival outcome must include time and event columns.")
}
df[[outcome[[1]]]] <- y_mat[, 1]
df[[outcome[[2]]]] <- y_mat[, ncol(y_mat)]
df <- df[, c(outcome, setdiff(names(df), outcome)), drop = FALSE]
return(df)
}
df[[outcome]] <- y
df[, c(outcome, setdiff(names(df), outcome)), drop = FALSE]
}
# fold-level function -------------------------------------------------------
do_fold <- function(fold) {
fold_full <- resolve_fold_indices(fold)
fold_id <- fold_full$fold_seq %||% fold_full$fold
set.seed(seed + fold_id)
tr <- fold_full$train
te <- fold_full$test
Xtr <- Xall[tr, , drop = FALSE]
Xte <- Xall[te, , drop = FALSE]
Xtr_raw <- Xall_raw[tr, , drop = FALSE]
Xte_raw <- Xall_raw[te, , drop = FALSE]
ytr <- yall[tr]
yte <- yall[te]
if (task %in% c("binomial", "multiclass")) {
ytr <- factor(ytr, levels = class_levels)
yte <- factor(yte, levels = class_levels)
if (nlevels(droplevels(ytr)) < 2) {
warning(sprintf("Fold %s skipped: only one class in training data", fold_id))
empty_metrics <- if (identical(metric_mode, "yardstick")) {
numeric(0)
} else {
setNames(rep(NA_real_, length(metrics)), metric_labels)
}
skipped <- lapply(learner_names, function(ln) {
list(
metrics = empty_metrics,
pred = data.frame(
id = integer(0),
truth = factor(character(0), levels = class_levels),
pred = numeric(0),
fold = integer(0),
learner = character(0),
stringsAsFactors = FALSE
),
guard = list(state = NULL),
learner = NULL,
feat_names = colnames(Xtr)
)
})
names(skipped) <- learner_names
return(skipped)
}
fold_weights <- resolve_weights(ytr, class_weights)
} else if (task == "survival") {
fold_weights <- NULL
} else {
ytr <- as.numeric(ytr)
yte <- as.numeric(yte)
fold_weights <- NULL
}
preprocess_state <- NULL
Xtrg <- Xtr
Xteg <- Xte
dftr <- NULL
dfte <- NULL
if (identical(preprocess_mode, "guard")) {
guard <- .guard_fit(
X = Xtr,
y = ytr,
steps = if (exists("preprocess") && is.list(preprocess)) preprocess else list(),
task = task
)
Xtrg <- guard$transform(Xtr)
Xteg <- guard$transform(Xte)
preprocess_state <- guard$state
colnames(Xtrg) <- make.names(colnames(Xtrg))
colnames(Xteg) <- make.names(colnames(Xteg))
} else if (identical(preprocess_mode, "recipe")) {
dftr <- make_fold_df(Xtr_raw, ytr)
dfte <- make_fold_df(Xte_raw, yte)
recipe_prep <- recipes::prep(preprocess, training = dftr, retain = TRUE)
Xtrg <- as.data.frame(recipes::juice(recipe_prep, recipes::all_predictors()),
check.names = FALSE)
Xteg <- as.data.frame(recipes::bake(recipe_prep, new_data = dfte,
recipes::all_predictors()),
check.names = FALSE)
if (length(drop_cols)) {
keep_cols <- setdiff(names(Xtrg), drop_cols)
Xtrg <- Xtrg[, keep_cols, drop = FALSE]
Xteg <- Xteg[, intersect(keep_cols, names(Xteg)), drop = FALSE]
missing_in_test <- setdiff(names(Xtrg), names(Xteg))
if (length(missing_in_test)) {
for (nm in missing_in_test) Xteg[[nm]] <- 0
}
Xteg <- Xteg[, names(Xtrg), drop = FALSE]
}
if (!ncol(Xtrg)) stop("All predictors have zero variance after preprocessing.")
preprocess_state <- list(type = "recipe", recipe = recipe_prep)
colnames(Xtrg) <- make.names(colnames(Xtrg))
colnames(Xteg) <- make.names(colnames(Xteg))
} else {
dftr <- make_fold_df(Xtr, ytr)
dfte <- make_fold_df(Xte, yte)
preprocess_state <- list(type = "workflow")
}
results <- list()
for (i in seq_along(learner_names)) {
ln <- learner_names[[i]]
learner_obj <- learner_objs[[i]]
# Train one learner; ensure consistent level matching
if (identical(preprocess_mode, "workflow")) {
dfte_pred <- dfte[, setdiff(names(dfte), outcome), drop = FALSE]
model <- train_one_workflow(learner_obj, ln, dftr, dfte_pred, weights = fold_weights)
feat_names <- setdiff(names(dftr), outcome)
} else {
model <- train_one_learner(learner_obj, ln, Xtrg, ytr, Xteg, yte, weights = fold_weights)
feat_names <- colnames(Xtrg)
}
pred_class <- if (task == "binomial") {
if (!is.null(model$pred_class)) {
model$pred_class
} else if (is.factor(model$pred) || is.character(model$pred)) {
factor(as.character(model$pred), levels = class_levels)
} else {
factor(ifelse(model$pred >= classification_threshold, class_levels[2], class_levels[1]),
levels = class_levels)
}
} else if (task == "multiclass") {
model$pred_class %||% {
if (is.factor(model$pred) || is.character(model$pred)) {
factor(as.character(model$pred), levels = class_levels)
} else {
NULL
}
}
} else {
NULL
}
prob_mat <- model$prob %||% NULL
ms <- if (identical(metric_mode, "yardstick")) {
compute_yardstick(yte, model$pred, pred_class, prob = prob_mat)
} else if (task == "multiclass") {
vals <- vapply(seq_along(metrics), function(idx) {
mname <- metrics[[idx]]
if (is.function(mname)) return(mname(yte, pred_class))
if (identical(mname, "accuracy")) return(.multiclass_accuracy(yte, pred_class))
if (identical(mname, "macro_f1")) return(.multiclass_macro_f1(yte, pred_class))
if (identical(mname, "log_loss")) {
if (is.null(prob_mat)) {
stop("log_loss requires class probability predictions for multiclass tasks.")
}
return(.multiclass_log_loss(yte, prob_mat))
}
NA_real_
}, numeric(1))
names(vals) <- metric_labels
vals
} else {
vals <- vapply(seq_along(metrics), function(idx) {
compute_metric(metrics[[idx]], yte, model$pred)
}, numeric(1))
names(vals) <- metric_labels
vals
}
pred_tbl <- if (task == "survival") {
yte_mat <- as.matrix(yte)
time_col <- if ("time" %in% colnames(yte_mat)) "time" else colnames(yte_mat)[1]
status_col <- if ("status" %in% colnames(yte_mat)) "status" else colnames(yte_mat)[ncol(yte_mat)]
data.frame(
id = ids[te],
truth_time = yte_mat[, time_col],
truth_event = yte_mat[, status_col],
pred = model$pred,
fold = fold_id,
learner = ln,
stringsAsFactors = FALSE
)
} else {
data.frame(
id = ids[te],
truth = yte,
pred = model$pred,
fold = fold_id,
learner = ln,
stringsAsFactors = FALSE
)
}
if (!is.null(pred_class)) pred_tbl$pred_class <- pred_class
if (!is.null(prob_mat) && task == "multiclass") {
prob_df <- as.data.frame(prob_mat, check.names = FALSE)
names(prob_df) <- paste0(".pred_", make.names(class_levels))
pred_tbl <- cbind(pred_tbl, prob_df)
}
results[[ln]] <- list(
metrics = ms,
pred = pred_tbl,
guard = preprocess_state,
learner = model$fit,
feat_names = feat_names
)
}
results
}
folds <- splits@indices
nfold <- length(folds)
fold_errors <- rep(NA_character_, nfold)
# progress bar --------------------------------------------------------------
pb <- utils::txtProgressBar(min = 0, max = nfold, style = 3)
pb_counter <- 0
progress_wrap <- function(f) {
local_fold_id <- f$fold_seq %||% f$fold
start_time <- proc.time()
res <- tryCatch(do_fold(f), error = function(e) {
fold_errors[[local_fold_id]] <<- conditionMessage(e)
warning(sprintf("Fold %s failed: %s", local_fold_id, e$message)); NULL
})
elapsed <- (proc.time() - start_time)[["elapsed"]]
if (is.null(res)) {
res <- structure(list(), elapsed_sec = elapsed)
} else {
attr(res, "elapsed_sec") <- elapsed
}
pb_counter <<- pb_counter + 1
utils::setTxtProgressBar(pb, pb_counter)
res
}
# parallel or sequential execution -----------------------------------------
if (parallel && requireNamespace("future.apply", quietly = TRUE)) {
out <- future.apply::future_lapply(seq_along(folds), function(i) {
fold <- folds[[i]]
fold$fold_seq <- i
progress_wrap(fold)
}, future.seed = TRUE)
} else {
out <- lapply(seq_along(folds), function(i) {
fold <- folds[[i]]
fold$fold_seq <- i
progress_wrap(fold)
})
}
close(pb)
# collect results -----------------------------------------------------------
met_rows <- list()
preds <- list()
guards <- list()
lears <- list()
featn <- NULL
audit_rows <- list()
fold_status_rows <- vector("list", length(out))
for (fold_idx in seq_along(out)) {
fold_res <- out[[fold_idx]]
learner_total <- length(learner_names)
learner_success <- 0L
fold_elapsed <- attr(fold_res, "elapsed_sec") %||% NA_real_
if (is.null(fold_res) || !length(names(fold_res))) next
fold_info <- resolve_fold_indices(folds[[fold_idx]])
fold_id <- fold_idx
for (ln in names(fold_res)) {
res <- fold_res[[ln]]
if (is.null(res)) next
m <- res$metrics
if (all(is.na(m))) {
next
}
learner_success <- learner_success + 1L
metric_row <- c(list(fold = fold_id, learner = ln), as.list(m))
met_rows[[length(met_rows) + 1]] <- as.data.frame(metric_row,
row.names = NULL,
check.names = FALSE)
if (is.data.frame(res$pred) || is.matrix(res$pred)) {
preds[[length(preds) + 1]] <- res$pred
}
if (!is.null(res$guard) && is.list(res$guard)) {
guards[[length(guards) + 1]] <- res$guard
if (is.null(featn) && !is.null(res$feat_names)) featn <- res$feat_names
} else {
guards[[length(guards) + 1]] <- NULL
}
lears[[length(lears) + 1]] <- res$learner
audit_rows[[length(audit_rows) + 1]] <- data.frame(
fold = fold_id,
n_train = length(fold_info$train),
n_test = length(fold_info$test),
learner = ln,
features_final = if (!is.null(res$guard) &&
is.list(res$guard) &&
!is.null(res$guard$filter) &&
!is.null(res$guard$filter$keep)) {
sum(res$guard$filter$keep)
} else if (!is.null(res$feat_names)) {
length(res$feat_names)
} else NA_integer_,
row.names = NULL
)
}
all_preds_empty <- length(fold_res) > 0 && all(vapply(fold_res, function(res) {
is.null(res$pred) || nrow(res$pred) == 0L
}, logical(1)))
fold_status_rows[[fold_idx]] <- data.frame(
fold = fold_idx,
stage = if (learner_success > 0L) "fold_run" else "fold_precheck",
status = if (learner_success > 0L) "success" else "skipped",
reason = if (learner_success > 0L) NA_character_ else if (all_preds_empty) {
"single_class_training"
} else {
"no_valid_metrics"
},
notes = sprintf("Successful learners: %d/%d", learner_success, learner_total),
elapsed_sec = fold_elapsed,
stringsAsFactors = FALSE
)
}
for (fold_idx in seq_along(out)) {
if (!is.null(fold_status_rows[[fold_idx]])) next
err_note <- fold_errors[[fold_idx]]
if (is.na(err_note)) err_note <- "Fold failed before producing any learner results."
fold_elapsed <- attr(out[[fold_idx]], "elapsed_sec") %||% NA_real_
fold_status_rows[[fold_idx]] <- data.frame(
fold = fold_idx,
stage = "fold_run",
status = "failed",
reason = "fold_error",
notes = err_note,
elapsed_sec = fold_elapsed,
stringsAsFactors = FALSE
)
}
fold_status_df <- do.call(rbind, fold_status_rows)
if (!length(met_rows)) {
.bio_stop("No successful folds were completed. Check learner and preprocessing settings.",
"bioLeak_fit_error")
}
met_df <- do.call(rbind, met_rows)
audit_df <- do.call(rbind, audit_rows)
metrics_used <- setdiff(colnames(met_df), c("fold", "learner"))
# summarize metrics ---------------------------------------------------------
metric_summary_raw <- aggregate(. ~ learner, data = met_df[, -1, drop = FALSE],
FUN = function(x) c(mean = mean(x, na.rm = TRUE),
sd = sd(x, na.rm = TRUE)),
na.action = stats::na.pass)
# flatten embedded matrices from aggregate() into separate columns
metric_summary <- data.frame(learner = metric_summary_raw$learner,
stringsAsFactors = FALSE)
for (col in setdiff(colnames(metric_summary_raw), "learner")) {
mat <- metric_summary_raw[[col]]
if (is.matrix(mat)) {
metric_summary[[paste0(col, "_mean")]] <- mat[, "mean"]
metric_summary[[paste0(col, "_sd")]] <- mat[, "sd"]
} else {
metric_summary[[col]] <- mat
}
}
# confidence intervals -------------------------------------------------------
tryCatch({
avg_n_train <- mean(audit_df$n_train, na.rm = TRUE)
avg_n_test <- mean(audit_df$n_test, na.rm = TRUE)
ci_df <- cv_ci(met_df, method = "nadeau_bengio",
n_train = avg_n_train, n_test = avg_n_test)
# merge CI columns into metric_summary
ci_cols <- grep("_ci_lo$|_ci_hi$", names(ci_df), value = TRUE)
if (length(ci_cols) && nrow(ci_df) == nrow(metric_summary)) {
for (cc in ci_cols) {
metric_summary[[cc]] <- ci_df[[cc]]
}
}
}, error = function(e) NULL)
# optional refit ------------------------------------------------------------
final_model <- NULL
final_guard <- NULL
if (refit) {
ln <- learner_names[[1]]
learner_obj <- learner_objs[[1]]
full_weights <- if (task %in% c("binomial", "multiclass")) resolve_weights(yall, class_weights) else NULL
if (identical(preprocess_mode, "guard")) {
guard_full <- .guard_fit(Xall, yall, preprocess, task)
Xfullg <- guard_full$transform(Xall)
colnames(Xfullg) <- make.names(colnames(Xfullg))
final_model <- train_one_learner(learner_obj, ln, Xfullg, yall, Xfullg, yall,
weights = full_weights)$fit
final_guard <- guard_full$state
} else if (identical(preprocess_mode, "recipe")) {
df_full <- make_fold_df(Xall_raw, yall)
recipe_prep <- recipes::prep(preprocess, training = df_full, retain = TRUE)
Xfullg <- as.data.frame(recipes::juice(recipe_prep, recipes::all_predictors()),
check.names = FALSE)
if (length(drop_cols)) {
keep_cols <- setdiff(names(Xfullg), drop_cols)
Xfullg <- Xfullg[, keep_cols, drop = FALSE]
}
colnames(Xfullg) <- make.names(colnames(Xfullg))
final_model <- train_one_learner(learner_obj, ln, Xfullg, yall, Xfullg, yall,
weights = full_weights)$fit
final_guard <- list(type = "recipe", recipe = recipe_prep)
} else {
df_full <- make_fold_df(Xall, yall)
final_model <- generics::fit(learner_obj, data = df_full)
final_guard <- list(type = "workflow")
}
}
# assemble LeakFit object ---------------------------------------------------
perm_refit_spec <- NULL
if (isTRUE(store_refit_data)) {
perm_refit_spec <- list(
x = x,
outcome = outcome,
preprocess = preprocess,
learner = learner_input,
learner_args = learner_args,
custom_learners = custom_learners,
class_weights = class_weights,
positive_class = positive_class,
classification_threshold = classification_threshold,
parallel = parallel
)
}
new("LeakFit",
splits = splits,
metrics = met_df,
metric_summary = metric_summary,
audit = audit_df,
predictions = preds,
preprocess = guards,
learners = lears,
outcome = outcome,
task = task,
feature_names = featn,
info = list(hash = .bio_hash_indices(folds),
metrics_requested = metrics_input,
metrics_used = metrics_used,
class_weights = class_weights,
positive_class = if (task == "binomial") class_levels[2] else NULL,
classification_threshold = if (task == "binomial") classification_threshold else NULL,
sample_ids = ids,
fold_seeds = setNames(seed + seq_along(folds),
paste0("fold", seq_along(folds))),
fold_status = fold_status_df,
refit = refit,
final_model = final_model,
final_preprocess = final_guard,
learner_names = learner_names,
perm_refit_spec = perm_refit_spec,
provenance = .bio_capture_provenance()))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.