tests/testthat/test-metric-args.R

test_that("metric inputs are checked for regression models", {
  library(parsnip)
  library(workflows)
  library(yardstick)
  library(rsample)

  wflow <- workflow(y ~ X1 + X2, linear_reg())
  knn_spec <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
  wflow_tune <- workflow(y ~ X1 + X2, knn_spec)

  set.seed(1)
  split <- initial_split(mtcars)
  rs <- vfold_cv(mtcars)

  # ------------------------------------------------------------------------------
  # setup metric sets

  met_mix_int <-
    metric_set(brier_survival_integrated,
               brier_survival,
               concordance_survival)
  met_reg <- metric_set(rmse)
  met_cls <- metric_set(brier_class)

  # ------------------------------------------------------------------------------
  # check inputs

  expect_snapshot(check_metrics_arg(NULL, wflow))

  expect_snapshot(check_metrics_arg(met_reg, wflow))
  expect_snapshot(check_metrics_arg(met_cls, wflow),     error = TRUE)
  expect_snapshot(check_metrics_arg(met_mix_int, wflow), error = TRUE)

  # ------------------------------------------------------------------------------
  # resampling

  expect_snapshot(fit_resamples(wflow, rs, metrics = met_cls),     error = TRUE)
  expect_snapshot(fit_resamples(wflow, rs, metrics = met_mix_int), error = TRUE)

  # ------------------------------------------------------------------------------
  # tuning

  expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_cls),     error = TRUE)
  expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_mix_int), error = TRUE)

  expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_cls),     error = TRUE)
  expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_mix_int), error = TRUE)

  # ------------------------------------------------------------------------------
  # final fit

  expect_snapshot(last_fit(wflow, split, metrics = met_cls),     error = TRUE)
  expect_snapshot(last_fit(wflow, split, metrics = met_mix_int), error = TRUE)

})

test_that("metric inputs are checked for classification models", {
  library(parsnip)
  library(workflows)
  library(yardstick)
  library(rsample)

  data(two_class_dat, package = "modeldata")
  wflow <- workflow(Class ~ A + B, logistic_reg())
  knn_spec <- nearest_neighbor(neighbors = tune()) %>% set_mode("classification")
  wflow_tune <- workflow(Class ~ A + B, knn_spec)

  set.seed(1)
  split <- initial_split(two_class_dat)
  rs <- vfold_cv(two_class_dat)

  # ------------------------------------------------------------------------------
  # setup metric sets

  met_mix_int <-
    metric_set(brier_survival_integrated,
               brier_survival,
               concordance_survival)
  met_reg <- metric_set(rmse)
  met_cls <- metric_set(brier_class)

  # ------------------------------------------------------------------------------
  # check inputs

  expect_snapshot(check_metrics_arg(NULL, wflow))

  expect_snapshot(check_metrics_arg(met_reg, wflow),     error = TRUE)
  expect_snapshot(check_metrics_arg(met_cls, wflow))
  expect_snapshot(check_metrics_arg(met_mix_int, wflow), error = TRUE)

  # ------------------------------------------------------------------------------
  # resampling

  expect_snapshot(fit_resamples(wflow, rs, metrics = met_reg),     error = TRUE)
  expect_snapshot(fit_resamples(wflow, rs, metrics = met_mix_int), error = TRUE)

  # ------------------------------------------------------------------------------
  # tuning

  expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_reg),     error = TRUE)
  expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_mix_int), error = TRUE)

  expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_reg),     error = TRUE)
  expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_mix_int), error = TRUE)

  # ------------------------------------------------------------------------------
  # final fit

  expect_snapshot(last_fit(wflow, split, metrics = met_reg),     error = TRUE)
  expect_snapshot(last_fit(wflow, split, metrics = met_mix_int), error = TRUE)
})


test_that("metric inputs are checked for censored regression models", {
  skip_if_not_installed("censored")
  library(parsnip)
  library(workflows)
  library(yardstick)
  library(rsample)
  library(censored)

  stanford2$event_time <- Surv(stanford2$time, stanford2$status)
  stanford2 <- stanford2[, c("event_time", "age")]

  wflow <- workflow(event_time ~ age, survival_reg())
  sr_spec <- survival_reg(dist = tune())
  wflow_tune <- workflow(event_time ~ age, sr_spec)

  set.seed(1)
  split <- initial_split(stanford2)
  rs <- vfold_cv(stanford2)

  # ------------------------------------------------------------------------------
  # setup metric sets

  met_srv <- metric_set(concordance_survival)
  met_reg <- metric_set(rmse)
  met_cls <- metric_set(brier_class)

  # ------------------------------------------------------------------------------
  # check inputs

  expect_snapshot(check_metrics_arg(NULL, wflow))

  expect_snapshot(check_metrics_arg(met_reg, wflow), error = TRUE)
  expect_snapshot(check_metrics_arg(met_cls, wflow), error = TRUE)
  expect_snapshot(check_metrics_arg(met_srv, wflow))

  # ------------------------------------------------------------------------------
  # resampling

  expect_snapshot(fit_resamples(wflow, rs, metrics = met_cls), error = TRUE)
  expect_snapshot(fit_resamples(wflow, rs, metrics = met_reg), error = TRUE)

  # ------------------------------------------------------------------------------
  # tuning

  expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_cls), error = TRUE)
  expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_reg), error = TRUE)

  expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_cls), error = TRUE)
  expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_reg), error = TRUE)

  # ------------------------------------------------------------------------------
  # final fit

  expect_snapshot(last_fit(wflow, split, metrics = met_cls), error = TRUE)
  expect_snapshot(last_fit(wflow, split, metrics = met_reg), error = TRUE)

})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on May 29, 2024, 7:32 a.m.