tests/testthat/test-schedule.R

# `get_param_info()` -----------------------------------------------------

test_that("`get_param_info()` works for a workflow without tags for tuning", {
  wflow <- workflow(mpg ~ ., parsnip::linear_reg())
  param_info <- get_param_info(wflow)

  expect_named(param_info, c("name", "id", "source", "has_submodel"))
  expect_identical(nrow(param_info), 0L)
})

test_that("`get_param_info()` works for a workflow with tags for tuning", {
  skip_if_not_installed("splines2")
  skip_if_not_installed("probably")

  # tuning tags in all components
  wflow <- workflow(rec_tune, mod_tune_no_submodel, tlr_tune)

  param_info <- get_param_info(wflow)

  expect_named(param_info, c("name", "id", "source", "has_submodel"))
  expect_identical(
    param_info$id,
    c("min_n", "threshold", "disp_df", "lower_limit")
  )
  expect_identical(param_info$has_submodel, c(FALSE, FALSE, FALSE, FALSE))
})

test_that("`get_param_info()` works when there are submodel parameters", {
  skip_if_not_installed("probably")

  # tuning tags only in model spec
  rec_no_steps <- recipes::recipe(mpg ~ ., data = mtcars)

  wflow <- workflow(rec_no_steps, mod_tune, tlr_no_tune)

  param_info <- get_param_info(wflow)

  expect_named(param_info, c("name", "id", "source", "has_submodel"))
  expect_identical(param_info$name, c("trees", "min_n"))
  expect_identical(param_info$has_submodel, c(TRUE, FALSE))
})


# `schedule_predict_stage_i()` -------------------------------------------

test_that("`schedule_predict_stage_i()` works with: no submodel, no post-processing", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_no_tune, tlr_no_tune)
  param_info <- get_param_info(wflow)
  grid_predict_stage <- tibble::tibble()

  schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
  expect_named(schedule, c("post_stage"))
  expect_identical(nrow(schedule), 0L)
})

test_that("`schedule_predict_stage_i()` works with: no submodel, with post-processing", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_no_tune, tlr_tune)
  param_info <- get_param_info(wflow)
  grid_predict_stage <- tibble::tibble(lower_limit = 1:2)

  schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
  expect_named(schedule, c("post_stage"))
  expect_identical(nrow(schedule), 1L)
  expect_identical(schedule$post_stage[[1]], grid_predict_stage)
})

test_that("`schedule_predict_stage_i()` works with: with submodel, no post-processing", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune_submodel, tlr_no_tune)
  param_info <- get_param_info(wflow)
  grid_predict_stage <- tibble::tibble(trees = 1:2)

  schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
  expect_named(schedule, c("trees", "post_stage"))
  expect_identical(schedule$trees, grid_predict_stage$trees)
  expect_identical(
    purrr::map_int(schedule$post_stage, ncol),
    c(0L, 0L)
  )
})

test_that("`schedule_predict_stage_i()` works with: with submodel, with post-processing", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune_submodel, tlr_tune)
  param_info <- get_param_info(wflow)
  # semi-regular grid
  grid_predict_stage <- list(
    tibble::tibble(trees = 1L, lower_limit = 1L),
    tibble::tibble(trees = c(2L, 2L), lower_limit = 1:2)
  ) |>
    purrr::list_rbind()

  schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
  expect_named(schedule, c("trees", "post_stage"))
  expect_identical(schedule$trees, 1:2)
  expect_identical(
    schedule$post_stage[[1]],
    tibble::tibble(lower_limit = 1L)
  )
  expect_identical(
    schedule$post_stage[[2]],
    tibble::tibble(lower_limit = 1:2)
  )
})


# `schedule_model_stage_i()` ---------------------------------------------

test_that("`schedule_model_stage_i()` works with: no tuning at all", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_no_tune, tlr_no_tune)
  param_info <- get_param_info(wflow)
  grid_model_stage <- tibble::tibble()

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
  expect_named(schedule, c("predict_stage"))
  expect_identical(nrow(schedule), 0L)
})

test_that("`schedule_model_stage_i()` works with only non-submodel: with non-submodel, no submodel, no post", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune_no_submodel, tlr_no_tune)
  param_info <- get_param_info(wflow)
  grid_model_stage <- tibble::tibble(min_n = 1:2)

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
  expect_named(schedule, c("min_n", "predict_stage"))
  expect_identical(nrow(schedule), 2L)
  expect_identical(schedule$min_n, 1:2)
  expect_identical(
    purrr::map_chr(schedule$predict_stage, \(x) class(x)[1]),
    c("tbl_df", "tbl_df")
  )
  expect_identical(
    purrr::map_chr(schedule$predict_stage, names),
    c("post_stage", "post_stage")
  )
})

test_that("`schedule_model_stage_i()` works with only submodel: no non-submodel, with submodel, no post", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune_submodel, tlr_no_tune)
  param_info <- get_param_info(wflow)
  grid_model_stage <- tibble::tibble(trees = 1:2)

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
  expect_named(schedule, c("trees", "predict_stage"))
  expect_identical(nrow(schedule), 1L)
  expect_identical(schedule$trees, 2L)
  expect_identical(
    purrr::map_chr(schedule$predict_stage, \(x) class(x)[1]),
    c("tbl_df")
  )
  expect_identical(
    purrr::map(schedule$predict_stage, names),
    list(c("trees", "post_stage"))
  )
  expect_identical(
    schedule$predict_stage[[1]] |> pull("trees"),
    1:2
  )
})

test_that("`schedule_model_stage_i()` works with only post: no non-submodel, no submodel, with post", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_no_tune, tlr_tune)
  param_info <- get_param_info(wflow)
  grid_model_stage <- tibble::tibble(lower_limit = 1:2)

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
  expect_named(schedule, c("predict_stage"))
  expect_identical(nrow(schedule), 1L)
  expect_identical(
    purrr::map_chr(schedule$predict_stage, \(x) class(x)[1]),
    c("tbl_df")
  )
  expect_identical(
    purrr::map(schedule$predict_stage, names),
    list(c("post_stage"))
  )
  expect_identical(
    schedule$predict_stage[[1]] |> pull("post_stage"),
    list(tibble::tibble(lower_limit = 1:2))
  )
})

test_that("`schedule_model_stage_i()` works with both model types only: with non-submodel, with submodel, no post", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune, tlr_no_tune)
  param_info <- get_param_info(wflow)
  # irregular grid
  grid_model_stage <- list(
    tibble::tibble(trees = 1:2, min_n = 1L),
    tibble::tibble(trees = 2L, min_n = 2L),
    tibble::tibble(trees = 2:3, min_n = 3L)
  ) |>
    purrr::list_rbind()

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)

  expect_named(schedule, c("trees", "min_n", "predict_stage"))
  expect_identical(nrow(schedule), 3L)
  expect_identical(schedule$trees, c(2L, 2L, 3L))
  expect_identical(schedule$min_n, 1:3)
  expect_identical(
    purrr::map_chr(schedule$predict_stage, \(x) class(x)[1]) |> unique(),
    "tbl_df"
  )
  expect_identical(
    purrr::map(schedule$predict_stage, names),
    list(
      c("trees", "post_stage"),
      c("trees", "post_stage"),
      c("trees", "post_stage")
    )
  )
  expect_identical(
    schedule |>
      filter(min_n == 1) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("trees"),
    1:2
  )
  expect_identical(
    schedule |>
      filter(min_n == 2) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("trees"),
    2L
  )
  expect_identical(
    schedule |>
      filter(min_n == 3) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("trees"),
    2:3
  )
})

test_that("`schedule_model_stage_i()` works without submodel: with non-submodel, no submodel, with post", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune_no_submodel, tlr_tune)
  param_info <- get_param_info(wflow)
  # semi-regular grid
  grid_model_stage <- list(
    tibble::tibble(min_n = 1L, lower_limit = 1:2),
    tibble::tibble(min_n = 2L, lower_limit = 1:3),
    tibble::tibble(min_n = 3L, lower_limit = 3L)
  ) |>
    purrr::list_rbind()

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
  expect_named(schedule, c("min_n", "predict_stage"))
  expect_identical(nrow(schedule), 3L)
  expect_identical(schedule$min_n, 1:3)
  expect_identical(
    purrr::map_chr(schedule$predict_stage, \(x) class(x)[1]) |> unique(),
    "tbl_df"
  )
  expect_identical(
    purrr::map(schedule$predict_stage, names) |>
      purrr::list_c() |>
      unique(),
    "post_stage"
  )
  expect_identical(
    schedule |>
      filter(min_n == 1) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("post_stage") |>
      purrr::pluck(1) |>
      pull("lower_limit"),
    1:2
  )
  expect_identical(
    schedule |>
      filter(min_n == 2) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("post_stage") |>
      purrr::pluck(1) |>
      pull("lower_limit"),
    1:3
  )
  expect_identical(
    schedule |>
      filter(min_n == 3) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("post_stage") |>
      purrr::pluck(1) |>
      pull("lower_limit"),
    3L
  )
})

test_that("`schedule_model_stage_i()` works everything: with non-submodel, with submodel, with post", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_tune, tlr_tune)
  param_info <- get_param_info(wflow)
  # semi-regular grid
  grid_model_stage <- list(
    tibble::tibble(
      min_n = 1L,
      trees = c(1L, 1L, 2L, 2L),
      lower_limit = c(1:2, 1:2)
    ),
    tibble::tibble(
      min_n = 2L,
      trees = c(1L, rep(2L, 3)),
      lower_limit = c(1L, 1:3)
    ),
    # another row to be combined under min_n = 1
    tibble::tibble(min_n = 1L, trees = 3L, lower_limit = 4L)
  ) |>
    purrr::list_rbind()

  schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
  expect_named(schedule, c("trees", "min_n", "predict_stage"))
  expect_identical(nrow(schedule), 2L)
  expect_identical(schedule$min_n, 1:2)
  expect_identical(
    purrr::map_chr(schedule$predict_stage, \(x) class(x)[1]) |> unique(),
    "tbl_df"
  )
  expect_identical(
    purrr::map(schedule$predict_stage, names) |>
      purrr::list_c() |>
      unique(),
    c("trees", "post_stage")
  )
  expect_identical(
    schedule |>
      filter(min_n == 1) |>
      select(-predict_stage),
    tibble::tibble(trees = 3L, min_n = 1L)
  )
  expect_identical(
    schedule |>
      filter(min_n == 1) |>
      pull(predict_stage) |>
      purrr::pluck(1) |>
      pull(trees),
    1:3
  )
  expect_identical(
    schedule |>
      filter(min_n == 1) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("post_stage") |>
      purrr::list_rbind() |>
      pull("lower_limit"),
    c(1:2, 1:2, 4L)
  )
  expect_identical(
    schedule |>
      filter(min_n == 2) |>
      select(-predict_stage),
    tibble::tibble(trees = 2L, min_n = 2L)
  )
  expect_identical(
    schedule |>
      filter(min_n == 2) |>
      pull(predict_stage) |>
      purrr::pluck(1) |>
      pull(trees),
    1:2
  )
  expect_identical(
    schedule |>
      filter(min_n == 2) |>
      pull("predict_stage") |>
      purrr::pluck(1) |>
      pull("post_stage") |>
      purrr::list_rbind() |>
      pull("lower_limit"),
    c(1L, 1:3)
  )
})


# `schedule_stages()` ----------------------------------------------------

test_that("`schedule_stages()` works without preprocessing", {
  skip_if_not_installed("probably")

  wflow <- workflow(mpg ~ ., mod_no_tune, tlr_no_tune)
  grid <- tibble::tibble()

  schedule <- schedule_stages(grid, wflow)
  expect_named(schedule, c("model_stage"))
  expect_identical(nrow(schedule), 0L)
})

test_that("`schedule_stages()` works with preprocessing", {
  skip_if_not_installed("splines2")
  skip_if_not_installed("probably")

  wflow <- workflow(rec_tune, mod_no_tune, tlr_no_tune)
  grid <- tibble::tibble(
    threshold = rep(1:2, each = 2),
    disp_df = c(1:2, 1:2)
  )

  schedule <- schedule_stages(grid, wflow)
  expect_named(schedule, c("threshold", "disp_df", "model_stage"))
  expect_identical(nrow(schedule), 4L)
  expect_identical(
    schedule |> dplyr::select(-model_stage),
    grid
  )
})

# `schedule_grid()` ------------------------------------------------------

# No tuning or postprocesing estimation

test_that("grid processing schedule - no parameters", {
  wflow_nada <- workflow(outcome ~ ., parsnip::logistic_reg())
  grid_nada <- tibble::tibble()

  sched_nada <- schedule_grid(grid_nada, wflow_nada)

  expect_named(sched_nada, "model_stage")
  expect_equal(nrow(sched_nada), 0)

  expect_s3_class(
    sched_nada,
    c("resample_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe and model", {
  skip_if_not_installed("splines2")

  wflow_pre_only <- workflow(rec_no_tune, parsnip::logistic_reg())
  grid_pre_only <- tibble::tibble()
  sched_pre_only <- schedule_grid(grid_pre_only, wflow_pre_only)

  expect_named(sched_pre_only, c("model_stage"))
  expect_equal(nrow(sched_pre_only), 0)

  expect_s3_class(
    sched_pre_only,
    c("resample_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe, model, and post", {
  skip_if_not_installed("splines2")
  skip_if_not_installed("probably")

  wflow_three <- workflow(rec_no_tune, parsnip::linear_reg(), tlr_no_tune)
  grid_three <- tibble::tibble()
  sched_three <- schedule_grid(grid_three, wflow_three)

  expect_named(sched_three, c("model_stage"))
  expect_equal(nrow(sched_three), 0)

  expect_s3_class(
    sched_three,
    c("resample_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

# Tuning, no postprocesing estimation

test_that("grid processing schedule - recipe only", {
  skip_if_not_installed("splines2")

  wflow_pre_only <- workflow(rec_tune, parsnip::logistic_reg())
  grid_pre_only <-
    extract_parameter_set_dials(wflow_pre_only) |>
    dials::grid_regular(levels = 3) |>
    arrange(threshold, disp_df)
  sched_pre_only <-
    schedule_grid(grid_pre_only, wflow_pre_only)

  expect_named(sched_pre_only, c("threshold", "disp_df", "model_stage"))
  expect_equal(nrow(sched_pre_only), nrow(grid_pre_only))

  # All of the other nested tibbles should be empty
  expect_equal(
    sched_pre_only |>
      tidyr::unnest(model_stage) |>
      tidyr::unnest(predict_stage) |>
      tidyr::unnest(post_stage),
    grid_pre_only
  )

  expect_s3_class(
    sched_pre_only,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - model only, no submodels", {
  wflow_rf_only <- workflow(outcome ~ ., mod_tune_no_submodel)
  grid_rf_only <-
    extract_parameter_set_dials(wflow_rf_only) |>
    dials::grid_regular(levels = 3)
  sched_rf_only <-
    schedule_grid(grid_rf_only, wflow_rf_only)

  expect_named(sched_rf_only, c("model_stage"))
  expect_equal(nrow(sched_rf_only), 1L)

  rf_n <- length(sched_rf_only$model_stage)
  for (i in 1:rf_n) {
    # No real need for the loop here
    expect_named(sched_rf_only$model_stage[[i]], c("min_n", "predict_stage"))
    expect_equal(
      sched_rf_only$model_stage[[i]] |>
        tidyr::unnest(predict_stage) |>
        tidyr::unnest(post_stage),
      grid_rf_only
    )
  }

  expect_s3_class(
    sched_rf_only,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - model only, submodels, regular grid", {
  wflow_bst <- workflow(outcome ~ ., mod_tune)
  grid_bst <- extract_parameter_set_dials(wflow_bst) |>
    dials::grid_regular(levels = 3)

  min_n_only <- grid_bst |> dplyr::distinct(min_n) |> dplyr::arrange(min_n)
  trees_only <- grid_bst |> dplyr::distinct(trees) |> dplyr::arrange(trees)

  # regular grid
  sched_bst <- schedule_grid(grid_bst, wflow_bst)

  expect_named(sched_bst, c("model_stage"))
  expect_equal(nrow(sched_bst), 1L)

  reg_n <- length(sched_bst$model_stage)
  for (i in 1:reg_n) {
    expect_named(
      sched_bst$model_stage[[i]],
      c("trees", "min_n", "predict_stage")
    )

    expect_equal(
      sched_bst$model_stage[[i]] |>
        dplyr::select(-trees, -predict_stage),
      min_n_only
    )

    for (j in seq_along(sched_bst$model_stage[[i]]$predict_stage)) {
      expect_named(
        sched_bst$model_stage[[i]]$predict_stage[[j]],
        c("trees", "post_stage")
      )
      expect_equal(
        sched_bst$model_stage[[i]]$predict_stage[[j]] |>
          dplyr::select(trees),
        trees_only
      )
    }

    expect_equal(
      sched_bst$model_stage[[i]] |>
        dplyr::select(-trees) |>
        tidyr::unnest(predict_stage) |>
        tidyr::unnest(post_stage) |>
        dplyr::select(trees, min_n),
      grid_bst
    )
  }

  expect_s3_class(
    sched_bst,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - model only, submodels, SFD grid", {
  wflow_bst <- workflow(outcome ~ ., mod_tune)
  grid_sfd_bst <- extract_parameter_set_dials(wflow_bst) |>
    dials::grid_space_filling(size = 5, type = "uniform")
  sched_sfd_bst <- schedule_grid(grid_sfd_bst, wflow_bst)

  expect_named(sched_sfd_bst, c("model_stage"))
  expect_equal(nrow(sched_sfd_bst), 1L)

  irreg_n <- length(sched_sfd_bst$model_stage)
  expect_equal(irreg_n, 1L)

  expect_named(
    sched_sfd_bst$model_stage[[1]],
    c("trees", "min_n", "predict_stage")
  )
  expect_equal(
    sched_sfd_bst$model_stage[[1]] |>
      dplyr::select(-predict_stage) |>
      dplyr::select(trees, min_n) |>
      dplyr::arrange(trees, min_n),
    grid_sfd_bst |>
      dplyr::select(trees, min_n) |>
      dplyr::arrange(trees, min_n)
  )

  expect_equal(
    sched_sfd_bst$model_stage[[1]] |>
      dplyr::select(-trees) |>
      tidyr::unnest(predict_stage) |>
      tidyr::unnest(post_stage) |>
      dplyr::select(trees, min_n) |>
      dplyr::arrange(trees, min_n),
    grid_sfd_bst |>
      dplyr::select(trees, min_n) |>
      dplyr::arrange(trees, min_n)
  )

  expect_s3_class(
    sched_sfd_bst,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - model only, submodels, irregular design", {
  wflow_bst <- workflow(outcome ~ ., mod_tune)
  grid_odd_bst <- tibble::tibble(
    min_n = c(1, 1, 2, 3, 4, 5),
    trees = rep(1:2, 3)
  )
  sched_odd_bst <- schedule_grid(grid_odd_bst, wflow_bst)

  expect_named(sched_odd_bst, c("model_stage"))
  expect_equal(nrow(sched_odd_bst), 1L)

  odd_n <- length(sched_odd_bst$model_stage)
  expect_equal(odd_n, 1L)

  expect_named(
    sched_odd_bst$model_stage[[1]],
    c("trees", "min_n", "predict_stage")
  )
  expect_equal(
    sched_odd_bst$model_stage[[1]] |>
      dplyr::select(-predict_stage) |>
      dplyr::select(trees, min_n),
    tibble::tibble(trees = c(2, 1, 2, 1, 2), min_n = c(1, 2, 3, 4, 5))
  )

  for (i in 1:nrow(sched_odd_bst$model_stage[[1]])) {
    prd <- sched_odd_bst$model_stage[[1]]$predict_stage[[i]]
    if (sched_odd_bst$model_stage[[1]]$min_n[i] == 1) {
      expect_equal(nrow(prd), 2L)
    } else {
      expect_equal(nrow(prd), 1L)
    }
    expect_true(
      all(purrr::map_int(prd$post_stage, nrow) == 1)
    )
  }

  expect_s3_class(
    sched_odd_bst,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - model only, submodels, 1 point design", {
  wflow_bst <- workflow(outcome ~ ., mod_tune)

  set.seed(1)
  grid_1_pt <- extract_parameter_set_dials(wflow_bst) |>
    dials::grid_random(size = 1)
  sched_1_pt <- schedule_grid(grid_1_pt, wflow_bst)

  expect_named(sched_1_pt, c("model_stage"))
  expect_equal(nrow(sched_1_pt), 1L)
  expect_equal(length(sched_1_pt$model_stage), 1L)
  expect_named(
    sched_1_pt$model_stage[[1]],
    c("trees", "min_n", "predict_stage")
  )

  expect_equal(
    length(sched_1_pt$model_stage[[1]]$predict_stage),
    1L
  )
  expect_named(
    sched_1_pt$model_stage[[1]]$predict_stage[[1]],
    c("trees", "post_stage")
  )

  expect_equal(
    length(sched_1_pt$model_stage[[1]]$predict_stage[[1]]$post_stage),
    1L
  )
  expect_equal(
    dim(sched_1_pt$model_stage[[1]]$predict_stage[[1]]$post_stage[[1]]),
    1:0
  )

  expect_s3_class(
    sched_1_pt,
    c(
      "single_schedule",
      "grid_schedule",
      "schedule",
      "tbl_df",
      "tbl",
      "data.frame"
    )
  )
})

test_that("grid processing schedule - postprocessing only", {
  skip_if_not_installed("probably")

  wflow_thrsh <- workflow(outcome ~ ., parsnip::linear_reg(), tlr_tune)
  grid_thrsh <- extract_parameter_set_dials(wflow_thrsh) |>
    update(lower_limit = dials::lower_limit(c(0, 1))) |>
    dials::grid_regular(levels = 3)

  sched_thrsh <- schedule_grid(grid_thrsh, wflow_thrsh)

  expect_named(sched_thrsh, c("model_stage"))
  expect_equal(nrow(sched_thrsh), 1L)

  expect_named(sched_thrsh$model_stage[[1]], c("predict_stage"))
  expect_equal(nrow(sched_thrsh$model_stage[[1]]), 1L)

  expect_named(
    sched_thrsh$model_stage[[1]]$predict_stage[[1]],
    c("post_stage")
  )
  expect_equal(nrow(sched_thrsh$model_stage[[1]]), 1L)

  expect_equal(
    sched_thrsh$model_stage[[1]]$predict_stage[[1]]$post_stage[[1]],
    grid_thrsh
  )

  expect_s3_class(
    sched_thrsh,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe + postprocessing, regular grid", {
  skip_if_not_installed("splines2")
  skip_if_not_installed("probably")

  wflow_pre_post <- workflow(rec_tune, parsnip::linear_reg(), tlr_tune)
  grid_pre_post <-
    extract_parameter_set_dials(wflow_pre_post) |>
    update(lower_limit = dials::lower_limit(c(0, 1))) |>
    dials::grid_regular(levels = 3)

  grid_pre <-
    grid_pre_post |>
    distinct(threshold, disp_df)
  grid_post <-
    grid_pre_post |>
    distinct(lower_limit) |>
    arrange(lower_limit)

  sched_pre_post <- schedule_grid(grid_pre_post, wflow_pre_post)

  expect_named(sched_pre_post, c("threshold", "disp_df", "model_stage"))
  expect_equal(
    sched_pre_post |> select(-model_stage) |> tibble::as_tibble(),
    grid_pre
  )

  for (i in seq_along(sched_pre_post$model_stage)) {
    expect_named(sched_pre_post$model_stage[[i]], c("predict_stage"))
    expect_equal(nrow(sched_pre_post$model_stage[[i]]), 1L)
  }

  for (i in seq_along(sched_pre_post$model_stage)) {
    expect_named(
      sched_pre_post$model_stage[[i]]$predict_stage[[1]],
      c("post_stage")
    )
    expect_identical(
      sched_pre_post$model_stage[[i]]$predict_stage[[1]]$post_stage[[1]] |>
        arrange(lower_limit),
      grid_post
    )
  }

  expect_s3_class(
    sched_pre_post,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe + postprocessing, irregular grid", {
  skip_if_not_installed("splines2")
  skip_if_not_installed("probably")

  wflow_pre_post <- workflow(rec_tune, parsnip::linear_reg(), tlr_tune)
  grid_pre_post <-
    extract_parameter_set_dials(wflow_pre_post) |>
    update(lower_limit = dials::lower_limit(c(0, 1))) |>
    dials::grid_regular() |>
    dplyr::slice(-c(1, 14))

  grid_pre <-
    grid_pre_post |>
    distinct(threshold, disp_df)

  grids_post <-
    grid_pre_post |>
    dplyr::group_nest(threshold, disp_df) |>
    mutate(data = purrr::map(data, \(.x) arrange(.x, lower_limit)))

  sched_pre_post <- schedule_grid(grid_pre_post, wflow_pre_post)

  expect_named(sched_pre_post, c("threshold", "disp_df", "model_stage"))
  expect_equal(
    sched_pre_post |> select(-model_stage) |> tibble::as_tibble(),
    grid_pre
  )

  for (i in seq_along(sched_pre_post$model_stage)) {
    expect_named(sched_pre_post$model_stage[[i]], c("predict_stage"))
    expect_equal(nrow(sched_pre_post$model_stage[[i]]), 1L)
  }

  for (i in seq_along(sched_pre_post$model_stage)) {
    expect_named(
      sched_pre_post$model_stage[[i]]$predict_stage[[1]],
      c("post_stage")
    )

    pre_grid_i <-
      sched_pre_post |>
      slice(i) |>
      select(threshold, disp_df)

    post_grid_i <-
      pre_grid_i |>
      inner_join(grids_post, by = dplyr::join_by(threshold, disp_df)) |>
      purrr::pluck("data") |>
      purrr::pluck(1) |>
      arrange(lower_limit)

    expect_identical(
      sched_pre_post$model_stage[[i]]$predict_stage[[1]]$post_stage[[1]] |>
        arrange(lower_limit),
      post_grid_i
    )
  }

  expect_s3_class(
    sched_pre_post,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe + model, no submodels, regular grid", {
  skip_if_not_installed("splines2")

  wflow_pre_model <- workflow(rec_tune, mod_tune_no_submodel)
  grid_pre_model <-
    extract_parameter_set_dials(wflow_pre_model) |>
    dials::grid_regular()

  grid_pre <-
    grid_pre_model |>
    distinct(threshold, disp_df)

  grid_model <-
    grid_pre_model |>
    distinct(min_n) |>
    arrange(min_n)

  sched_pre_model <- schedule_grid(grid_pre_model, wflow_pre_model)

  expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage"))
  expect_equal(
    sched_pre_model |> select(-model_stage) |> tibble::as_tibble(),
    grid_pre
  )

  for (i in seq_along(sched_pre_model$model_stage)) {
    expect_named(sched_pre_model$model_stage[[i]], c("min_n", "predict_stage"))
    expect_equal(
      sched_pre_model$model_stage[[i]] |> select(min_n) |> arrange(min_n),
      grid_model
    )
  }

  for (i in seq_along(sched_pre_model$model_stage)) {
    expect_named(
      sched_pre_model$model_stage[[i]]$predict_stage[[1]],
      c("post_stage")
    )

    expect_equal(
      nrow(sched_pre_model$model_stage[[i]]$predict_stage[[1]]),
      1L
    )
    expect_equal(
      nrow(sched_pre_model$model_stage[[i]]$predict_stage[[1]]$post_stage[[1]]),
      1L
    )
  }

  expect_s3_class(
    sched_pre_model,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe + model, submodels, irregular grid", {
  skip_if_not_installed("splines2")

  wflow_pre_model <- workflow(rec_tune, mod_tune)
  grid_pre_model <-
    extract_parameter_set_dials(wflow_pre_model) |>
    dials::grid_regular() |>
    # This will make the submodel parameter (trees) unbalanced for some
    # combination of parameters of the other parameters.
    slice(-c(1, 2, 11))

  grid_pre <-
    grid_pre_model |>
    distinct(threshold, disp_df)

  grid_model <-
    grid_pre_model |>
    dplyr::group_nest(threshold, disp_df) |>
    mutate(
      data = purrr::map(
        data,
        \(.x) .x |> dplyr::summarize(trees = max(trees), .by = c(min_n))
      ),
      data = purrr::map(data, \(.x) .x |> arrange(min_n))
    )

  sched_pre_model <- schedule_grid(grid_pre_model, wflow_pre_model)

  expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage"))
  expect_equal(
    sched_pre_model |> select(-model_stage) |> tibble::as_tibble(),
    grid_pre
  )

  for (i in seq_along(sched_pre_model$model_stage)) {
    model_i <- sched_pre_model$model_stage[[i]]
    expect_named(model_i, c("trees", "min_n", "predict_stage"))
    expect_equal(
      model_i |> select(min_n, trees) |> arrange(min_n),
      grid_model$data[[i]]
    )

    for (j in seq_along(sched_pre_model$model_stage[[i]]$predict_stage)) {
      predict_j <- model_i$predict_stage[[j]]

      # We need to figure out the trees that need predicting for the current
      # set of other parameters.

      # Get the settings that have already be resolved:
      other_ij <-
        model_i |>
        select(-predict_stage, -trees) |>
        slice(j) |>
        vctrs::vec_cbind(
          sched_pre_model |>
            select(threshold, disp_df) |>
            slice(i)
        )
      # What are the matching values from the grid?
      trees_ij <-
        grid_pre_model |>
        inner_join(other_ij, by = c("min_n", "threshold", "disp_df")) |>
        select(trees)

      expect_equal(
        predict_j |> select(trees) |> arrange(trees),
        trees_ij |> arrange(trees)
      )
    }
  }

  expect_s3_class(
    sched_pre_model,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("grid processing schedule - recipe + model + tailor, submodels, irregular grid", {
  skip_if_not_installed("splines2")
  skip_if_not_installed("probably")

  wflow_pre_model_post <- workflow(rec_tune, mod_tune, tlr_tune)
  grid_pre_model_post <-
    extract_parameter_set_dials(wflow_pre_model_post) |>
    update(lower_limit = dials::lower_limit(c(0, 1))) |>
    dials::grid_regular() |>
    # This will make the submodel parameter (trees) unbalanced for some
    # combination of parameters of the other parameters.
    slice(seq(1, 240, by = 7))

  grid_pre <-
    grid_pre_model_post |>
    distinct(threshold, disp_df)

  grid_model <-
    grid_pre_model_post |>
    select(-lower_limit) |>
    dplyr::group_nest(threshold, disp_df) |>
    mutate(
      data = purrr::map(
        data,
        \(.x) .x |> dplyr::summarize(trees = max(trees), .by = c(min_n))
      ),
      data = purrr::map(data, \(.x) .x |> arrange(min_n))
    )

  sched_pre_model_post <- schedule_grid(
    grid_pre_model_post,
    wflow_pre_model_post
  )

  expect_named(sched_pre_model_post, c("threshold", "disp_df", "model_stage"))
  expect_equal(
    sched_pre_model_post |> select(-model_stage) |> tibble::as_tibble(),
    grid_pre
  )

  for (i in seq_along(sched_pre_model_post$model_stage)) {
    model_i <- sched_pre_model_post$model_stage[[i]]

    # Get the current set of preproc parameters to remove
    other_i <-
      sched_pre_model_post[i, ] |>
      dplyr::select(-model_stage)

    # We expect to evaluate these specific models for this set of preprocessors
    exp_i <-
      grid_pre_model_post |>
      inner_join(other_i, by = c("threshold", "disp_df")) |>
      arrange(trees, min_n, lower_limit) |>
      select(trees, min_n, lower_limit)

    # What we will evaluate:
    subgrid_i <-
      model_i |>
      select(-trees) |>
      unnest(predict_stage) |>
      unnest(post_stage) |>
      arrange(trees, min_n, lower_limit) |>
      select(trees, min_n, lower_limit)

    expect_equal(subgrid_i, exp_i)

    for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) {
      model_ij <- model_i[j, ]
      expect_named(model_ij, c("trees", "min_n", "predict_stage"))

      predict_j <- model_ij$predict_stage[[1]]
      expect_named(predict_j, c("trees", "post_stage"))

      exp_post_grid <-
        # Condition on the current set of non-submodel or post param to see
        # what we should be evaluating:
        model_ij |>
        dplyr::select(-trees) |>
        vctrs::vec_cbind(other_i) |>
        dplyr::inner_join(
          grid_pre_model_post,
          by = c("threshold", "disp_df", "min_n")
        ) |>
        dplyr::select(trees, lower_limit) |>
        dplyr::arrange(trees, lower_limit)

      # Which as scheduled to be evaluated:
      subgrid_ij <-
        predict_j |>
        unnest(post_stage) |>
        dplyr::arrange(trees, lower_limit)

      expect_equal(subgrid_ij, exp_post_grid)
    }
  }

  expect_s3_class(
    sched_pre_model_post,
    c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
  )
})

test_that("parameter information with engine, recipe, and tailor parameters", {
  mlp_spec <-
    mlp(
      hidden_units = tune(),
      penalty = tune(),
      learn_rate = tune(),
      epochs = 500,
      activation = tune()
    ) |>
    set_engine(
      "brulee",
      stop_iter = tune(),
      class_weights = tune()
    ) |>
    set_mode("classification")

  rec <- recipe(Class ~ ., data = tibble(Class = "a", x = 1)) |>
    step_pca(all_numeric_predictors(), num_comp = tune())

  tlr <- tailor() |> adjust_probability_threshold(threshold = tune())

  mlp_wflow <- workflow(rec, mlp_spec, tlr)

  mlp_grid <-
    mlp_wflow |>
    extract_parameter_set_dials() |>
    grid_space_filling(size = 4)

  mlp_info <- tune:::get_param_info(mlp_wflow)
  expect_true(all(!is.na(mlp_info$has_submodel)))
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.