tests/testthat/test-pre-action-variables.R

test_that("can add variables to a workflow", {
  wf <- workflow()
  wf <- add_variables(wf, y, c(x1, x2))

  expect_s3_class(wf$pre$actions$variables, "action_variables")
  expect_s3_class(wf$pre$actions$variables$variables, "workflow_variables")
  expect_identical(wf$pre$actions$variables$variables$outcomes, quo(y))
  expect_identical(wf$pre$actions$variables$variables$predictors, quo(c(x1, x2)))
})

test_that("can add variables to a workflow with `variables` specification", {
  wf <- workflow()
  wf <- add_variables(wf, variables = workflow_variables(y, c(x1, x2)))

  expect_s3_class(wf$pre$actions$variables, "action_variables")
  expect_s3_class(wf$pre$actions$variables$variables, "workflow_variables")
  expect_identical(wf$pre$actions$variables$variables$outcomes, quo(y))
  expect_identical(wf$pre$actions$variables$variables$predictors, quo(c(x1, x2)))
})

test_that("cannot add variables if a recipe already exists", {
  rec <- recipes::recipe(mpg ~ cyl, mtcars)

  wf <- workflow()
  wf <- add_recipe(wf, rec)

  expect_snapshot(error = TRUE, add_variables(wf, y, x))
})

test_that("cannot add variables if a formula already exist", {
  wf <- workflow()
  wf <- add_formula(wf, mpg ~ cyl)

  expect_snapshot(error = TRUE, add_variables(wf, y, x))
})

test_that("informative error if either `predictors` or `outcomes` aren't provided (#144)", {
  expect_snapshot(error = TRUE, add_variables(workflow(), outcomes = mpg))
  expect_snapshot(error = TRUE, add_variables(workflow(), predictors = mpg))
})

test_that("works with fit()", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_variables(workflow, mpg, c(cyl, disp))
  workflow <- add_model(workflow, mod)

  result <- fit(workflow, mtcars)

  expect_identical(
    names(result$fit$fit$fit$coefficients),
    c("(Intercept)", "cyl", "disp")
  )

  expect_identical(
    names(result$pre$mold$outcomes),
    "mpg"
  )

  expect_identical(
    names(result$pre$mold$predictors),
    c("cyl", "disp")
  )
})

test_that("works with fit() when using `variables` specification", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  variables <- workflow_variables(mpg, c(cyl, disp))

  workflow <- workflow()
  workflow <- add_variables(workflow, variables = variables)
  workflow <- add_model(workflow, mod)

  result <- fit(workflow, mtcars)

  expect_identical(
    names(result$fit$fit$fit$coefficients),
    c("(Intercept)", "cyl", "disp")
  )

  expect_identical(
    names(result$pre$mold$outcomes),
    "mpg"
  )

  expect_identical(
    names(result$pre$mold$predictors),
    c("cyl", "disp")
  )
})

test_that("can use a `NULL` outcome", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_variables(workflow, NULL, c(cyl, disp))
  workflow <- add_model(workflow, mod)

  workflow <- .fit_pre(workflow, mtcars)

  expect_named(
    workflow$pre$mold$predictors,
    c("cyl", "disp")
  )

  expect_s3_class(
    workflow$pre$mold$outcomes,
    "tbl_df"
  )

  expect_identical(
    nrow(workflow$pre$mold$outcomes),
    32L
  )

  expect_identical(
    ncol(workflow$pre$mold$outcomes),
    0L
  )
})

test_that("can use `all_of(x)` when `x` is in the scope of `fit()`", {
  x <- c("disp", "cyl")

  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_variables(workflow, mpg, all_of(x))
  workflow <- add_model(workflow, mod)

  result <- fit(workflow, mtcars)

  expect_named(
    result$pre$mold$predictors,
    c("disp", "cyl")
  )
})

test_that("`outcomes` are removed from set of possible `predictors` (#72)", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, mod)

  workflow1 <- add_variables(workflow, mpg, everything())

  result <- .fit_pre(workflow1, mtcars)

  mtcars2 <- mtcars
  mtcars2$mpg <- NULL

  expect_identical(
    colnames(result$pre$mold$predictors),
    colnames(mtcars2)
  )

  expect_identical(
    colnames(result$pre$mold$outcomes),
    "mpg"
  )

  workflow2 <- add_variables(workflow, mpg, mpg)

  # vctrs subscript error
  expect_error(.fit_pre(workflow2, mtcars))
})

test_that("selecting no `outcomes` doesn't break selection of `predictors`", {
  mod <- parsnip::linear_reg()
  mod <- parsnip::set_engine(mod, "lm")

  workflow <- workflow()
  workflow <- add_variables(workflow, any_of("not_here"), everything())
  workflow <- add_model(workflow, mod)

  result <- .fit_pre(workflow, mtcars)

  expect_identical(
    colnames(result$pre$mold$predictors),
    colnames(mtcars)
  )

  expect_identical(
    colnames(result$pre$mold$outcomes),
    character()
  )

  expect_identical(
    nrow(result$pre$mold$outcomes),
    32L
  )
})

test_that("cannot add two variables", {
  workflow <- workflow()
  workflow <- add_variables(workflow, mpg, cyl)

  expect_snapshot(error = TRUE, add_variables(workflow, mpg, cyl))
  expect_snapshot(error = TRUE, add_variables(workflow, variables = workflow_variables(mpg, cyl)))
})

test_that("can remove variables", {
  wf <- workflow()
  wf1 <- add_variables(wf, mpg, cyl)
  wf2 <- remove_variables(wf1)

  expect_equal(wf2$pre, wf$pre)
})

test_that("can remove variables after model fit", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  wf <- workflow()
  wf <- add_model(wf, lm_model)

  wf1 <- add_variables(wf, mpg, cyl)
  wf1 <- fit(wf1, data = mtcars)

  wf2 <- remove_variables(wf1)

  expect_equal(wf$pre, wf2$pre)
  expect_equal(wf$fit, wf2$fit)
})

test_that("removing variables doesn't remove case weights", {
  wf <- workflow()
  wf <- add_variables(wf, mpg, cyl)
  wf <- add_case_weights(wf, disp)

  wf <- remove_variables(wf)

  expect_identical(names(wf$pre$actions), "case_weights")
})

test_that("can update a formula", {
  wf <- workflow()
  wf1 <- add_variables(wf, mpg, cyl)
  wf2 <- update_variables(wf1, cyl, mpg)
  wf3 <- update_variables(wf1, variables = workflow_variables(cyl, mpg))

  expect_identical(
    extract_preprocessor(wf2),
    new_workflow_variables(outcomes = quo(cyl), predictors = quo(mpg))
  )
  expect_identical(
    extract_preprocessor(wf3),
    new_workflow_variables(outcomes = quo(cyl), predictors = quo(mpg))
  )
})

test_that("can only use a 'xy_blueprint' blueprint", {
  blueprint <- hardhat::default_recipe_blueprint()

  workflow <- workflow()

  expect_snapshot(error = TRUE, add_variables(workflow, mpg, cyl, blueprint = blueprint))
})

test_that("can pass a blueprint through to hardhat::mold()", {
  lm_model <- parsnip::linear_reg()
  lm_model <- parsnip::set_engine(lm_model, "lm")

  blueprint <- hardhat::default_xy_blueprint(intercept = TRUE)

  workflow <- workflow()
  workflow <- add_model(workflow, lm_model)
  workflow <- add_variables(workflow, mpg, cyl, blueprint = blueprint)

  workflow <- fit(workflow, data = mtcars)

  expect_true("(Intercept)" %in% colnames(workflow$pre$mold$predictors))
  expect_identical(workflow$pre$actions$variables$blueprint, blueprint)
  expect_true(workflow$pre$mold$blueprint$intercept)
})

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on May 29, 2024, 3:57 a.m.