tests/testthat/helper-postprocessing.R

library(parsnip)
library(workflows)
library(tune)
library(rsample)
library(recipes)
library(yardstick)
library(dplyr)
library(dials)
library(rlang)
library(tailor)
library(tibble)

# ------------------------------------------------------------------------------

dt_spec <- parsnip::decision_tree(
  mode = "classification",
  min_n = tune(),
  engine = "C5.0"
)

knn_cls_spec <- parsnip::nearest_neighbor(
  mode = "classification",
  neighbors = tune()
)


if (rlang::is_installed("probably")) {
  cls_est_post <- tailor::tailor() |>
    tailor::adjust_probability_calibration(method = "logistic")

  cls_cal_tune_post <- tailor::tailor() |>
    tailor::adjust_probability_calibration(method = "logistic") |>
    tailor::adjust_probability_threshold(threshold = tune("cut"))

  cls_cal <- tailor::tailor() |>
    tailor::adjust_probability_calibration()

  cls_tenth <- tailor::tailor() |>
    tailor::adjust_probability_threshold(threshold = 1 / 10)

  cls_post <- tailor::tailor() |>
    tailor::adjust_probability_threshold(threshold = tune("cut"))
}

fac_2c <- structure(
  integer(0),
  levels = c("Class1", "Class2"),
  class = "factor"
)
cls_two_class_plist <-
  tibble::tibble(
    Class = fac_2c,
    .pred_class = fac_2c,
    .pred_Class1 = double(0),
    .pred_Class2 = double(0),
    .row = integer(0),
  )

sim_2c <- structure(
  integer(0),
  levels = c("class_1", "class_2"),
  class = "factor"
)
cls_sim_plist <-
  tibble::tibble(
    class = sim_2c,
    .pred_class = sim_2c,
    .pred_class_1 = double(0),
    .pred_class_2 = double(0),
    .row = integer(0),
  )

# ------------------------------------------------------------------------------

dt_grid <- tibble::tibble(min_n = c(2, 4))
knn_grid <- tibble::tibble(neighbors = 1:3)
svm_grid <- tibble::tibble(degree = 1:2)

make_post_data <- function(mode = "classification") {
  set.seed(1)
  if (mode == "classification") {
    dat <- modeldata::sim_classification(1000)
    nm <- "class"
  } else if (mode == "regression") {
    dat <- modeldata::sim_regression(1000)
    nm <- "outcome"
  } else if (mode == "censored") {
    require(survival)
    dat <- modeldata::deliveries |>
      dplyr::select(time_to_delivery, starts_with("item"))
    evt <- rep_len(c(rep(1, 9), 0), nrow(dat))
    dat$outcome <- survival::Surv(dat$time_to_delivery, evt)
    dat$time_to_delivery <- NULL
    nm <- "outcome"
  } else {
    cli::abort(
      "Only have modes for classification, regression, and censored regression so far"
    )
  }
  rs <- rsample::mc_cv(dat, times = 2)
  rs_split <- rs$splits[[1]]
  rs_args <- rsample::.get_split_args(rs)
  list(data = dat, rs = rs, split = rs_split, args = rs_args, y = nm)
}

# ------------------------------------------------------------------------------

puromycin <- tibble::as_tibble(Puromycin)
puromycin_rec <- recipes::recipe(rate ~ ., data = puromycin) |>
  recipes::step_dummy(state)

puromycin_tune_rec <- puromycin_rec |>
  recipes::step_poly(conc, degree = tune())

knn_reg_spec <- parsnip::nearest_neighbor(
  mode = "regression",
  neighbors = tune()
)
svm_spec <- parsnip::svm_poly(mode = "regression", cost = 1, degree = tune())

reg_post <- tailor::tailor() |>
  tailor::adjust_predictions_custom(.pred = .pred + 10000)


if (rlang::is_installed("probably")) {
  reg_cal_max <- tailor::tailor() |>
    tailor::adjust_numeric_calibration() |>
    tailor::adjust_numeric_range(upper_limit = tune())

  reg_cal <- tailor::tailor() |>
    tailor::adjust_numeric_calibration()

  reg_cal_tune <- tailor::tailor() |>
    tailor::adjust_numeric_calibration(method = tune())

  reg_max <- tailor::tailor() |>
    tailor::adjust_numeric_range(upper_limit = tune())
}

glmn_spec <- parsnip::linear_reg(penalty = tune(), mixture = tune()) |>
  parsnip::set_engine("glmnet")

reg_sim_plist <- tibble::tibble(
  outcome = double(0),
  .pred = double(0),
  .row = integer(0)
)

puromycin_plist <- tibble::tibble(
  rate = puromycin$rate[0],
  .pred = puromycin$rate[0],
  .row = integer(0)
)

# ------------------------------------------------------------------------------

surv_0 <- structure(
  numeric(0),
  type = "right",
  dim = c(0L, 2L),
  dimnames = list(NULL, c("time", "status")),
  class = "Surv"
)

pred_0 <- tibble::tibble(
  .eval_time = numeric(0),
  .pred_survival = numeric(0)
)

pred_dyn_0 <- tibble::tibble(
  .eval_time = numeric(0),
  .pred_survival = numeric(0),
  .weight_censored = numeric(0)
)

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.