tests/testthat/test-fit-resample-edge.R

test_that("fit_resample validates custom learner specifications", {
  df <- make_class_df(10)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 2, seed = 1)

  expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                  learner = "glm", custom_learners = "bad"),
               "custom_learners must be a named list")

  expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                  learner = "glm", custom_learners = list(list())),
               "custom_learners must be a named list")

  bad_custom <- list(glm = list(fit = function(...) NULL))
  expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                  learner = "glm", custom_learners = bad_custom),
               "fit.*predict")
})

test_that("fit_resample validates class weights and positive class", {
  df <- make_class_df(12)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 3, seed = 1)
  custom <- make_custom_learners()

  expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                  learner = "glm", custom_learners = custom,
                                  class_weights = c(a = 1),
                                  metrics = "auc", refit = FALSE),
               "missing levels")

  expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                  learner = "glm", custom_learners = custom,
                                  positive_class = c(0, 1),
                                  metrics = "auc", refit = FALSE),
               "single value")
})

test_that("fit_resample supports class weights with parsnip learners", {
  skip_if_not_installed("parsnip")
  df <- make_class_df(12)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 3, seed = 1)
  spec <- parsnip::logistic_reg() |> parsnip::set_engine("glm")
  fit <- fit_resample_quiet(df, outcome = "outcome", splits = splits,
                            learner = spec,
                            class_weights = c("0" = 1, "1" = 1),
                            metrics = "auc", refit = FALSE)
  expect_true(nrow(fit@metrics) > 0)
})

test_that("fit_resample drops invalid metrics with warnings", {
  df <- make_class_df(10)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 2, seed = 1)
  custom <- make_custom_learners()

  fit <- expect_warning_match(
    fit_resample_quiet(df, outcome = "outcome", splits = splits,
                       learner = "glm", custom_learners = custom,
                       metrics = c("auc", "rmse"), refit = FALSE),
    "Dropping metrics"
  )
  expect_true(nrow(fit@metrics) > 0)
})

test_that("fit_resample summarizes metrics with all-NA columns", {
  df <- make_class_df(10)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 2, seed = 1)
  custom <- make_custom_learners()
  na_metric <- function(y, pred) NA_real_

  fit <- fit_resample_quiet(df, outcome = "outcome", splits = splits,
                            learner = "glm", custom_learners = custom,
                            metrics = list(auc = "auc", na_metric = na_metric),
                            refit = FALSE)
  expect_true(nrow(fit@metric_summary) > 0)
  expect_true("na_metric_mean" %in% colnames(fit@metric_summary))
})

test_that("fit_resample handles gaussian tasks and ignores classification options", {
  df <- make_regression_df(12)
  splits <- make_split_plan_quiet(df, outcome = "y",
                              mode = "subject_grouped", group = "subject",
                              v = 3, seed = 1)
  custom <- make_custom_learners()

  fit <- expect_warning_match(
    fit_resample_quiet(df, outcome = "y", splits = splits,
                       learner = "glm", custom_learners = custom,
                       class_weights = c(a = 1, b = 2),
                       positive_class = "a",
                       metrics = c("rmse", "cindex"), refit = FALSE),
    "ignored",
    all = TRUE
  )
  expect_equal(fit@task, "gaussian")
})

# test_that("fit_resample errors on unsupported outcomes", {
#   df <- make_class_df(10)
#   df$outcome <- factor(c("a", "b", "c", "a", "b", "c", "a", "b", "c", "a"))
#   splits <- make_split_plan_quiet(df, outcome = "outcome",
#                               mode = "subject_grouped", group = "subject",
#                               v = 2, seed = 1)
#   custom <- make_custom_learners()
#   expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
#                                   learner = "glm", custom_learners = custom),
#                "No successful folds were completed. Check learner and preprocessing settings.")
# })

test_that("fit_resample respects positive_class releveling", {
  df <- make_class_df(12)
  df$outcome <- factor(ifelse(df$outcome == 1, "yes", "no"), levels = c("no", "yes"))
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 3, seed = 1)
  custom <- make_custom_learners()
  fit <- fit_resample_quiet(df, outcome = "outcome", splits = splits,
                            learner = "glm", custom_learners = custom,
                            positive_class = "yes", metrics = "auc",
                            refit = FALSE)
  expect_equal(fit@info$positive_class, "yes")
})

test_that("fit_resample applies classification_threshold for binomial tasks", {
  df <- make_class_df(12)
  df$x1 <- seq(0.05, 0.95, length.out = nrow(df))
  custom <- list(
    fixed_prob = list(
      fit = function(x, y, task, weights, ...) list(),
      predict = function(object, newdata, task, ...) as.numeric(newdata$x1)
    )
  )
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 3, seed = 2)

  fit_low <- fit_resample_quiet(
    df, outcome = "outcome", splits = splits,
    learner = "fixed_prob", custom_learners = custom,
    metrics = "accuracy", classification_threshold = 0.2,
    refit = FALSE
  )
  fit_high <- fit_resample_quiet(
    df, outcome = "outcome", splits = splits,
    learner = "fixed_prob", custom_learners = custom,
    metrics = "accuracy", classification_threshold = 0.8,
    refit = FALSE
  )

  pred_low <- do.call(rbind, fit_low@predictions)
  pred_high <- do.call(rbind, fit_high@predictions)
  key_low <- order(pred_low$fold, pred_low$id)
  key_high <- order(pred_high$fold, pred_high$id)
  expect_false(identical(as.character(pred_low$pred_class[key_low]),
                         as.character(pred_high$pred_class[key_high])))
  expect_false(isTRUE(all.equal(fit_low@metrics$accuracy, fit_high@metrics$accuracy)))
  expect_equal(fit_low@info$classification_threshold, 0.2)
  expect_equal(fit_high@info$classification_threshold, 0.8)

  expect_error(
    fit_resample_quiet(df, outcome = "outcome", splits = splits,
                       learner = "fixed_prob", custom_learners = custom,
                       classification_threshold = 1.2, refit = FALSE),
    "classification_threshold must be"
  )
})

test_that("fit_resample surfaces custom learner prediction length errors", {
  df <- make_class_df(10)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 2, seed = 1)
  custom <- list(
    bad = list(
      fit = function(x, y, task, weights, ...) suppressWarnings(stats::glm(
        y ~ ., data = data.frame(y = y, x), family = stats::binomial()
      )),
      predict = function(object, newdata, task, ...) rep(0.5, nrow(newdata) + 1)
    )
  )
  expect_warning_match(
    expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                    learner = "bad", custom_learners = custom,
                                    metrics = "auc", refit = FALSE),
                 "No successful folds were completed"),
    "Custom learner 'bad' returned"
  )
})

test_that("fit_resample warns when a fold lacks both classes", {
  df <- make_class_df(8)
  df$outcome <- c(rep(0, 4), rep(1, 4))
  indices <- list(
    list(train = 1:4, test = 5:6, fold = 1, repeat_id = 1),
    list(train = c(1:3, 5:8), test = 4, fold = 2, repeat_id = 1)
  )
  splits <- bioLeak:::LeakSplits(mode = "custom", indices = indices,
                                 info = list(outcome = "outcome", coldata = df))
  custom <- make_custom_learners()
  fit <- expect_warning_match(
    fit_resample_quiet(df, outcome = "outcome", splits = splits,
                       learner = "glm", custom_learners = custom,
                       metrics = "accuracy", refit = FALSE),
    "only one class"
  )
  expect_true(nrow(fit@metrics) > 0)
  expect_true(is.data.frame(fit@info$fold_status))
  expect_equal(nrow(fit@info$fold_status), length(splits@indices))
  expect_true(any(fit@info$fold_status$status == "success"))
  expect_true(any(fit@info$fold_status$status == "skipped"))
})

test_that("fit_resample errors for compact time_series without time metadata", {
  df <- make_class_df(10)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "time_series", time = "time",
                              v = 2, seed = 1, compact = TRUE)
  splits@info$time <- NULL
  custom <- make_custom_learners()
  expect_error(fit_resample_quiet(df, outcome = "outcome", splits = splits,
                                  learner = "glm", custom_learners = custom,
                                  metrics = "auc", refit = FALSE),
               "time_series compact splits")
})

test_that("fit_resample handles compact splits with repeats", {
  df <- make_class_df(30)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped",
                              group = "subject",
                              v = 3, repeats = 2,
                              compact = TRUE, seed = 1)
  fit <- fit_resample_quiet(df, outcome = "outcome", splits = splits,
                            learner = "glm", custom_learners = make_custom_learners(),
                            metrics = "auc", refit = FALSE, seed = 1)
  pred_df <- do.call(rbind, fit@predictions)
  fold_counts <- table(pred_df$fold)
  expect_equal(length(fold_counts), length(splits@indices))
  expect_true(all(fold_counts > 0))
})

test_that("fit_resample resolves compact time_series splits with purge and embargo", {
  df <- make_class_df(12)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "time_series", time = "time",
                              v = 3, compact = TRUE, seed = 1,
                              purge = 2, embargo = 5)
  fit <- fit_resample_quiet(df, outcome = "outcome", splits = splits,
                            learner = "glm", custom_learners = make_custom_learners(),
                            metrics = "accuracy", refit = FALSE, seed = 1)

  expect_equal(fit@audit$n_train, c(3, 7))
  expect_equal(fit@splits@info$purge, 2)
  expect_equal(fit@splits@info$embargo, 5)
})

test_that("fit_resample supports parsnip learners when available", {
  skip_if_not_installed("parsnip")
  df <- make_class_df(12)
  splits <- make_split_plan_quiet(df, outcome = "outcome",
                              mode = "subject_grouped", group = "subject",
                              v = 3, seed = 1)
  spec <- parsnip::logistic_reg() |>
    parsnip::set_engine("glm")
  fit <- expect_warning_match(
    fit_resample_quiet(df, outcome = "outcome", splits = splits,
                       learner = spec, custom_learners = make_custom_learners(),
                       learner_args = list(alpha = 1),
                       metrics = "auc", refit = FALSE, seed = 1),
    "ignored"
  )
  expect_true(nrow(fit@metrics) > 0)
  expect_true(any(nzchar(fit@metrics$learner)))
})

Try the bioLeak package in your browser

Any scripts or data that you put into this service are public.

bioLeak documentation built on March 6, 2026, 1:06 a.m.