tests/testthat/test_augment.R

test_that('regression models', {
  x <- linear_reg() %>% set_engine("lm")

  reg_form <- x %>% fit(mpg ~ ., data = mtcars)
  reg_xy <- x %>% fit_xy(mtcars[, -1], mtcars$mpg)

  expect_equal(
    colnames(augment(reg_form, head(mtcars))),
    c( ".pred", ".resid",
       "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
      "gear", "carb")
  )
  expect_equal(nrow(augment(reg_form, head(mtcars))), 6)
  expect_equal(
    colnames(augment(reg_form, head(mtcars[, -1]))),
    c(".pred",
    "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
      "gear", "carb")
  )
  expect_equal(nrow(augment(reg_form, head(mtcars[, -1]))), 6)

  expect_equal(
    colnames(augment(reg_xy, head(mtcars))),
    c(".pred",
      "mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
      "gear", "carb")
  )
  expect_equal(nrow(augment(reg_xy, head(mtcars))), 6)
  expect_equal(
    colnames(augment(reg_xy, head(mtcars[, -1]))),
    c(".pred",
      "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
      "gear", "carb")
  )
  expect_equal(nrow(augment(reg_xy, head(mtcars[, -1]))), 6)

  expect_s3_class(augment(reg_form, head(mtcars)), "tbl_df")

  reg_form$spec$mode <- "depeche"

  expect_error(augment(reg_form, head(mtcars[, -1])), "Unknown mode: depeche")

})



test_that('classification models', {
  data(two_class_dat, package = "modeldata")
  x <- logistic_reg() %>% set_engine("glm")

  cls_form <- x %>% fit(Class ~ ., data = two_class_dat)
  cls_xy <- x %>% fit_xy(two_class_dat[, -3], two_class_dat$Class)

  expect_equal(
    colnames(augment(cls_form, head(two_class_dat))),
    c(".pred_class", ".pred_Class1", ".pred_Class2", "A", "B", "Class")
  )
  expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6)
  expect_equal(
    colnames(augment(cls_form, head(two_class_dat[, -3]))),
    c(".pred_class", ".pred_Class1", ".pred_Class2", "A", "B")
  )
  expect_equal(nrow(augment(cls_form, head(two_class_dat[, -3]))), 6)

  expect_equal(
    colnames(augment(cls_xy, head(two_class_dat))),
    c(".pred_class", ".pred_Class1", ".pred_Class2", "A", "B", "Class")
  )
  expect_equal(nrow(augment(cls_xy, head(two_class_dat))), 6)
  expect_equal(
    colnames(augment(cls_xy, head(two_class_dat[, -3]))),
    c(".pred_class", ".pred_Class1", ".pred_Class2", "A", "B")
  )
  expect_equal(nrow(augment(cls_xy, head(two_class_dat[, -3]))), 6)

})


test_that('augment for model without class probabilities', {
  skip_if_not_installed("LiblineaR")

  data(two_class_dat, package = "modeldata")
  x <- svm_linear(mode = "classification") %>% set_engine("LiblineaR")
  cls_form <- x %>% fit(Class ~ ., data = two_class_dat)

  expect_equal(
    colnames(augment(cls_form, head(two_class_dat))),
    c(".pred_class", "A", "B", "Class")
  )
  expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6)

})

Try the parsnip package in your browser

Any scripts or data that you put into this service are public.

parsnip documentation built on Aug. 18, 2023, 1:07 a.m.