
test_that("can add a model to a workflow", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, mod)

  expect_s3_class(workflow$fit$actions$model, "action_model")

test_that("model is validated", {
  expect_snapshot(error = TRUE, add_model(workflow(), 1))

test_that("model must contain a known mode (#160)", {
  mod <- parsnip::decision_tree()

  workflow <- workflow()

  expect_snapshot(error = TRUE, {
    add_model(workflow, mod)

test_that("prompt on spec without a loaded implementation (#174)", {
  mod <- parsnip::bag_tree() %>%

  workflow <- workflow()

  expect_snapshot(error = TRUE, add_model(workflow, mod))
  expect_snapshot(error = TRUE, workflow(spec = mod))

test_that("cannot add two models", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, mod)

  expect_snapshot(error = TRUE, add_model(workflow, mod))

test_that("can provide a model formula override", {
  # disp is in the recipe, but excluded from the model formula
  rec <- recipes::recipe(mpg ~ cyl + disp, mtcars)
  rec <- recipes::step_center(rec, cyl)

  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_recipe(workflow, rec)
  workflow <- add_model(workflow, mod, formula = mpg ~ cyl)

  result <- fit(workflow, mtcars)

    c("(Intercept)", "cyl"),

test_that("model formula override can contain `offset()` (#162)", {
  df <- vctrs::data_frame(
    y = c(1.5, 2.5, 3.5, 1, 3),
    x = c(2, 6, 7, 3, 6),
    o = c(1.1, 2, 3, .5, 2)

  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, lm_model, formula = y ~ x + offset(o))
  workflow <- add_variables(workflow, y, c(x, o))

  result <- fit(workflow, data = df)
  lm_result <- hardhat::extract_fit_engine(result)

  expect_named(lm_result$coefficients, c("(Intercept)", "x"))
  expect_identical(attr(lm_result$terms, "offset"), 3L)

test_that("remove a model", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  workflow_no_model <- workflow()
  workflow_no_model <- add_formula(workflow_no_model, mpg ~ cyl)

  workflow_with_model <- add_model(workflow_no_model, lm_model)
  workflow_removed_model <- remove_model(workflow_with_model)

  expect_equal(workflow_no_model$fit, workflow_removed_model$fit)

test_that("remove a model after model fit", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  workflow_no_model <- workflow()
  workflow_no_model <- add_formula(workflow_no_model, mpg ~ cyl)

  workflow_with_model <- add_model(workflow_no_model, lm_model)
  workflow_with_model <- fit(workflow_with_model, data = mtcars)

  workflow_removed_model <- remove_model(workflow_with_model)

  expect_equal(workflow_no_model$fit, workflow_removed_model$fit)

test_that("update a model", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")
  glmn_model <- parsnip::set_engine(lm_model, "glmnet")

  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ cyl)
  workflow <- add_model(workflow, lm_model)
  workflow <- update_model(workflow, glmn_model)

  expect_equal(workflow$fit$actions$model$spec$engine, "glmnet")

test_that("update a model after model fit", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")
  no_model <- parsnip::set_engine(lm_model, "lm", model = FALSE)

  workflow <- workflow()
  workflow <- add_model(workflow, no_model)
  workflow <- add_formula(workflow, mpg ~ cyl)

  workflow <- fit(workflow, data = mtcars)
  workflow <- update_model(workflow, lm_model)

  # Should no longer have `model = FALSE` engine arg
  engine_args <- workflow$fit$actions$model$spec$eng_args
  expect_false(any(names(engine_args) == "model"))

  # The fitted model should be removed

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on May 29, 2024, 3:57 a.m.