tests/testthat/test-fit_best.R

test_that("fit_best", {
  skip_if_not_installed("kknn")

  library(recipes)
  library(rsample)
  library(dplyr)
  library(parsnip)

  data(meats, package = "modeldata")
  meats <- meats %>% select(-water, -fat)

  set.seed(1)
  meat_split <- initial_split(meats)
  meat_train <- training(meat_split)
  meat_test <- testing(meat_split)

  set.seed(2)
  meat_rs <- vfold_cv(meat_train, v = 3)

  pca_rec <-
    recipe(protein ~ ., data = meat_train) %>%
    step_pca(all_predictors(), num_comp = tune())

  knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")

  ctrl <- control_grid(save_workflow = TRUE)

  set.seed(3)
  knn_pca_res <-
    tune_grid(knn_mod, pca_rec, resamples = meat_rs, grid = 3, control = ctrl)

  expect_silent(knn_fit <- fit_best(knn_pca_res))
  expect_true(knn_fit$trained)
  expect_silent(fit_best(knn_pca_res, metric = "rsq"))
  expect_snapshot(fit_best(knn_pca_res, verbose = TRUE))
  expect_snapshot(
    tmp <-
      fit_best(knn_pca_res,
               verbose = TRUE,
               parameters = tibble(neighbors = 1, num_comp = 1))
  )

  expect_snapshot_error(
    fit_best(1L)
  )
  expect_snapshot_error(
    fit_best(tibble())
  )
  expect_snapshot_error(
    fit_best(knn_pca_res, metric = "WAT")
  )
  expect_snapshot_error(
    fit_best(knn_pca_res, parameters = tibble())
  )
  expect_snapshot_error(
    fit_best(knn_pca_res, parameters = tibble(neighbors = 2))
  )
  expect_snapshot_error(
    fit_best(knn_pca_res, chickens = 2)
  )
  data(example_ames_knn)
  expect_snapshot_error(
    fit_best(ames_iter_search)
  )
})

test_that("fit_best() works with validation split: 3-way split", {
  skip_if_not_installed("kknn")
  skip_if_not_installed("modeldata")
  data(ames, package = "modeldata", envir = rlang::current_env())

  set.seed(23598723)
  initial_val_split <- rsample::initial_validation_split(ames)
  val_set <- validation_set(initial_val_split)

  f <- Sale_Price ~ Gr_Liv_Area + Year_Built
  knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
  wflow <- workflow(f, knn_mod)

  tune_res <- tune_grid(
    wflow,
    grid = tibble(neighbors = c(1, 5)),
    resamples = val_set,
    control = control_grid(save_workflow = TRUE)
  ) %>% suppressWarnings()
  set.seed(3)
  fit_on_train <- fit_best(tune_res)
  pred <- predict(fit_on_train, testing(initial_val_split))

  set.seed(3)
  exp_fit_on_train <- nearest_neighbor(neighbors = 5) %>%
    set_mode("regression") %>%
    fit(f, training(initial_val_split))
  exp_pred <- predict(exp_fit_on_train, testing(initial_val_split))

  expect_equal(pred, exp_pred)
})

test_that("fit_best() works with validation split: 2x 2-way splits", {
  skip_if_not_installed("kknn")
  skip_if_not_installed("modeldata")
  data(ames, package = "modeldata", envir = rlang::current_env())

  set.seed(23598723)
  split <- rsample::initial_split(ames)
  train_and_val <- training(split)
  val_set <- rsample::validation_split(train_and_val)

  f <- Sale_Price ~ Gr_Liv_Area + Year_Built
  knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
  wflow <- workflow(f, knn_mod)

  tune_res <- tune_grid(
    wflow,
    grid = tibble(neighbors = c(1, 5)),
    resamples = val_set,
    control = control_grid(save_workflow = TRUE)
  )
  set.seed(3)
  fit_on_train_and_val <- fit_best(tune_res)
  pred <- predict(fit_on_train_and_val, testing(split))

  set.seed(3)
  exp_fit_on_train_and_val <- nearest_neighbor(neighbors = 5) %>%
    set_mode("regression") %>%
    fit(f, train_and_val)
  exp_pred <- predict(exp_fit_on_train_and_val, testing(split))

  expect_equal(pred, exp_pred)
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Aug. 24, 2023, 1:09 a.m.