tests/testthat/test-predict.R

skip_if_not_installed("recipes")

test_that("can predict from a workflow", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ cyl)
  workflow <- add_model(workflow, mod)

  fit_workflow <- fit(workflow, mtcars)

  result <- predict(fit_workflow, mtcars)

  expect_s3_class(result, "tbl_df")
  expect_equal(nrow(result), 32)
})

test_that("workflow must have been `fit()` before prediction can be done", {
  expect_snapshot(error = TRUE, predict(workflow(), mtcars))
})

test_that("formula preprocessing is done to the `new_data`", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ log(cyl))
  workflow <- add_model(workflow, mod)

  fit_workflow <- fit(workflow, mtcars)

  result1 <- predict(fit_workflow, mtcars)

  # pre-log the data
  mtcars_with_log <- mtcars
  mtcars_with_log$cyl <- log(mtcars_with_log$cyl)

  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ cyl)
  workflow <- add_model(workflow, mod)

  fit_workflow <- fit(workflow, mtcars_with_log)

  result2 <- predict(fit_workflow, mtcars_with_log)

  expect_equal(result1, result2)
})

test_that("recipe preprocessing is done to the `new_data`", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  rec <- recipes::recipe(mpg ~ cyl, mtcars)
  rec <- recipes::step_log(rec, cyl)

  workflow <- workflow()
  workflow <- add_recipe(workflow, rec)
  workflow <- add_model(workflow, mod)

  fit_workflow <- fit(workflow, mtcars)

  result1 <- predict(fit_workflow, mtcars)

  # pre-log the data
  mtcars_with_log <- mtcars
  mtcars_with_log$cyl <- log(mtcars_with_log$cyl)

  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ cyl)
  workflow <- add_model(workflow, mod)

  fit_workflow <- fit(workflow, mtcars_with_log)

  result2 <- predict(fit_workflow, mtcars_with_log)

  expect_equal(result1, result2)
})

test_that("`new_data` must have all of the original predictors", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  rec <- recipes::recipe(mpg ~ cyl, mtcars)
  rec <- recipes::step_log(rec, cyl)

  workflow <- workflow()
  workflow <- add_recipe(workflow, rec)
  workflow <- add_model(workflow, mod)

  fit_workflow <- fit(workflow, mtcars)

  cars_no_cyl <- mtcars
  cars_no_cyl$cyl <- NULL

  # This error comes from hardhat, so we don't snapshot it
  expect_error(predict(fit_workflow, cars_no_cyl))
})

test_that("blueprint will get passed on to hardhat::forge()", {
  train <- data.frame(
    y = c(1L, 5L, 3L, 4L),
    x = factor(c("x", "y", "x", "y"))
  )

  test <- data.frame(
    x = factor(c("x", "y", "z"))
  )

  spec <- parsnip::linear_reg()
  spec <- parsnip::set_engine(spec, "lm")

  bp1 <- hardhat::default_formula_blueprint(
    intercept = TRUE,
    allow_novel_levels = FALSE
  )
  bp2 <- hardhat::default_formula_blueprint(
    intercept = TRUE,
    allow_novel_levels = TRUE
  )

  workflow <- workflow()
  workflow <- add_model(workflow, spec)

  workflow1 <- add_formula(workflow, y ~ x, blueprint = bp1)
  workflow2 <- add_formula(workflow, y ~ x, blueprint = bp2)

  mod1 <- fit(workflow1, train)
  mod2 <- fit(workflow2, train)

  # Warning from hardhat, so we don't snapshot it
  expect_warning(pred1 <- predict(mod1, test))
  expect_no_warning(pred2 <- predict(mod2, test))

  expect_identical(
    pred1[[".pred"]],
    c(2, 4.5, NA)
  )

  expect_identical(
    pred2[[".pred"]],
    c(2, 4.5, 2)
  )
})

test_that("monitoring: no double intercept due to dot expansion in model formula #210", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  # model formula includes a dot to mean "everything available after the preprocessing formula
  workflow <- workflow()
  workflow <- add_model(workflow, mod, formula = mpg ~ .)

  blueprint_with_intercept <- hardhat::default_formula_blueprint(
    intercept = TRUE
  )
  workflow_with_intercept <- add_formula(
    workflow,
    mpg ~ hp + disp,
    blueprint = blueprint_with_intercept
  )
  fit_with_intercept <- fit(workflow_with_intercept, mtcars)

  # The dot expansion used to include the intercept column, added via the blueprint, as a regular predictor.
  # `parsnip:::prepare_data()` removed this column, so lm's predict method errored.
  # Now it gets removed before fitting (lm will handle the intercept itself),
  # so lm()'s predict method won't error anymore here. (tidymodels/parsnip#1033)
  expect_no_error(predict(fit_with_intercept, mtcars))
})

test_that("predict(type) is respected with a postprocessor (#251)", {
  skip_if_not_installed("tailor")
  # create example data
  y <- seq(0, 7, .1)
  d <- data.frame(
    y = as.factor(ifelse(y > 3.5, "yes", "no")),
    x = y + (y - 3)^2
  )
  wflow <- workflow(y ~ ., parsnip::logistic_reg(), tailor::tailor())
  wflow_fit <- fit(wflow, d)

  pred_class <- predict(wflow_fit, d[1:5, ], type = "class")
  pred_prob <- predict(wflow_fit, d[1:5, ], type = "prob")
  pred_null <- predict(wflow_fit, d[1:5, ])

  expect_named(pred_class, ".pred_class")
  expect_named(pred_prob, c(".pred_no", ".pred_yes"), ignore.order = TRUE)
  expect_equal(pred_class, pred_null)

  expect_snapshot(error = TRUE, predict(wflow_fit, d[1:5, ], type = "boop"))
})
tidymodels/workflows documentation built on June 10, 2025, 12:12 a.m.