tests/testthat/test-select_best.R

# ------------------------------------------------------------------------------
# objects for these tests are created using inst/test_objects.R

# ------------------------------------------------------------------------------

test_that("select_best()", {
  options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf)

  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))

  expect_true(
    tibble::is_tibble(select_best(rcv_results, metric = "rmse"))
  )
  best_rmse <-
    tibble::tribble(
      ~deg_free, ~degree, ~`wt df`, ~`wt degree`,
      6L,        2L,      2L,       1L
    )
  best_rsq <-
    tibble::tribble(
      ~deg_free, ~degree, ~`wt df`, ~`wt degree`,
      10L,       2L,      2L,       2L
    )

  expect_equal(
    select_best(rcv_results, metric = "rmse") %>% select(-.config),
    best_rmse
  )
  expect_equal(
    select_best(rcv_results, metric = "rsq") %>% select(-.config),
    best_rsq
  )

  expect_snapshot(error = TRUE, {
    select_best(rcv_results, metric = "random")
  })
  expect_snapshot(
    select_best(rcv_results, metric = c("rmse", "rsq"))
  )
  expect_snapshot({
    best_default_metric <- select_best(rcv_results)
    best_rmse <- select_best(rcv_results, metric = "rmse")
  })
  expect_equal(best_default_metric, best_rmse)

  expect_snapshot(error = TRUE, {
    select_best(mtcars, metric = "disp")
  })
})


test_that("show_best()", {
  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))

  rcv_rmse <-
    rcv_results %>%
    collect_metrics() %>%
    dplyr::filter(.metric == "rmse") %>%
    dplyr::arrange(mean)

  expect_equal(
    show_best(rcv_results, metric = "rmse", n = 1),
    rcv_rmse %>% slice(1)
  )
  expect_equal(
    show_best(rcv_results, metric = "rmse", n = nrow(rcv_rmse) + 1),
    rcv_rmse
  )
  expect_equal(
    show_best(rcv_results, metric = "rmse", n = 1) %>% names(),
    rcv_rmse %>% names()
  )
  expect_snapshot({
    best_default_metric <- show_best(rcv_results)
    best_rmse <- show_best(rcv_results, metric = "rmse")
  })
  expect_equal(best_default_metric, best_rmse)

  expect_snapshot(error = TRUE, {
    show_best(mtcars, metric = "disp")
  })
})

test_that("one-std error rule", {
  options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf)

  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))
  knn_results <- readRDS(test_path("data", "knn_results.rds"))

  expect_true(
    tibble::is_tibble(select_by_one_std_err(knn_results, metric = "accuracy", K))
  )

  expect_equal(
    select_by_one_std_err(rcv_results, metric = "rmse", deg_free, `wt degree`)$.config,
    "Preprocessor05_Model1"
  )
  expect_equal(
    select_by_one_std_err(knn_results, metric = "accuracy", K)$K,
    25L
  )

  expect_snapshot(error = TRUE, {
    select_by_one_std_err(rcv_results, metric = "random", deg_free)
  })
  expect_snapshot(
    select_by_one_std_err(rcv_results, metric = c("rmse", "rsq"), deg_free)
  )
  expect_snapshot({
    select_via_default_metric <- select_by_one_std_err(knn_results, K)
    select_via_roc <- select_by_one_std_err(knn_results, K, metric = "roc_auc")
  })
  expect_equal(select_via_default_metric, select_via_roc)

  expect_snapshot(error = TRUE, {
    select_by_one_std_err(rcv_results, metric = "random")
  })

  expect_snapshot(error = TRUE, {
    select_by_one_std_err(mtcars, metric = "disp")
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk, K)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk, Kay)
  })
  expect_snapshot(error = TRUE, {
    select_by_one_std_err(knn_results, metric = "roc_auc", weight_funk, desc(K))
  })
})


test_that("percent loss", {
  options(width = 200, pillar.advice = FALSE, pillar.min_title_chars = Inf)

  rcv_results <- readRDS(test_path("data", "rcv_results.rds"))
  knn_results <- readRDS(test_path("data", "knn_results.rds"))

  expect_true(
    tibble::is_tibble(select_by_pct_loss(knn_results, metric = "accuracy", K))
  )
  expect_equal(
    select_by_pct_loss(rcv_results, metric = "rmse", deg_free, `wt degree`)$.config,
    "Preprocessor05_Model1"
  )
  expect_equal(
    select_by_pct_loss(knn_results, metric = "accuracy", K)$K,
    12L
  )

  expect_snapshot(error = TRUE, {
    select_by_pct_loss(rcv_results, metric = "random", deg_free)
  })
  expect_snapshot(
    select_by_pct_loss(rcv_results, metric = c("rmse", "rsq"), deg_free)
  )
  expect_snapshot({
    select_via_default_metric <- select_by_pct_loss(knn_results, K)
    select_via_roc <- select_by_pct_loss(knn_results, K, metric = "roc_auc")
  })
  expect_equal(select_via_default_metric, select_via_roc)

  expect_snapshot(error = TRUE, {
    select_by_pct_loss(rcv_results, metric = "random")
  })

  expect_snapshot(error = TRUE, {
    select_by_pct_loss(mtcars, metric = "disp")
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk, K)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk, Kay)
  })
  expect_snapshot(error = TRUE, {
    select_by_pct_loss(knn_results, metric = "roc_auc", weight_funk, desc(K))
  })

  data("example_ames_knn")
  expect_equal(
    select_by_pct_loss(ames_grid_search, metric = "rmse", limit = 10, desc(K))$K,
    40
  )
})

test_that("select_by_* can handle metrics with direction == 'zero'", {
  skip_on_cran()

  set.seed(1)
  resamples <- rsample::bootstraps(mtcars, times = 5)

  set.seed(1)
  tune_res <-
    tune::tune_grid(
      parsnip::nearest_neighbor(mode = "regression", neighbors = tune()),
      mpg ~ .,
      resamples,
      metrics = yardstick::metric_set(yardstick::mpe, yardstick::msd)
    )

  tune_res_metrics <- tune_res %>% collect_metrics()

  expect_equal(
    select_best(tune_res, metric = "msd")$.config,
    tune_res_metrics %>%
      filter(.metric == "msd") %>%
      arrange(abs(mean)) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )

  expect_equal(
    select_best(tune_res, metric = "mpe")$.config,
    tune_res_metrics %>%
      filter(.metric == "mpe") %>%
      arrange(abs(mean)) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )

  expect_equal(
    show_best(tune_res, metric = "msd", n = 5)$.config,
    tune_res_metrics %>%
      filter(.metric == "msd") %>%
      arrange(abs(mean)) %>%
      slice(1:5) %>%
      select(.config) %>%
      pull()
  )

  expect_equal(
    show_best(tune_res, metric = "mpe", n = 5)$.config,
    tune_res_metrics %>%
      filter(.metric == "mpe") %>%
      arrange(abs(mean)) %>%
      slice(1:5) %>%
      select(.config) %>%
      pull()
  )

  # one std error, msd ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "msd") %>%
    arrange(min(abs(mean))) %>%
    slice(1)

  bound_lower <- -abs(best$mean) - abs(best$std_err)
  bound_upper <- abs(best$mean) + abs(best$std_err)
  expect_equal(bound_lower, -bound_upper)

  simplest_within_bound <-
    tune_res_metrics %>%
    filter(.metric == "msd") %>%
    filter(abs(mean) < bound_upper) %>%
    arrange(desc(neighbors)) %>%
    slice(1)

  expect_equal(
    select_by_one_std_err(tune_res, metric = "msd", desc(neighbors))$.config,
    simplest_within_bound$.config
  )

  # one std error, mpe ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "mpe") %>%
    arrange(min(abs(mean))) %>%
    slice(1)

  bound_lower <- -abs(best$mean) - abs(best$std_err)
  bound_upper <- abs(best$mean) + abs(best$std_err)
  expect_equal(bound_lower, -bound_upper)

  simplest_within_bound <-
    tune_res_metrics %>%
    filter(.metric == "mpe") %>%
    filter(abs(mean) < bound_upper) %>%
    arrange(desc(neighbors)) %>%
    slice(1)

  expect_equal(
    select_by_one_std_err(tune_res, metric = "mpe", desc(neighbors))$.config,
    simplest_within_bound$.config
  )

  # pct loss, msd ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "msd") %>%
    arrange(abs(mean)) %>%
    slice(1)

  expect_equal(
    select_by_pct_loss(tune_res, metric = "msd", limit = 10, desc(neighbors))$.config,
    tune_res_metrics %>%
      filter(.metric == "msd") %>%
      dplyr::rowwise() %>%
      mutate(loss = abs((abs(mean) - abs(best$mean)) / best$mean) * 100) %>%
      ungroup() %>%
      arrange(desc(neighbors)) %>%
      slice(1:which(.config == best$.config)) %>%
      filter(loss < 10) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )

  # pct loss, mpe ----------
  best <-
    tune_res_metrics %>%
    filter(.metric == "mpe") %>%
    arrange(abs(mean)) %>%
    slice(1)

  expect_equal(
    select_by_pct_loss(tune_res, metric = "mpe", limit = 10, desc(neighbors))$.config,
    tune_res_metrics %>%
      filter(.metric == "mpe") %>%
      dplyr::rowwise() %>%
      mutate(loss = abs((abs(mean) - abs(best$mean)) / best$mean) * 100) %>%
      ungroup() %>%
      arrange(desc(neighbors)) %>%
      slice(1:which(.config == best$.config)) %>%
      filter(loss < 10) %>%
      slice(1) %>%
      select(.config) %>%
      pull()
  )
})

test_that("show_best with survival models", {
  surv_res <- readRDS(test_path("data", "surv_boost_tree_res.rds"))

  expect_snapshot(
    show_best(surv_res)
  )
  expect_snapshot(
    show_best(surv_res, metric = "concordance_survival")
  )
  expect_snapshot(
    show_best(surv_res, metric = "brier_survival_integrated")
  )
  expect_snapshot(
    show_best(surv_res, metric = "brier_survival")
  )
  expect_snapshot(
    show_best(surv_res, metric = c("brier_survival", "roc_auc_survival"))
  )
  expect_snapshot(
    show_best(surv_res, metric = "brier_survival", eval_time = 1)
  )
  expect_snapshot(
    show_best(surv_res, metric = "concordance_survival", eval_time = 1)
  )
  expect_snapshot(
    show_best(surv_res, metric = "concordance_survival", eval_time = 1.1)
  )
  expect_snapshot(
    show_best(surv_res, metric = "brier_survival", eval_time = 1.1),
    error = TRUE
  )
  expect_snapshot(
    show_best(surv_res, metric = "brier_survival", eval_time = 1:2)
  )
  expect_snapshot(
    show_best(surv_res, metric = "brier_survival", eval_time = 3:4),
    error = TRUE
  )

})

test_that("select_best with survival models", {
  surv_res <- readRDS(test_path("data", "surv_boost_tree_res.rds"))

  expect_snapshot(
    select_best(surv_res)
  )
  expect_snapshot(
    select_best(surv_res, metric = "concordance_survival")
  )
  expect_snapshot(
    select_best(surv_res, metric = "brier_survival_integrated")
  )
  expect_snapshot(
    select_best(surv_res, metric = "brier_survival")
  )
  expect_snapshot(
    select_best(surv_res, metric = c("brier_survival", "roc_auc_survival"))
  )
  expect_snapshot(
    select_best(surv_res, metric = "brier_survival", eval_time = 1)
  )
  expect_snapshot(
    select_best(surv_res, metric = "concordance_survival", eval_time = 1)
  )
  expect_snapshot(
    select_best(surv_res, metric = "concordance_survival", eval_time = 1.1)
  )
  expect_snapshot(
    select_best(surv_res, metric = "brier_survival", eval_time = 1.1),
    error = TRUE
  )
  expect_snapshot(
    select_best(surv_res, metric = "brier_survival", eval_time = 1:2)
  )
  expect_snapshot(
    select_best(surv_res, metric = "brier_survival", eval_time = 3:4),
    error = TRUE
  )

})

test_that("select_by_one_std_err with survival models", {
  surv_res <- readRDS(test_path("data", "surv_boost_tree_res.rds"))

  expect_snapshot(
    select_by_one_std_err(surv_res, trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "concordance_survival", trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "brier_survival_integrated", trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "brier_survival", trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = c("brier_survival", "roc_auc_survival"),
                          trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "brier_survival",
                          eval_time = 1, trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "concordance_survival",
                          eval_time = 1, trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "concordance_survival",
                          eval_time = 1.1, trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "brier_survival",
                          eval_time = 1.1, trees),
    error = TRUE
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "brier_survival",
                          eval_time = 1:2, trees)
  )
  expect_snapshot(
    select_by_one_std_err(surv_res, metric = "brier_survival",
                          eval_time = 3:4, trees),
    error = TRUE
  )

})

test_that("select_by_pct_loss with survival models", {
  surv_res <- readRDS(test_path("data", "surv_boost_tree_res.rds"))

  expect_snapshot(
    select_by_pct_loss(surv_res, trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "concordance_survival", trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "brier_survival_integrated", trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "brier_survival", trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = c("brier_survival", "roc_auc_survival"),
                       trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "brier_survival",
                       eval_time = 1, trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "concordance_survival",
                       eval_time = 1, trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "concordance_survival",
                       eval_time = 1.1, trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "brier_survival",
                       eval_time = 1.1, trees),
    error = TRUE
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "brier_survival",
                       eval_time = 1:2, trees)
  )
  expect_snapshot(
    select_by_pct_loss(surv_res, metric = "brier_survival",
                       eval_time = 3:4, trees),
    error = TRUE
  )

})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on May 29, 2024, 7:32 a.m.