tests/testthat/test-augment.R

data(two_class_dat, package = "modeldata")

# ------------------------------------------------------------------------------

test_that("augment fit_resamples", {
  lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm")

  set.seed(1)
  two_class_dat <- as.data.frame(two_class_dat)
  bt1 <- rsample::bootstraps(two_class_dat, times = 30)

  set.seed(1)
  fit_1 <- fit_resamples(
    lr_spec,
    Class ~ .,
    bt1,
    control = control_resamples(save_pred = TRUE)
  )
  aug_1 <- augment(fit_1)
  expect_true(nrow(aug_1) == nrow(two_class_dat))
  expect_equal(aug_1[["A"]], two_class_dat[["A"]])
  expect_true(sum(!is.na(aug_1$.pred_class)) == nrow(two_class_dat))
  expect_true(sum(names(aug_1) == ".pred_class") == 1)
  expect_true(sum(names(aug_1) == ".pred_Class1") == 1)
  expect_true(sum(names(aug_1) == ".pred_Class2") == 1)
  expect_true(sum(names(aug_1) == ".resid") == 0)
  expect_s3_class_bare_tibble(aug_1)

  expect_true(names(aug_1)[1] == ".pred_class")
  expect_true(names(aug_1)[2] == ".pred_Class1")
  expect_true(names(aug_1)[3] == ".pred_Class2")

  expect_identical(
    rle(grepl("^\\.", names(aug_1)))$values,
    c(TRUE, FALSE)
  )

  expect_snapshot(error = TRUE, augment(fit_1, hey = "you"))
})


test_that("augment fit_resamples", {
  skip_if(new_rng_snapshots)
  lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm")

  set.seed(1)
  two_class_dat <- as.data.frame(two_class_dat)
  bt2 <- rsample::bootstraps(two_class_dat, times = 3)

  set.seed(1)
  fit_2 <- fit_resamples(
    lr_spec,
    Class ~ .,
    bt2,
    control = control_resamples(save_pred = TRUE)
  )
  expect_snapshot(aug_2 <- augment(fit_2))

  expect_true(nrow(aug_2) == nrow(two_class_dat))
  expect_equal(aug_2[["A"]], two_class_dat[["A"]])
  expect_true(sum(!is.na(aug_2$.pred_class)) < nrow(two_class_dat))
  expect_true(sum(names(aug_2) == ".pred_class") == 1)
  expect_true(sum(names(aug_2) == ".pred_Class1") == 1)
  expect_true(sum(names(aug_2) == ".pred_Class2") == 1)
  expect_s3_class_bare_tibble(aug_2)

  expect_true(names(aug_2)[1] == ".pred_class")
  expect_true(names(aug_2)[2] == ".pred_Class1")
  expect_true(names(aug_2)[3] == ".pred_Class2")

  expect_identical(
    rle(grepl("^\\.", names(aug_2)))$values,
    c(TRUE, FALSE)
  )
})

# ------------------------------------------------------------------------------

test_that("augment tune_grid", {
  skip_if_not_installed("kernlab")

  svm_spec <- parsnip::svm_linear(cost = tune(), margin = 0.1) %>%
    parsnip::set_engine("kernlab") %>%
    parsnip::set_mode("regression")
  set.seed(1)
  cv1 <- rsample::vfold_cv(mtcars)

  set.seed(1)
  fit_1 <- tune_grid(
    svm_spec,
    mpg ~ .,
    cv1,
    grid = data.frame(cost = 1:3),
    control = control_grid(save_pred = TRUE)
  )
  aug_1 <- augment(fit_1)
  expect_true(nrow(aug_1) == nrow(mtcars))
  expect_equal(aug_1[["wt"]], mtcars[["wt"]])
  expect_true(sum(!is.na(aug_1$.pred)) == nrow(mtcars))
  expect_true(sum(names(aug_1) == ".pred") == 1)
  expect_true(sum(names(aug_1) == ".resid") == 1)
  expect_s3_class_bare_tibble(aug_1)

  expect_true(names(aug_1)[1] == ".pred")

  expect_identical(
    rle(grepl("^\\.", names(aug_1)))$values,
    c(TRUE, FALSE)
  )

  aug_2 <- augment(fit_1, parameters = data.frame(cost = 3))
  expect_true(any(abs(aug_1$.pred - aug_2$.pred) > 1))

  expect_snapshot(error = TRUE, {
    augment(fit_1, parameters = list(cost = 3))
  })

  expect_snapshot(error = TRUE, {
    augment(fit_1, parameters = data.frame(cost = 3:4))
  })

  expect_snapshot(error = TRUE, {
    augment(fit_1, cost = 4)
  })

  # ------------------------------------------------------------------------------

  suppressMessages({
    set.seed(1)
    fit_2 <- tune_bayes(
      svm_spec,
      mpg ~ .,
      cv1,
      initial = fit_1,
      iter = 2,
      param_info = parameters(dials::cost(c(-10, 5))),
      control = control_bayes(save_pred = TRUE)
    )
  })
  aug_3 <- augment(fit_2)
  expect_true(nrow(aug_3) == nrow(mtcars))
  expect_equal(aug_3[["wt"]], mtcars[["wt"]])
  expect_true(sum(!is.na(aug_3$.pred)) == nrow(mtcars))
  expect_true(sum(names(aug_3) == ".pred") == 1)
  expect_true(sum(names(aug_3) == ".resid") == 1)
  expect_s3_class_bare_tibble(aug_3)

  expect_true(names(aug_3)[1] == ".pred")
})


# ------------------------------------------------------------------------------

test_that("augment last_fit", {
  lr_spec <- parsnip::logistic_reg() %>% parsnip::set_engine("glm")
  set.seed(1)
  split <- rsample::initial_split(two_class_dat)
  fit_1 <- last_fit(lr_spec, Class ~ ., split = split)

  aug_1 <- augment(fit_1)
  expect_true(nrow(aug_1) == nrow(rsample::assessment(split)))
  expect_equal(aug_1[["A"]], rsample::assessment(split)[["A"]])
  expect_true(sum(!is.na(aug_1$.pred_class)) == nrow(rsample::assessment(split)))
  expect_true(sum(names(aug_1) == ".pred_class") == 1)
  expect_true(sum(names(aug_1) == ".pred_Class1") == 1)
  expect_true(sum(names(aug_1) == ".pred_Class2") == 1)
  expect_s3_class_bare_tibble(aug_1)

  expect_true(names(aug_1)[1] == ".pred_class")
  expect_true(names(aug_1)[2] == ".pred_Class1")
  expect_true(names(aug_1)[3] == ".pred_Class2")

  expect_identical(
    rle(grepl("^\\.", names(aug_1)))$values,
    c(TRUE, FALSE)
  )

  expect_snapshot(error = TRUE, augment(fit_1, potato = TRUE))
})

Try the tune package in your browser

Any scripts or data that you put into this service are public.

tune documentation built on May 29, 2024, 7:32 a.m.