tests/testthat/test-extract.R

test_that("tune recipe only", {
  skip_if_not_installed("kernlab")

  helper_objects <- helper_objects_tune()
  set.seed(363)
  mt_folds <- rsample::vfold_cv(mtcars, v = 5)

  extr_1_1 <- function(x) {
    extract_recipe(x) |> tidy(number = 2)
  }
  before_kind <- RNGkind()[[1]]
  expect_no_error(
    res_1_1 <-
      workflow() |>
      add_recipe(helper_objects$rec_tune_1) |>
      add_model(helper_objects$lm_mod) |>
      tune_grid(
        resamples = mt_folds,
        control = control_grid(extract = extr_1_1)
      )
  )
  after_kind <- RNGkind()[[1]]
  expect_equal(before_kind, after_kind)
  expect_no_error(extract_1_1 <- dplyr::bind_rows(res_1_1$.extracts))

  expect_true(all(names(extract_1_1) == c("num_comp", ".extracts", ".config")))
  expect_true(
    all(purrr::map_lgl(extract_1_1$.extracts, tibble::is_tibble)),
  )
})

# ------------------------------------------------------------------------------

test_that("tune model only", {
  skip_if_not_installed("kernlab")

  helper_objects <- helper_objects_tune()
  set.seed(363)
  mt_folds <- rsample::vfold_cv(mtcars, v = 5)

  extr_2_1 <- function(x) {
    mod <- extract_fit_engine(x)
    tibble(index = mod@alphaindex[[1]], estimate = mod@coef[[1]])
  }

  expect_no_error(
    res_2_1 <-
      workflow() |>
      add_recipe(helper_objects$rec_no_tune_1) |>
      add_model(helper_objects$svm_mod) |>
      tune_grid(
        resamples = mt_folds,
        grid = 2,
        control = control_grid(extract = extr_2_1)
      )
  )
  expect_no_error(extract_2_1 <- dplyr::bind_rows(res_2_1$.extracts))

  expect_true(all(names(extract_2_1) == c("cost", ".extracts", ".config")))
  expect_true(
    all(purrr::map_lgl(extract_2_1$.extracts, tibble::is_tibble)),
  )
  expect_true(
    all(purrr::map_lgl(
      extract_2_1$.extracts,
      \(.x) all(names(.x) == c("index", "estimate"))
    )),
  )

  extr_2_2 <- function(x) {
    tibble(is_null_rec = is.null(extract_recipe(x)))
  }

  # should not fail:
  expect_no_error(
    res_2_2 <-
      workflow() |>
      add_recipe(helper_objects$rec_tune_1) |>
      add_model(helper_objects$lm_mod) |>
      tune_grid(
        resamples = mt_folds,
        grid = 2,
        control = control_grid(extract = extr_2_2)
      )
  )

  expect_no_error(
    extract_2_2 <-
      dplyr::bind_rows(res_2_2$.extracts) |>
      tidyr::unnest(cols = c(.extracts))
  )
  expect_true(all(!extract_2_2$is_null_rec))
})

test_that("mis-specified extract function", {
  wf <-
    workflows::workflow(
      preprocessor = mpg ~ .,
      spec = parsnip::linear_reg()
    )

  set.seed(1)
  boots <- rsample::bootstraps(mtcars, 3)

  raise_warning <- function(x) {
    warning("AHHH")
    TRUE
  }
  raise_error <- function(x) {
    stop("AHHH")
    TRUE
  }
  raise_both <- function(x) {
    warning("AH")
    stop("AHHH")
    TRUE
  }
  raise_error_once <- local({
    first <- TRUE

    function(x) {
      if (first) {
        first <<- FALSE
        stop("oh no")
      }

      "hi"
    }
  })

  expect_snapshot(
    res_extract_warning <-
      fit_resamples(
        wf,
        boots,
        control = control_resamples(extract = raise_warning)
      ),
    transform = catalog_lines
  )

  expect_snapshot(
    res_extract_error <-
      fit_resamples(
        wf,
        boots,
        control = control_resamples(extract = raise_error)
      ),
    transform = catalog_lines
  )

  expect_snapshot(
    res_extract_both <-
      fit_resamples(
        wf,
        boots,
        control = control_resamples(extract = raise_both)
      ),
    transform = catalog_lines
  )

  expect_snapshot(
    res_extract_error_once <-
      fit_resamples(
        wf,
        boots,
        control = control_resamples(extract = raise_error_once)
      ),
    transform = catalog_lines
  )

  expect_snapshot(res_extract_warning)
  expect_snapshot(res_extract_error)
  expect_snapshot(res_extract_both)
  expect_snapshot(res_extract_error_once)

  expect_true(res_extract_warning$.extracts[[1]]$.extracts[[1]])
  expect_snapshot(res_extract_warning$.notes[[1]])

  expect_snapshot(res_extract_error$.extracts[[1]]$.extracts[[1]])
  expect_snapshot(res_extract_error$.notes[[1]])

  expect_snapshot(res_extract_both$.extracts[[1]]$.extracts[[1]])

  expect_snapshot(res_extract_error_once$.extracts[[1]]$.extracts[[1]])
  expect_equal(res_extract_error_once$.extracts[[2]]$.extracts[[1]], "hi")
  expect_equal(res_extract_error_once$.extracts[[3]]$.extracts[[1]], "hi")
})

# ------------------------------------------------------------------------------

test_that("tune model and recipe", {
  skip_if_not_installed("kernlab")

  helper_objects <- helper_objects_tune()
  set.seed(363)
  mt_folds <- rsample::vfold_cv(mtcars, v = 5)

  extr_3_1 <- function(x) {
    x
  }

  wflow_3 <-
    workflow() |>
    add_recipe(helper_objects$rec_tune_1) |>
    add_model(helper_objects$svm_mod)
  set.seed(35)
  grid_3 <-
    extract_parameter_set_dials(wflow_3) |>
    update(num_comp = dials::num_comp(c(2, 5))) |>
    dials::grid_space_filling(size = 4)

  expect_no_error(
    res_3_1 <- tune_grid(
      wflow_3,
      resamples = mt_folds,
      grid = grid_3,
      control = control_grid(extract = extr_3_1)
    )
  )
  expect_no_error(extract_3_1 <- dplyr::bind_rows(res_3_1$.extracts))

  expect_named(extract_3_1, c("cost", "num_comp", ".extracts", ".config"))
  expect_true(
    all(purrr::map_lgl(extract_3_1$.extracts, inherits, "workflow")),
  )
})

# ------------------------------------------------------------------------------

test_that("check .config in extracts", {
  load(test_path("data", "test_objects.RData"))

  # recipe only
  for (i in 1:nrow(mt_spln_lm_grid)) {
    expect_true(any(names(mt_spln_lm_grid$.extracts[[i]]) == ".config"))
  }

  for (i in 1:nrow(mt_spln_lm_bo)) {
    expect_true(any(names(mt_spln_lm_bo$.extracts[[i]]) == ".config"))
  }

  recipe_only_configs <-
    full_join(
      mt_spln_lm_bo |>
        filter(id == first(id)) |>
        select(.iter, .metrics) |>
        unnest(cols = .metrics) |>
        filter(.metric == first(.metric)),
      mt_spln_lm_bo |>
        filter(id == first(id)) |>
        select(.iter, .extracts) |>
        unnest(cols = .extracts),
      by = c(".iter", "deg_free")
    )

  expect_equal(
    recipe_only_configs$.config.x,
    recipe_only_configs$.config.y
  )

  # recipe and model
  for (i in 1:nrow(mt_spln_knn_grid)) {
    expect_true(any(names(mt_spln_knn_grid$.extracts[[i]]) == ".config"))
  }

  for (i in 1:nrow(mt_spln_knn_bo)) {
    expect_true(any(names(mt_spln_knn_bo$.extracts[[i]]) == ".config"))
  }

  # model only
  for (i in 1:nrow(mt_knn_grid)) {
    expect_true(any(names(mt_knn_grid$.extracts[[i]]) == ".config"))
  }

  for (i in 1:nrow(mt_knn_bo)) {
    expect_true(any(names(mt_knn_bo$.extracts[[i]]) == ".config"))
  }
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Sept. 1, 2025, 5:10 p.m.