tests/testthat/test-finalize.R

test_that("cannot finalize with recipe parameters", {
  skip_if_not_installed("randomForest")
  skip_if_not_installed("splines2")

  set.seed(21983)
  rs <- rsample::vfold_cv(mtcars)

  mod_1 <-
    parsnip::rand_forest(mtry = tune(), trees = 20, min_n = tune()) |>
    parsnip::set_engine("randomForest") |>
    parsnip::set_mode("regression")

  rec_1 <-
    recipes::recipe(mpg ~ ., data = mtcars) |>
    recipes::step_spline_natural(disp, deg_free = tune())

  rec_2 <-
    recipes::recipe(mpg ~ ., data = mtcars) |>
    recipes::step_spline_natural(disp, deg_free = 3)

  expect_snapshot(error = TRUE, {
    mod_1 |> tune_grid(rec_1, resamples = rs, grid = 3)
  })

  set.seed(987323)
  expect_no_error(
    suppressMessages(mod_1 |> tune_grid(rec_2, resamples = rs, grid = 3))
  )
})


test_that("skip error if grid is supplied", {
  skip_if_not_installed("randomForest")
  skip_if_not_installed("splines2")

  set.seed(21983)
  rs <- rsample::vfold_cv(mtcars)

  mod_1 <-
    parsnip::rand_forest(mtry = tune(), trees = 20, min_n = tune()) |>
    parsnip::set_engine("randomForest") |>
    parsnip::set_mode("regression")

  rec_1 <-
    recipes::recipe(mpg ~ ., data = mtcars) |>
    recipes::step_spline_natural(disp, deg_free = tune())

  grid <- tibble::tibble(mtry = 1:3, deg_free = c(3, 3, 4), min_n = c(5, 4, 6))

  set.seed(987323)
  expect_no_error(
    mod_1 |> tune_grid(rec_1, resamples = rs, grid = grid)
  )
})


test_that("finalize recipe step with multiple tune parameters", {
  skip_if_not_installed("modeldata")
  skip_if_not_installed("splines2")

  data(biomass, package = "modeldata")

  model_spec <- parsnip::linear_reg() |>
    parsnip::set_engine("lm")

  rec <- recipes::recipe(
    HHV ~ carbon + hydrogen + oxygen + nitrogen + sulfur,
    data = biomass
  ) |>
    recipes::step_spline_b(carbon, hydrogen, deg_free = tune(), degree = tune())

  best <- tibble(deg_free = 2, degree = 1, .config = "Preprocessor1_Model1")

  expect_s3_class(finalize_recipe(rec, best), "recipe")
  expect_equal(finalize_recipe(rec, best)$steps[[1]]$degree, 1)
  expect_equal(finalize_recipe(rec, best)$steps[[1]]$deg_free, 2)
})

# ------------------------------------------------------------------------------
# post-processing

test_that("finalize tailors", {
  skip_if_not_installed("probably")
  skip_if_not_installed("dials", "1.3.0.9000")
  library(tailor)

  adjust_rng <-
    tailor() |>
    adjust_numeric_range(lower_limit = tune(), upper_limit = tune())

  adj_1 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2))
  expect_equal(adj_1$adjustments[[1]]$arguments$lower_limit, 2)
  expect_equal(adj_1$adjustments[[1]]$arguments$upper_limit, tune())

  adj_2 <- finalize_tailor(adjust_rng, tibble(lower_limit = 2, upper_limit = 3))
  expect_equal(adj_2$adjustments[[1]]$arguments$lower_limit, 2)
  expect_equal(adj_2$adjustments[[1]]$arguments$upper_limit, 3)

  adj_3 <- finalize_tailor(
    adjust_rng,
    tibble(lower_limit = 2, upper_limit = 3, a = 2)
  )
  expect_equal(adj_3$adjustments[[1]]$arguments$lower_limit, 2)
  expect_equal(adj_3$adjustments[[1]]$arguments$upper_limit, 3)

  adj_4 <- finalize_tailor(adjust_rng, tibble())
  expect_equal(adj_4, adjust_rng)

  expect_snapshot(
    finalize_tailor(linear_reg(), tibble()),
    error = TRUE
  )
})

test_that("finalize workflows with tailors", {
  skip_if_not_installed("probably")
  skip_if_not_installed("dials", "1.3.0.9000")
  library(tailor)
  library(purrr)

  adjust_rng <-
    tailor() |>
    adjust_numeric_range(lower_limit = tune(), upper_limit = tune())
  wflow <- workflow(y ~ ., linear_reg(), adjust_rng)

  wflow_1 <- finalize_workflow(wflow, tibble(lower_limit = 2))
  expect_equal(
    wflow_1 |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("lower_limit"),
    2
  )
  expect_equal(
    wflow_1 |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("upper_limit"),
    tune()
  )

  wflow_2 <- finalize_workflow(wflow, tibble(lower_limit = 2, upper_limit = 3))
  expect_equal(
    wflow_2 |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("lower_limit"),
    2
  )
  expect_equal(
    wflow_2 |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("upper_limit"),
    3
  )

  wflow_3 <- finalize_workflow(
    wflow,
    tibble(lower_limit = 2, upper_limit = 3, a = 2)
  )
  expect_equal(
    wflow_3 |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("lower_limit"),
    2
  )
  expect_equal(
    wflow_3 |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("upper_limit"),
    3
  )

  wflow_4 <- finalize_workflow(wflow, tibble())
  expect_equal(wflow_4, wflow)

  # https://github.com/tidymodels/tune/issues/998
  cls_post <- tailor() |>
    adjust_probability_threshold(threshold = tune("cut"))
  wflow_thrsh <- workflow(y ~ ., logistic_reg(), cls_post)
  thrsh_param <- tibble(min_n = 2, cut = 1 / 3)
  wflow_thrsh <- wflow_thrsh |> finalize_workflow(thrsh_param)

  expect_equal(
    wflow_thrsh |>
      extract_postprocessor() |>
      pluck("adjustments") |>
      pluck(1) |>
      pluck("arguments") |>
      pluck("threshold"),
    1 / 3
  )
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.