tests/testthat/test-sa-overall.R

test_that("formula interface", {
  skip_on_cran()
  expect_snapshot({
    set.seed(1)
    res <- f_wflow %>%
      tune_sim_anneal(cell_folds,
                      iter = 2,
                      control = control_sim_anneal(verbose = TRUE)
      )
  })
  expect_equal(class(res), c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame"))
  expect_true(nrow(collect_metrics(res)) == 9)
  expect_equal(res, .Last.tune.result)
  expect_null(.get_tune_eval_times(res))
  expect_null(.get_tune_eval_time_target(res))
})

# ------------------------------------------------------------------------------

test_that("recipe interface", {
  skip_on_cran()
  skip_on_os("windows")
  skip_on_os("linux")

  expect_silent({
    set.seed(1)
    res <- rec_wflow %>%
      tune_sim_anneal(cell_folds,
                      iter = 2,
                      control = control_sim_anneal(verbose = FALSE, verbose_iter = FALSE)
      )
  })

  expect_equal(class(res), c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame"))
  expect_true(nrow(collect_metrics(res)) == 9)
  expect_equal(res, .Last.tune.result)
})

# ------------------------------------------------------------------------------

test_that("variable interface", {
  skip_on_cran()
  expect_snapshot({
    set.seed(1)
    res <- var_wflow %>%
      tune_sim_anneal(cell_folds,
                      iter = 2,
                      control = control_sim_anneal(verbose = TRUE, verbose_iter = TRUE)
      )
  })
  expect_equal(class(res), c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame"))
  expect_true(nrow(collect_metrics(res)) == 9)
  expect_equal(res, .Last.tune.result)

  # Check to see if iterations are picked up when an iterative object is used
  # as the initial object

  expect_snapshot({
    set.seed(1)
    new_res <- var_wflow %>%
      tune_sim_anneal(cell_folds,
                      iter = 2, initial = res,
                      control = control_sim_anneal(verbose = FALSE)
      )
  })
  expect_true(nrow(collect_metrics(new_res)) == 15)
  expect_true(max(new_res$.iter) == 4)
  expect_true(sum(grepl("^initial", collect_metrics(new_res)$.config)) == 9)
  expect_equal(new_res, .Last.tune.result)

  # but not for non-iterative objects
  set.seed(1)
  grid_res <- var_wflow %>%
    tune_grid(cell_folds, grid = 2)

  expect_snapshot({
    set.seed(1)
    new_new_res <- var_wflow %>%
      tune_sim_anneal(cell_folds,
                      iter = 2, initial = grid_res,
                      control = control_sim_anneal(verbose = FALSE)
      )
  })
  expect_true(nrow(collect_metrics(new_new_res)) == 12)
  expect_true(max(new_new_res$.iter) == 2)
  expect_true(sum(grepl("^initial", collect_metrics(new_new_res)$.config)) == 6)
  expect_equal(new_new_res, .Last.tune.result)
})


test_that("unfinalized parameters", {
  skip_on_cran()
  skip_on_os("windows")
  skip_on_os("linux")

  data(two_class_dat, package = "modeldata")

  set.seed(5046)
  bt <- bootstraps(two_class_dat, times = 5)

  rec_example <- recipe(Class ~ ., data = two_class_dat)

  # RF
  model_rf <- rand_forest(mtry = tune()) %>%
    set_mode("classification") %>%
    set_engine("ranger")

  wf_rf <- workflow() %>%
    add_model(model_rf) %>%
    add_recipe(rec_example)

  set.seed(30)
  rf_res <- wf_rf %>%
    tune_grid(resamples = bt, grid = 4)

  expect_snapshot({
    set.seed(40)
    rf_res_finetune <- wf_rf %>%
      tune_sim_anneal(resamples = bt, initial = rf_res)
  })

  # don't supply an initial grid (#39)
  expect_snapshot({
    set.seed(40)
    rf_res_finetune <- wf_rf %>%
      tune_sim_anneal(resamples = bt)
  })
})

test_that("incompatible parameter objects", {
  skip_on_cran()

  skip_if_not_installed("ranger")
  skip_if_not_installed("modeldata")
  skip_if_not_installed("rsample")

  rf_spec <- parsnip::rand_forest(mode = "regression", mtry = tune::tune())

  set.seed(1)
  grid_with_bigger_range <-
    dials::grid_latin_hypercube(dials::mtry(range = c(1, 16)))

  set.seed(1)
  car_folds <- rsample::vfold_cv(car_prices, v = 2)

  car_wflow <- workflows::workflow() %>%
    workflows::add_formula(Price ~ .) %>%
    workflows::add_model(rf_spec)

  set.seed(1)
  tune_res_with_bigger_range <- tune::tune_grid(
    car_wflow,
    resamples = car_folds,
    grid = grid_with_bigger_range
  )

  set.seed(1)
  parameter_set_with_smaller_range <-
    dials::parameters(dials::mtry(range = c(1, 5)))

  scrub_best <- function(lines) {
    has_best <- grepl("Initial best", lines)
    lines[has_best] <- ""
    lines
  }

  set.seed(1)
  expect_snapshot(error = TRUE, transform = scrub_best, {
    res <-
      tune_sim_anneal(
        car_wflow,
        param_info = parameter_set_with_smaller_range,
        resamples = car_folds,
        initial = tune_res_with_bigger_range,
        iter = 2
      )
  })
})

test_that("set event-level", {
  # See issue 40
  skip_if_not_installed("rpart")
  skip_if_not_installed("modeldata")
  skip_if_not_installed("yardstick")
  skip_if_not_installed("rsample")
  skip_on_cran()

  # ------------------------------------------------------------------------------

  set.seed(1)
  dat <- modeldata::sim_classification(500, intercept = 8)

  # We should get high sensitivity and low specificity when event_level = "first"
  # count(dat, class)
  # levels(dat$class)

  set.seed(2)
  rs <- vfold_cv(dat, strata = class)

  cart_spec <- decision_tree(min_n = tune()) %>% set_mode("classification")

  stats <- metric_set(accuracy, sensitivity, specificity)

  # ------------------------------------------------------------------------------
  # high sensitivity and low specificity


  set.seed(3)
  cart_res_first <-
    cart_spec %>%
    tune_sim_anneal(class ~ .,
                    rs,
                    control = control_sim_anneal(event_level = "first", verbose_iter = FALSE),
                    metrics = stats)

  results_first <-
    cart_res_first %>%
    collect_metrics() %>%
    dplyr::filter(.metric != "accuracy") %>%
    dplyr::select(.config, .metric, mean) %>%
    tidyr::pivot_wider(id_cols = .config, names_from = .metric, values_from = mean)

  dir_check <- all(results_first$sensitivity > results_first$specificity)
  expect_true(dir_check)

  # ------------------------------------------------------------------------------
  # Now reversed

  set.seed(3)
  cart_res_second <-
    cart_spec %>%
    tune_sim_anneal(class ~ .,
                    rs,
                    control = control_sim_anneal(event_level = "second", verbose_iter = FALSE),
                    metrics = stats)

  results_second <-
    cart_res_second %>%
    collect_metrics() %>%
    dplyr::filter(.metric != "accuracy") %>%
    dplyr::select(.config, .metric, mean) %>%
    tidyr::pivot_wider(id_cols = .config, names_from = .metric, values_from = mean)

  rev_dir_check <- all(results_second$sensitivity < results_second$specificity)
  expect_true(rev_dir_check)

})
tidymodels/finetune documentation built on March 23, 2024, 6:50 p.m.