tests/testthat/test-select_best.R

# ------------------------------------------------------------------------------
# library(tidymodels)
# set.seed(7898)
# data_folds <- vfold_cv(mtcars, repeats = 2)
#
# base_rec <-
#   recipe(mpg ~ ., data = mtcars) %>%
#   step_normalize(all_predictors())
#
# disp_rec <-
#   base_rec %>%
#   step_bs(disp, degree = tune(), deg_free = tune()) %>%
#   step_bs(wt, degree = tune("wt degree"), deg_free = tune("wt df"))
#
# lm_model <-
#   linear_reg(mode = "regression") %>%
#   set_engine("lm")
#
# cars_wflow <-
#   workflow() %>%
#   add_recipe(disp_rec) %>%
#   add_model(lm_model)
#
# cars_set <-
#   cars_wflow %>%
#   parameters %>%
#   update(degree = degree_int(1:2)) %>%
#   update(deg_free = deg_free(c(2, 10))) %>%
#   update(`wt degree` = degree_int(1:2)) %>%
#   update(`wt df` = deg_free(c(2, 10)))
#
# set.seed(255)
# cars_grid <-
#   cars_set %>%
#   grid_regular(levels = c(3, 2, 3, 2))
#
#
# cars_res <- tune_grid(cars_wflow,
#                       resamples = data_folds,
#                       grid = cars_grid,
#                       control = control_grid(verbose = TRUE, save_pred = TRUE))
# saveRDS(cars_res,
#         file = testthat::test_path("data", "rcv_results.rds"),
#         version = 2, compress = "xz")

# ------------------------------------------------------------------------------

test_that("select_best()", {
  options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf)

  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))

  expect_true(
    tibble::is_tibble(select_best(rcv_results, metric = "rmse"))
  )
  best_rmse <-
    tibble::tribble(
      ~deg_free, ~degree, ~`wt df`, ~`wt degree`,
      6L,        2L,      2L,       1L
    )
  best_rsq <-
    tibble::tribble(
      ~deg_free, ~degree, ~`wt df`, ~`wt degree`,
      10L,       2L,      2L,       2L
    )

  expect_equal(
    select_best(rcv_results, metric = "rmse") %>% select(-.config),
    best_rmse
  )
  expect_equal(
    select_best(rcv_results, metric = "rsq") %>% select(-.config),
    best_rsq
  )
  expect_snapshot(
    select_best(rcv_results, metric = "rsq", maximize = TRUE)
  )

  expect_snapshot(error = TRUE, {
    select_best(rcv_results, metric = "random")
  })
  expect_snapshot(error = TRUE, {
    select_best(rcv_results, metric = c("rmse", "rsq"))
  })
  expect_snapshot({
    best_default_metric <- select_best(rcv_results)
    best_rmse <- select_best(rcv_results, metric = "rmse")
  })
  expect_equal(best_default_metric, best_rmse)

  expect_snapshot(error = TRUE, {
    select_best(mtcars, metric = "disp")
  })
})


test_that("show_best()", {
  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))

  rcv_rmse <-
    rcv_results %>%
    collect_metrics() %>%
    dplyr::filter(.metric == "rmse") %>%
    dplyr::arrange(mean)

  expect_equal(
    show_best(rcv_results, metric = "rmse", n = 1),
    rcv_rmse %>% slice(1)
  )
  expect_equal(
    show_best(rcv_results, metric = "rmse", n = nrow(rcv_rmse) + 1),
    rcv_rmse
  )
  expect_equal(
    show_best(rcv_results, metric = "rmse", n = 1) %>% names(),
    rcv_rmse %>% names()
  )
  expect_snapshot({
    best_default_metric <- show_best(rcv_results)
    best_rmse <- show_best(rcv_results, metric = "rmse")
  })
  expect_equal(best_default_metric, best_rmse)

  expect_snapshot(error = TRUE, {
    show_best(mtcars, metric = "disp")
  })
})

test_that("one-std error rule", {
  options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf)

  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))
  knn_results <- readRDS(test_path("data", "knn_results.rds"))

  expect_true(
    tibble::is_tibble(select_by_one_std_err(knn_results, metric = "accuracy", K))
  )

  expect_equal(
    select_by_one_std_err(rcv_results, metric = "rmse", deg_free, `wt degree`)$mean,
    2.94252798698909
  )
  expect_equal(
    select_by_one_std_err(knn_results, metric = "accuracy", K)$K,
    25L
  )

  expect_snapshot(
    select_by_one_std_err(knn_results, metric = "accuracy", K, maximize = TRUE)
  )

  expect_snapshot(error = TRUE, {
    select_by_one_std_err(rcv_results, metric = "random", deg_free)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(rcv_results, metric = c("rmse", "rsq"), deg_free)
  })
  expect_snapshot({
    select_via_default_metric <- select_by_one_std_err(knn_results, K)
    select_via_roc <- select_by_one_std_err(knn_results, K, metric = "roc_auc")
  })
  expect_equal(select_via_default_metric, select_via_roc)

  expect_snapshot(error = TRUE, {
    select_by_one_std_err(rcv_results, metric = "random")
  })

  expect_snapshot(error = TRUE, {
    select_by_one_std_err(mtcars, metric = "disp")
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk, K)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk, Kay)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk, desc(K))
  })
})


test_that("percent loss", {
  options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf)

  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))
  knn_results <- readRDS(test_path("data", "knn_results.rds"))

  expect_true(
    tibble::is_tibble(select_by_pct_loss(knn_results, metric = "accuracy", K))
  )
  expect_equal(
    select_by_pct_loss(rcv_results, metric = "rmse", deg_free, `wt degree`)$mean,
    2.94252798698909
  )
  expect_equal(
    select_by_pct_loss(knn_results, metric = "accuracy", K)$K,
    12L
  )

  expect_snapshot(
    select_by_pct_loss(knn_results, metric = "accuracy", K, maximize = TRUE)
  )

  expect_snapshot(error = TRUE, {
    select_by_pct_loss(rcv_results, metric = "random", deg_free)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(rcv_results, metric = c("rmse", "rsq"), deg_free)
  })
  expect_snapshot({
    select_via_default_metric <- select_by_pct_loss(knn_results, K)
    select_via_roc <- select_by_pct_loss(knn_results, K, metric = "roc_auc")
  })
  expect_equal(select_via_default_metric, select_via_roc)

  expect_snapshot(error = TRUE, {
    select_by_pct_loss(rcv_results, metric = "random")
  })

  expect_snapshot(error = TRUE, {
    select_by_pct_loss(mtcars, metric = "disp")
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk, K)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk, Kay)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk, desc(K))
  })

  data("example_ames_knn")
  expect_equal(
    select_by_pct_loss(ames_grid_search, metric = "rmse", limit = 10, desc(K))$K,
    40
  )
})

test_that("select_by_* can handle metrics with direction == 'zero'", {
  skip_on_cran()

  set.seed(1)
  resamples <- bootstraps(mtcars, times = 5)

  set.seed(1)
  tune_res <-
    tune::tune_grid(
      nearest_neighbor(mode = "regression", neighbors = tune()),
      mpg ~ .,
      resamples,
      metrics = yardstick::metric_set(yardstick::mpe, yardstick::msd)
    )

  tune_res_metrics <- tune_res %>% collect_metrics()

  expect_equal(
    select_best(tune_res, metric = "msd")$.config,
    tune_res_metrics %>%
      filter(.metric == "msd") %>%
      arrange(abs(mean)) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )

  expect_equal(
    select_best(tune_res, metric = "mpe")$.config,
    tune_res_metrics %>%
      filter(.metric == "mpe") %>%
      arrange(abs(mean)) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )

  expect_equal(
    show_best(tune_res, metric = "msd", n = 5)$.config,
    tune_res_metrics %>%
      filter(.metric == "msd") %>%
      arrange(abs(mean)) %>%
      slice(1:5) %>%
      select(.config) %>%
      pull()
  )

  expect_equal(
    show_best(tune_res, metric = "mpe", n = 5)$.config,
    tune_res_metrics %>%
      filter(.metric == "mpe") %>%
      arrange(abs(mean)) %>%
      slice(1:5) %>%
      select(.config) %>%
      pull()
  )

  # one std error, msd ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "msd") %>%
    arrange(min(abs(mean))) %>%
    slice(1)

  bound_lower <- -abs(best$mean) - abs(best$std_err)
  bound_upper <- abs(best$mean) + abs(best$std_err)
  expect_equal(bound_lower, -bound_upper)

  simplest_within_bound <-
    tune_res_metrics %>%
    filter(.metric == "msd") %>%
    filter(abs(mean) < bound_upper) %>%
    arrange(desc(neighbors)) %>%
    slice(1)

  expect_equal(
    select_by_one_std_err(tune_res, metric = "msd", desc(neighbors))$.config,
    simplest_within_bound$.config
  )

  # one std error, mpe ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "mpe") %>%
    arrange(min(abs(mean))) %>%
    slice(1)

  bound_lower <- -abs(best$mean) - abs(best$std_err)
  bound_upper <- abs(best$mean) + abs(best$std_err)
  expect_equal(bound_lower, -bound_upper)

  simplest_within_bound <-
    tune_res_metrics %>%
    filter(.metric == "mpe") %>%
    filter(abs(mean) < bound_upper) %>%
    arrange(desc(neighbors)) %>%
    slice(1)

  expect_equal(
    select_by_one_std_err(tune_res, metric = "mpe", desc(neighbors))$.config,
    simplest_within_bound$.config
  )

  # pct loss, msd ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "msd") %>%
    arrange(abs(mean)) %>%
    slice(1)

  expect_equal(
    select_by_pct_loss(tune_res, metric = "msd", limit = 10, desc(neighbors))$.config,
    tune_res_metrics %>%
      filter(.metric == "msd") %>%
      rowwise() %>%
      mutate(loss = abs((abs(mean) - abs(best$mean)) / best$mean) * 100) %>%
      ungroup() %>%
      arrange(desc(neighbors)) %>%
      slice(1:which(.config == best$.config)) %>%
      filter(loss < 10) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )

  # pct loss, mpe ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "mpe") %>%
    arrange(abs(mean)) %>%
    slice(1)

  expect_equal(
    select_by_pct_loss(tune_res, metric = "mpe", limit = 10, desc(neighbors))$.config,
    tune_res_metrics %>%
      filter(.metric == "mpe") %>%
      rowwise() %>%
      mutate(loss = abs((abs(mean) - abs(best$mean)) / best$mean) * 100) %>%
      ungroup() %>%
      arrange(desc(neighbors)) %>%
      slice(1:which(.config == best$.config)) %>%
      filter(loss < 10) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Aug. 24, 2023, 1:09 a.m.