tests/testthat/test-resample.R

# fit_resamples() --------------------------------------------------------------

test_that("`fit_resamples()` returns a `resample_result` object", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  result <- lin_mod %>%
    fit_resamples(mpg ~ ., folds)

  expect_s3_class(result, "resample_results")

  expect_equal(result, .Last.tune.result)
})

test_that("can use `fit_resamples()` with a formula", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  expect_warning(
    result <- lin_mod %>%
      fit_resamples(mpg ~ ., folds),
    NA
  )

  expect_equal(nrow(result$.metrics[[1]]), 2L)
})

test_that("can use `fit_resamples()` with a recipe", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
    recipes::step_ns(disp) %>%
    recipes::step_ns(wt)

  lin_mod <- linear_reg() %>%
    set_engine("lm")

  # Ensure the recipes are prepped and returned
  control <- control_resamples(extract = function(x) x)

  result <- fit_resamples(lin_mod, rec, folds, control = control)

  prepped_rec <- extract_recipe(result$.extracts[[1]][[1]][[1]])

  expect_true(prepped_rec$steps[[1]]$trained)
  expect_true(prepped_rec$steps[[2]]$trained)

  expect_equal(nrow(result$.metrics[[1]]), 2L)
})

test_that("can use `fit_resamples()` with a workflow - recipe", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
    recipes::step_ns(disp) %>%
    recipes::step_ns(wt)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  workflow <- workflow() %>%
    add_recipe(rec) %>%
    add_model(lin_mod)

  expect <- fit_resamples(lin_mod, rec, folds)

  result <- fit_resamples(workflow, folds)

  expect_equal(collect_metrics(expect), collect_metrics(result))
})

test_that("can use `fit_resamples()` with a workflow - variables", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  workflow <- workflow() %>%
    add_variables(mpg, c(cyl, disp)) %>%
    add_model(lin_mod)

  expect <- fit_resamples(lin_mod, mpg ~ cyl + disp, folds)

  result <- fit_resamples(workflow, folds)

  expect_equal(collect_metrics(expect), collect_metrics(result))
})

test_that("can use `fit_resamples()` with a workflow - formula", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  workflow <- workflow() %>%
    add_formula(mpg ~ cyl + disp) %>%
    add_model(lin_mod)

  expect <- fit_resamples(lin_mod, mpg ~ cyl + disp, folds)

  result <- fit_resamples(workflow, folds)

  expect_equal(collect_metrics(expect), collect_metrics(result))
})

test_that("extracted workflow is finalized", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  workflow <- workflow() %>%
    add_variables(mpg, c(cyl, disp)) %>%
    add_model(lin_mod)

  control <- control_resamples(extract = identity)

  result <- fit_resamples(workflow, folds, control = control)
  result_workflow <- result$.extracts[[1]]$.extracts[[1]]

  expect_true(result_workflow$trained)
})

# Error capture ----------------------------------------------------------------

test_that("failure in recipe is caught elegantly", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
    recipes::step_ns(disp, deg_free = NA_real_)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  control <- control_resamples(extract = function(x) x, save_pred = TRUE)

  expect_snapshot(
    result <- fit_resamples(lin_mod, rec, folds, control = control)
  )

  notes <- result$.notes
  note <- notes[[1]]$note

  extract <- result$.extracts
  predictions <- result$.predictions

  expect_length(notes, 2L)

  # Known failure in the recipe
  expect_true(any(grepl("missing value where", note)))

  expect_equal(extract, list(NULL, NULL))
  expect_equal(predictions, list(NULL, NULL))
})

test_that("failure in variables tidyselect specification is caught elegantly", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  workflow <- workflow() %>%
    add_model(lin_mod) %>%
    add_variables(mpg, foobar)

  control <- control_resamples(extract = function(x) x, save_pred = TRUE)

  expect_snapshot(
    result <- fit_resamples(workflow, folds, control = control)
  )

  notes <- result$.notes
  note <- notes[[1]]$note

  extract <- result$.extracts
  predictions <- result$.predictions

  expect_length(notes, 2L)

  # Known failure in the variables part
  expect_true(any(grepl("foobar", note)))

  expect_equal(extract, list(NULL, NULL))
  expect_equal(predictions, list(NULL, NULL))
})

test_that("classification models generate correct error message", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <- recipes::recipe(vs ~ ., data = mtcars)

  log_mod <- parsnip::logistic_reg() %>%
    parsnip::set_engine("glm")

  control <- control_resamples(extract = function(x) x, save_pred = TRUE)

  expect_snapshot(
    result <- fit_resamples(log_mod, rec, folds, control = control)
  )

  notes <- result$.notes
  note <- notes[[1]]$note

  extract <- result$.extracts
  predictions <- result$.predictions

  expect_length(notes, 2L)

  # Known failure in the recipe
  expect_true(all(grepl("outcome should be a `factor`", note)))

  expect_equal(extract, list(NULL, NULL))
  expect_equal(predictions, list(NULL, NULL))
})

# tune_grid() fallback ---------------------------------------------------------

test_that("`tune_grid()` falls back to `fit_resamples()` - formula", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  expect <- fit_resamples(lin_mod, mpg ~ ., folds)

  expect_snapshot(
    result <- tune_grid(lin_mod, mpg ~ ., folds)
  )

  expect_equal(collect_metrics(expect), collect_metrics(result))
})

test_that("`tune_grid()` falls back to `fit_resamples()` - workflow variables", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  wf <- workflow() %>%
    add_model(lin_mod) %>%
    add_variables(mpg, c(cyl, disp))

  expect <- fit_resamples(wf, folds)

  expect_snapshot(
    result <- tune_grid(wf, folds)
  )

  expect_equal(collect_metrics(expect), collect_metrics(result))
})

test_that("`tune_grid()` ignores `grid` if there are no tuning parameters", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  expect <- lin_mod %>%
    fit_resamples(mpg ~ ., folds)

  expect_snapshot(
    result <- lin_mod %>% tune_grid(mpg ~ ., grid = data.frame(x = 1), folds)
  )

  expect_equal(collect_metrics(expect), collect_metrics(result))
})


# autoplot() -------------------------------------------------------------------

test_that("cannot autoplot `fit_resamples()` results", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  result <- lin_mod %>%
    fit_resamples(mpg ~ ., folds)

  expect_snapshot(error = TRUE, {
    autoplot(result)
  })
})

test_that("ellipses with fit_resamples", {
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  expect_snapshot(
    lin_mod %>% fit_resamples(mpg ~ ., folds, something = "wrong")
  )
})

test_that("argument order gives errors for recipe/formula", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  rec <- recipes::recipe(mpg ~ ., data = mtcars) %>%
    recipes::step_ns(disp) %>%
    recipes::step_ns(wt)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  expect_snapshot(error = TRUE, {
    fit_resamples(rec, lin_mod, folds)
  })
  expect_snapshot(error = TRUE, {
    fit_resamples(mpg ~ ., lin_mod, folds)
  })
})



test_that("retain extra attributes", {
  set.seed(6735)
  folds <- rsample::vfold_cv(mtcars, v = 2)

  lin_mod <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")

  res <- lin_mod %>%
    fit_resamples(mpg ~ ., folds)

  att <- attributes(res)
  att_names <- names(att)
  expect_true(any(att_names == "metrics"))
  expect_true(any(att_names == "outcomes"))
  expect_true(any(att_names == "parameters"))

  expect_true(is.character(att$outcomes))
  expect_true(att$outcomes == "mpg")
  expect_true(inherits(att$parameters, "parameters"))
  expect_true(inherits(att$metrics, "metric_set"))

  res2 <- lin_mod %>%
    fit_resamples(
      mpg ~ .,
      folds,
      control = control_resamples(save_workflow = TRUE)
    )
  expect_null(attr(res, "workflow"))
  expect_true(inherits(attr(res2, "workflow"), "workflow"))

  expect_snapshot(
    fit_resamples(
      lin_mod,
      recipes::recipe(mpg ~ ., mtcars[rep(1:32, 3000), ]),
      folds,
      control = control_resamples(save_workflow = TRUE)
    )
  )
})


test_that("`fit_resamples()` when objects need tuning", {
  rec <- recipe(mpg ~ ., data = mtcars) %>% step_ns(disp, deg_free = tune())
  spec_1 <- linear_reg(penalty = tune()) %>% set_engine("glmnet")
  spec_2 <- linear_reg()
  wflow_1 <- workflow(rec, spec_1)
  wflow_2 <- workflow(mpg ~ ., spec_1)
  wflow_3 <- workflow(rec, spec_2)
  rs <- rsample::vfold_cv(mtcars)

  expect_snapshot_error(fit_resamples(wflow_1, rs))
  expect_snapshot_error(fit_resamples(wflow_2, rs))
  expect_snapshot_error(fit_resamples(wflow_3, rs))
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on Aug. 24, 2023, 1:09 a.m.