tests/testthat/test-broom.R

# ------------------------------------------------------------------------------
# tidy()

test_that("can't tidy the model of an unfit workflow", {
  x <- workflow()
  expect_snapshot(error = TRUE, tidy(x))
})

test_that("can't tidy the recipe of an unfit workflow", {
  x <- workflow()

  expect_snapshot(error = TRUE, tidy(x, what = "recipe"))

  rec <- recipes::recipe(y ~ x, data.frame(y = 1, x = 1))

  x <- add_recipe(x, rec)

  expect_snapshot(error = TRUE, tidy(x, what = "recipe"))
})

test_that("can tidy workflow model or recipe", {
  skip_if_not_installed("broom")

  df <- data.frame(y = c(2, 3, 4), x = c(1, 5, 3))

  rec <- recipes::recipe(y ~ x, df)
  rec <- recipes::step_log(rec, x)

  lm_spec <- parsnip::linear_reg()
  lm_spec <- parsnip::set_engine(lm_spec, "lm")

  wf <- workflow()
  wf <- add_recipe(wf, rec)
  wf <- add_model(wf, lm_spec)

  wf <- fit(wf, df)

  x <- tidy(wf)
  expect_identical(x$term, c("(Intercept)", "x"))

  x <- tidy(wf, what = "recipe")
  expect_identical(x$number, 1L)
})

# ------------------------------------------------------------------------------
# glance()

test_that("can't glance at the model of an unfit workflow", {
  x <- workflow()
  expect_snapshot_error(glance(x))
})

test_that("can glance at a fitted workflow's model", {
  skip_if_not_installed("broom")

  df <- data.frame(y = c(2, 3, 4), x = c(1, 5, 3))

  lm_spec <- parsnip::linear_reg()
  lm_spec <- parsnip::set_engine(lm_spec, "lm")

  wf <- workflow()
  wf <- add_formula(wf, y ~ x)
  wf <- add_model(wf, lm_spec)

  wf <- fit(wf, df)

  x <- glance(wf)

  expect_s3_class(x, "tbl_df")
  expect_identical(nrow(x), 1L)
})

# ------------------------------------------------------------------------------
# augment()

test_that("can't augment with the model of an unfit workflow", {
  x <- workflow()
  x <- add_model(x, parsnip::linear_reg())
  expect_snapshot_error(augment(x, mtcars))
})

test_that("can augment using a fitted workflow's model", {
  skip_if_not_installed("broom")

  df <- data.frame(y = c(2, 3, 4), x = c(1, 5, 3))

  lm_spec <- parsnip::linear_reg()
  lm_spec <- parsnip::set_engine(lm_spec, "lm")

  wf <- workflow()
  wf <- add_formula(wf, y ~ x)
  wf <- add_model(wf, lm_spec)

  wf <- fit(wf, df)

  x <- augment(wf, df)

  expect_s3_class(x, "tbl_df")
  expect_identical(nrow(x), 3L)

  # at least 1 prediction specific column should be added
  expect_true(ncol(x) > ncol(df))
})

test_that("augment returns `new_data`, not the pre-processed version of `new_data`", {
  skip_if_not_installed("broom")

  df <- data.frame(y = c(2, 3, 4), x = factor(c("a", "b", "a")))

  lm_spec <- parsnip::linear_reg()
  lm_spec <- parsnip::set_engine(lm_spec, "lm")

  wf <- workflow()
  wf <- add_formula(wf, y ~ x)
  wf <- add_model(wf, lm_spec)

  wf <- fit(wf, df)

  # Returns `new_data` + prediction columns
  x <- augment(wf, df)

  expect_true(all(names(df) %in% names(x)))
})

test_that("augment fails if it can't preprocess `new_data`", {
  skip_if_not_installed("broom")

  df <- data.frame(y = c(2, 3, 4), x = factor(c("a", "b", "a")))
  new_data <- data.frame(y = c(2, 3, 4), x = 1:3)

  lm_spec <- parsnip::linear_reg()
  lm_spec <- parsnip::set_engine(lm_spec, "lm")

  wf <- workflow()
  wf <- add_formula(wf, y ~ x)
  wf <- add_model(wf, lm_spec)

  wf <- fit(wf, df)

  # vctrs type error
  expect_error(augment(wf, new_data))
})

test_that("augment works with matrix compositions (#148)", {
  skip_if_not_installed("broom")

  df <- data.frame(y = c(2, 3, 4), x = c(1, 5, 2), z = c(6, 8, 10))
  new_data <- data.frame(x = 1:3, z = 4:6)

  bp <- hardhat::default_formula_blueprint(composition = "matrix")

  lm_spec <- parsnip::linear_reg()
  lm_spec <- parsnip::set_engine(lm_spec, "lm")

  wf <- workflow()
  wf <- add_formula(wf, y ~ x + z, blueprint = bp)
  wf <- add_model(wf, lm_spec)

  wf <- fit(wf, df)

  out <- augment(wf, new_data = new_data)

  expect_s3_class(out, "tbl_df")
  expect_named(out, c("x", "z", ".pred"))
})

test_that("augment works with sparse matrix compositions (#148)", {
  skip_if_not_installed("broom")

  # These two dependencies aren't in Suggests, so mainly we just run this test
  # locally. They are only used for broom tests, and we don't want to bloat
  # Suggests just for broom support.
  skip_if_not_installed("Matrix")
  # A parsnip engine that supports sparse matrices
  skip_if_not_installed("ranger")

  df <- data.frame(y = c(2, 3, 4), x = c(1, 5, 2), z = c(6, 8, 10))
  new_data <- data.frame(x = 1:3, z = 4:6)

  bp <- hardhat::default_formula_blueprint(composition = "dgCMatrix")

  spec <- parsnip::rand_forest(mode = "regression", engine = "ranger")

  wf <- workflow()
  wf <- add_formula(wf, y ~ x + z, blueprint = bp)
  wf <- add_model(wf, spec)

  wf <- fit(wf, df)

  out <- augment(wf, new_data = new_data)

  expect_s3_class(out, "tbl_df")
  expect_named(out, c("x", "z", ".pred"))
})

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on March 7, 2023, 7:50 p.m.