tests/testthat/test-extract.R

data(Chicago, package = "modeldata")

# ------------------------------------------------------------------------------
# extract_preprocessor()

test_that("can extract a formula preprocessor", {
  workflow <- workflow()
  workflow <- add_formula(workflow, mpg ~ cyl)

  expect_equal(
    extract_preprocessor(workflow),
    mpg ~ cyl
  )
})

test_that("can extract a recipe preprocessor", {
  recipe <- recipes::recipe(mpg ~ cyl, mtcars)

  workflow <- workflow()
  workflow <- add_recipe(workflow, recipe)

  expect_equal(
    extract_preprocessor(workflow),
    recipe
  )
})

test_that("can extract a variables preprocessor", {
  variables <- workflow_variables(mpg, c(cyl, disp))

  workflow <- workflow()
  workflow <- add_variables(workflow, variables = variables)

  expect_identical(
    extract_preprocessor(workflow),
    variables
  )
})

test_that("error if no preprocessor", {
  expect_snapshot(error = TRUE, extract_preprocessor(workflow()))
})

test_that("error if not a workflow", {
  expect_snapshot(error = TRUE, extract_preprocessor(1))
})

# ------------------------------------------------------------------------------
# extract_spec_parsnip()

test_that("can extract a model spec", {
  model <- parsnip::linear_reg()

  workflow <- workflow()
  workflow <- add_model(workflow, model)

  expect_equal(
    extract_spec_parsnip(workflow),
    model
  )
})

test_that("error if no spec", {
  expect_snapshot(error = TRUE, extract_spec_parsnip(workflow()))
})

test_that("error if not a workflow", {
  expect_snapshot(error = TRUE, extract_spec_parsnip(1))
})

# ------------------------------------------------------------------------------
# extract_fit_parsnip()

test_that("can extract a parsnip model fit", {
  model <- parsnip::linear_reg()
  model <- parsnip::set_engine(model, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, model)
  workflow <- add_formula(workflow, mpg ~ cyl)

  workflow <- fit(workflow, mtcars)

  expect_equal(
    extract_fit_parsnip(workflow),
    workflow$fit$fit
  )
})

test_that("error if no parsnip fit", {
  expect_snapshot(error = TRUE, extract_fit_parsnip(workflow()))
})

test_that("error if not a workflow", {
  expect_snapshot(error = TRUE, extract_fit_parsnip(1))
})

# ------------------------------------------------------------------------------
# extract_fit_engine()

test_that("can extract a engine model fit", {
  model <- parsnip::linear_reg()
  model <- parsnip::set_engine(model, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, model)
  workflow <- add_formula(workflow, mpg ~ cyl)

  workflow <- fit(workflow, mtcars)

  expect_equal(
    extract_fit_engine(workflow),
    workflow$fit$fit$fit
  )
})

# ------------------------------------------------------------------------------
# extract_mold()

test_that("can extract a mold", {
  model <- parsnip::linear_reg()
  model <- parsnip::set_engine(model, "lm")

  workflow <- workflow()
  workflow <- add_model(workflow, model)
  workflow <- add_formula(workflow, mpg ~ cyl)

  workflow <- fit(workflow, mtcars)

  expect_type(extract_mold(workflow), "list")

  expect_equal(
    extract_mold(workflow),
    workflow$pre$mold
  )
})

test_that("error if no mold", {
  expect_snapshot(error = TRUE, extract_mold(workflow()))
})

test_that("error if not a workflow", {
  expect_snapshot(error = TRUE, extract_mold(1))
})

# ------------------------------------------------------------------------------
# extract_recipe()

test_that("can extract a prepped recipe", {
  model <- parsnip::linear_reg()
  model <- parsnip::set_engine(model, "lm")

  recipe <- recipes::recipe(mpg ~ cyl, mtcars)

  workflow <- workflow()
  workflow <- add_model(workflow, model)
  workflow <- add_recipe(workflow, recipe)

  workflow <- fit(workflow, mtcars)

  expect_s3_class(extract_recipe(workflow), "recipe")

  expect_equal(
    extract_recipe(workflow),
    workflow$pre$mold$blueprint$recipe
  )

  expect_snapshot(error = TRUE, extract_recipe(workflow, FALSE))
  expect_snapshot(error = TRUE, extract_recipe(workflow, estimated = "yes please"))
})

test_that("error if no recipe preprocessor", {
  expect_snapshot(error = TRUE, extract_recipe(workflow()))
})

test_that("error if no mold", {
  recipe <- recipes::recipe(mpg ~ cyl, mtcars)

  workflow <- workflow()
  workflow <- add_recipe(workflow, recipe)

  expect_snapshot(error = TRUE, extract_recipe(workflow))
})

test_that("error if not a workflow", {
  expect_snapshot(error = TRUE, extract_recipe(1))
})

# ------------------------------------------------------------------------------
# extract_parameter_set_dials()

test_that("extract parameter set from workflow with tunable recipe", {

  spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
    recipes::step_date(date) %>%
    recipes::step_holiday(date) %>%
    recipes::step_rm(date, ends_with("away")) %>%
    recipes::step_impute_knn(recipes::all_predictors(),
                             neighbors = hardhat::tune("imputation")) %>%
    recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
    recipes::step_dummy(recipes::all_nominal()) %>%
    recipes::step_normalize(recipes::all_predictors()) %>%
    recipes::step_bs(recipes::all_predictors(),
                     deg_free = hardhat::tune(), degree = hardhat::tune())
  lm_model <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")
  wf_tunable_recipe <- workflow(spline_rec, lm_model)

  wf_info <- extract_parameter_set_dials(wf_tunable_recipe)
  check_parameter_set_tibble(wf_info)
  expect_true(all(wf_info$source == "recipe"))
})

test_that("extract parameter set from workflow with tunable model", {

  rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
    recipes::step_rm(date, ends_with("away"))
  bst_model <-
    parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
    parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
  wf_tunable_model <- workflow(rm_rec, bst_model)

  wf_info <- extract_parameter_set_dials(wf_tunable_model)
  check_parameter_set_tibble(wf_info)
  expect_equal(nrow(wf_info), 2)
  expect_true(all(wf_info$source == "model_spec"))
})

test_that("extract parameter set from workflow with tunable recipe and model", {

  spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
    recipes::step_date(date) %>%
    recipes::step_holiday(date) %>%
    recipes::step_rm(date, ends_with("away")) %>%
    recipes::step_impute_knn(recipes::all_predictors(),
                             neighbors = hardhat::tune("imputation")) %>%
    recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
    recipes::step_dummy(recipes::all_nominal()) %>%
    recipes::step_normalize(recipes::all_predictors()) %>%
    recipes::step_bs(recipes::all_predictors(),
                     deg_free = hardhat::tune(), degree = hardhat::tune())
  bst_model <-
    parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
    parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
  wf_tunable <- workflow(spline_rec, bst_model)

  wf_info <- extract_parameter_set_dials(wf_tunable)
  check_parameter_set_tibble(wf_info)
  expect_equal(
    wf_info$source,
    c(rep("model_spec", 2), rep("recipe", 4))
  )
})


# ------------------------------------------------------------------------------
# extract_parameter_dials()

test_that("extract single parameter from workflow with tunable recipe", {

  spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
    recipes::step_date(date) %>%
    recipes::step_holiday(date) %>%
    recipes::step_rm(date, ends_with("away")) %>%
    recipes::step_impute_knn(recipes::all_predictors(),
                             neighbors = hardhat::tune("imputation")) %>%
    recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
    recipes::step_dummy(recipes::all_nominal()) %>%
    recipes::step_normalize(recipes::all_predictors()) %>%
    recipes::step_bs(recipes::all_predictors(),
                     deg_free = hardhat::tune(), degree = hardhat::tune())
  lm_model <- parsnip::linear_reg() %>%
    parsnip::set_engine("lm")
  wf_tunable_recipe <- workflow(spline_rec, lm_model)

  expect_equal(
    extract_parameter_dials(wf_tunable_recipe, "imputation"),
    dials::neighbors()
  )
  expect_equal(
    extract_parameter_dials(wf_tunable_recipe, "threshold"),
    dials::threshold(c(0, 1/10))
  )
  expect_equal(
    extract_parameter_dials(wf_tunable_recipe, "deg_free"),
    dials::spline_degree(range = c(1, 15))
  )
  expect_equal(
    extract_parameter_dials(wf_tunable_recipe, "degree"),
    dials::degree_int(c(1, 2))
  )
})

test_that("extract single parameter from workflow with tunable model", {

  rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
    recipes::step_rm(date, ends_with("away"))
  bst_model <-
    parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
    parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
  wf_tunable_model <- workflow(rm_rec, bst_model)

  expect_equal(
    hardhat::extract_parameter_dials(wf_tunable_model, parameter = "funky name \n"),
    dials::trees(c(1, 100))
  )
  expect_equal(
    extract_parameter_dials(wf_tunable_model, parameter = "rules"),
    NA
  )
})

test_that("extract single parameter from workflow with tunable recipe and model", {

  spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
    recipes::step_date(date) %>%
    recipes::step_holiday(date) %>%
    recipes::step_rm(date, ends_with("away")) %>%
    recipes::step_impute_knn(recipes::all_predictors(),
                             neighbors = hardhat::tune("imputation")) %>%
    recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
    recipes::step_dummy(recipes::all_nominal()) %>%
    recipes::step_normalize(recipes::all_predictors()) %>%
    recipes::step_bs(recipes::all_predictors(),
                     deg_free = hardhat::tune(), degree = hardhat::tune())
  bst_model <-
    parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
    parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
  wf_tunable <- workflow(spline_rec, bst_model)

  expect_equal(
    extract_parameter_dials(wf_tunable, "imputation"),
    dials::neighbors()
  )
  expect_equal(
    extract_parameter_dials(wf_tunable, "threshold"),
    dials::threshold(c(0, 1/10))
  )
  expect_equal(
    extract_parameter_dials(wf_tunable, "deg_free"),
    dials::spline_degree(range = c(1, 15))
  )
  expect_equal(
    extract_parameter_dials(wf_tunable, "degree"),
    dials::degree_int(c(1, 2))
  )
  expect_equal(
    hardhat::extract_parameter_dials(wf_tunable, parameter = "funky name \n"),
    dials::trees(c(1, 100))
  )
  expect_equal(
    extract_parameter_dials(wf_tunable, parameter = "rules"),
    NA
  )
})

Try the workflows package in your browser

Any scripts or data that you put into this service are public.

workflows documentation built on May 29, 2024, 3:57 a.m.