tests/testthat/test-fit-action-model.R

test_that("can add a model to a workflow", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, mod)

  expect_s3_class(workflow$fit$actions$model, "action_model")
})

test_that("model is validated", {
  expect_snapshot(error = TRUE, add_model(workflow(), 1))
})

test_that("model must contain a known mode (#160)", {
  mod <- parsnip::decision_tree()

  workflow <- workflow()

  expect_snapshot(error = TRUE, {
    add_model(workflow, mod)
  })
})

test_that("prompt on spec without a loaded implementation (#174)", {
  mod <- parsnip::bag_tree() %>%
    parsnip::set_mode("regression")

  workflow <- workflow()

  expect_snapshot(error = TRUE, add_model(workflow, mod))
  expect_snapshot(error = TRUE, workflow(spec = mod))
})

test_that("cannot add two models", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, mod)

  expect_snapshot(error = TRUE, add_model(workflow, mod))
})

test_that("can provide a model formula override", {
  # disp is in the recipe, but excluded from the model formula
  rec <- recipes::recipe(mpg ~ cyl + disp, mtcars)
  rec <- recipes::step_center(rec, cyl)

  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_recipe(workflow, rec)
  workflow <- add_model(workflow, mod, formula = mpg ~ cyl)

  result <- fit(workflow, mtcars)

  expect_equal(
    c("(Intercept)", "cyl"),
    names(result$fit$fit$fit$coefficients)
  )
})

test_that("model formula override can contain `offset()` (#162)", {
  df <- vctrs::data_frame(
    y = c(1.5, 2.5, 3.5, 1, 3),
    x = c(2, 6, 7, 3, 6),
    o = c(1.1, 2, 3, .5, 2)
  )

  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, lm_model, formula = y ~ x + offset(o))
  workflow <- add_variables(workflow, y, c(x, o))

  result <- fit(workflow, data = df)
  lm_result <- hardhat::extract_fit_engine(result)

  expect_named(lm_result$coefficients, c("(Intercept)", "x"))
  expect_identical(attr(lm_result$terms, "offset"), 3L)
})

test_that("remove a model", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  workflow_no_model <- workflow()
  workflow_no_model <- add_formula(workflow_no_model, mpg ~ cyl)

  workflow_with_model <- add_model(workflow_no_model, lm_model)
  workflow_removed_model <- remove_model(workflow_with_model)

  expect_equal(workflow_no_model$fit, workflow_removed_model$fit)
})

test_that("remove a model after model fit", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  workflow_no_model <- workflow()
  workflow_no_model <- add_formula(workflow_no_model, mpg ~ cyl)

  workflow_with_model <- add_model(workflow_no_model, lm_model)
  workflow_with_model <- fit(workflow_with_model, data = mtcars)

  workflow_removed_model <- remove_model(workflow_with_model)

  expect_equal(workflow_no_model$fit, workflow_removed_model$fit)
})

test_that("update a model", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")
  glmn_model <- parsnip::set_engine(lm_model, "glmnet")

  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ cyl)
  workflow <- add_model(workflow, lm_model)
  workflow <- update_model(workflow, glmn_model)

  expect_equal(workflow$fit$actions$model$spec$engine, "glmnet")
})


test_that("update a model after model fit", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")
  no_model <- parsnip::set_engine(lm_model, "lm", model = FALSE)

  workflow <- workflow()
  workflow <- add_model(workflow, no_model)
  workflow <- add_formula(workflow, mpg ~ cyl)

  workflow <- fit(workflow, data = mtcars)
  workflow <- update_model(workflow, lm_model)

  # Should no longer have `model = FALSE` engine arg
  engine_args <- workflow$fit$actions$model$spec$eng_args
  expect_false(any(names(engine_args) == "model"))

  # The fitted model should be removed
  expect_null(workflow$fit$fit)
})

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on May 29, 2024, 3:57 a.m.