Nothing
data(Chicago, package = "modeldata")
# ------------------------------------------------------------------------------
# extract_preprocessor()
test_that("can extract a formula preprocessor", {
workflow <- workflow()
workflow <- add_formula(workflow, mpg ~ cyl)
expect_equal(
extract_preprocessor(workflow),
mpg ~ cyl
)
})
test_that("can extract a recipe preprocessor", {
recipe <- recipes::recipe(mpg ~ cyl, mtcars)
workflow <- workflow()
workflow <- add_recipe(workflow, recipe)
expect_equal(
extract_preprocessor(workflow),
recipe
)
})
test_that("can extract a variables preprocessor", {
variables <- workflow_variables(mpg, c(cyl, disp))
workflow <- workflow()
workflow <- add_variables(workflow, variables = variables)
expect_identical(
extract_preprocessor(workflow),
variables
)
})
test_that("error if no preprocessor", {
expect_snapshot(error = TRUE, extract_preprocessor(workflow()))
})
test_that("error if not a workflow", {
expect_snapshot(error = TRUE, extract_preprocessor(1))
})
# ------------------------------------------------------------------------------
# extract_spec_parsnip()
test_that("can extract a model spec", {
model <- parsnip::linear_reg()
workflow <- workflow()
workflow <- add_model(workflow, model)
expect_equal(
extract_spec_parsnip(workflow),
model
)
})
test_that("error if no spec", {
expect_snapshot(error = TRUE, extract_spec_parsnip(workflow()))
})
test_that("error if not a workflow", {
expect_snapshot(error = TRUE, extract_spec_parsnip(1))
})
# ------------------------------------------------------------------------------
# extract_fit_parsnip()
test_that("can extract a parsnip model fit", {
model <- parsnip::linear_reg()
model <- parsnip::set_engine(model, "lm")
workflow <- workflow()
workflow <- add_model(workflow, model)
workflow <- add_formula(workflow, mpg ~ cyl)
workflow <- fit(workflow, mtcars)
expect_equal(
extract_fit_parsnip(workflow),
workflow$fit$fit
)
})
test_that("error if no parsnip fit", {
expect_snapshot(error = TRUE, extract_fit_parsnip(workflow()))
})
test_that("error if not a workflow", {
expect_snapshot(error = TRUE, extract_fit_parsnip(1))
})
# ------------------------------------------------------------------------------
# extract_fit_engine()
test_that("can extract a engine model fit", {
model <- parsnip::linear_reg()
model <- parsnip::set_engine(model, "lm")
workflow <- workflow()
workflow <- add_model(workflow, model)
workflow <- add_formula(workflow, mpg ~ cyl)
workflow <- fit(workflow, mtcars)
expect_equal(
extract_fit_engine(workflow),
workflow$fit$fit$fit
)
})
# ------------------------------------------------------------------------------
# extract_mold()
test_that("can extract a mold", {
model <- parsnip::linear_reg()
model <- parsnip::set_engine(model, "lm")
workflow <- workflow()
workflow <- add_model(workflow, model)
workflow <- add_formula(workflow, mpg ~ cyl)
workflow <- fit(workflow, mtcars)
expect_type(extract_mold(workflow), "list")
expect_equal(
extract_mold(workflow),
workflow$pre$mold
)
})
test_that("error if no mold", {
expect_snapshot(error = TRUE, extract_mold(workflow()))
})
test_that("error if not a workflow", {
expect_snapshot(error = TRUE, extract_mold(1))
})
# ------------------------------------------------------------------------------
# extract_recipe()
test_that("can extract a prepped recipe", {
model <- parsnip::linear_reg()
model <- parsnip::set_engine(model, "lm")
recipe <- recipes::recipe(mpg ~ cyl, mtcars)
workflow <- workflow()
workflow <- add_model(workflow, model)
workflow <- add_recipe(workflow, recipe)
workflow <- fit(workflow, mtcars)
expect_s3_class(extract_recipe(workflow), "recipe")
expect_equal(
extract_recipe(workflow),
workflow$pre$mold$blueprint$recipe
)
expect_snapshot(error = TRUE, extract_recipe(workflow, FALSE))
expect_snapshot(error = TRUE, extract_recipe(workflow, estimated = "yes please"))
})
test_that("error if no recipe preprocessor", {
expect_snapshot(error = TRUE, extract_recipe(workflow()))
})
test_that("error if no mold", {
recipe <- recipes::recipe(mpg ~ cyl, mtcars)
workflow <- workflow()
workflow <- add_recipe(workflow, recipe)
expect_snapshot(error = TRUE, extract_recipe(workflow))
})
test_that("error if not a workflow", {
expect_snapshot(error = TRUE, extract_recipe(1))
})
# ------------------------------------------------------------------------------
# extract_parameter_set_dials()
test_that("extract parameter set from workflow with tunable recipe", {
spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(),
neighbors = hardhat::tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::step_bs(recipes::all_predictors(),
deg_free = hardhat::tune(), degree = hardhat::tune())
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
wf_tunable_recipe <- workflow(spline_rec, lm_model)
wf_info <- extract_parameter_set_dials(wf_tunable_recipe)
check_parameter_set_tibble(wf_info)
expect_true(all(wf_info$source == "recipe"))
})
test_that("extract parameter set from workflow with tunable model", {
rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_tunable_model <- workflow(rm_rec, bst_model)
wf_info <- extract_parameter_set_dials(wf_tunable_model)
check_parameter_set_tibble(wf_info)
expect_equal(nrow(wf_info), 2)
expect_true(all(wf_info$source == "model_spec"))
})
test_that("extract parameter set from workflow with tunable recipe and model", {
spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(),
neighbors = hardhat::tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::step_bs(recipes::all_predictors(),
deg_free = hardhat::tune(), degree = hardhat::tune())
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_tunable <- workflow(spline_rec, bst_model)
wf_info <- extract_parameter_set_dials(wf_tunable)
check_parameter_set_tibble(wf_info)
expect_equal(
wf_info$source,
c(rep("model_spec", 2), rep("recipe", 4))
)
})
# ------------------------------------------------------------------------------
# extract_parameter_dials()
test_that("extract single parameter from workflow with tunable recipe", {
spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(),
neighbors = hardhat::tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::step_bs(recipes::all_predictors(),
deg_free = hardhat::tune(), degree = hardhat::tune())
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
wf_tunable_recipe <- workflow(spline_rec, lm_model)
expect_equal(
extract_parameter_dials(wf_tunable_recipe, "imputation"),
dials::neighbors()
)
expect_equal(
extract_parameter_dials(wf_tunable_recipe, "threshold"),
dials::threshold(c(0, 1/10))
)
expect_equal(
extract_parameter_dials(wf_tunable_recipe, "deg_free"),
dials::spline_degree(range = c(1, 15))
)
expect_equal(
extract_parameter_dials(wf_tunable_recipe, "degree"),
dials::degree_int(c(1, 2))
)
})
test_that("extract single parameter from workflow with tunable model", {
rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_tunable_model <- workflow(rm_rec, bst_model)
expect_equal(
hardhat::extract_parameter_dials(wf_tunable_model, parameter = "funky name \n"),
dials::trees(c(1, 100))
)
expect_equal(
extract_parameter_dials(wf_tunable_model, parameter = "rules"),
NA
)
})
test_that("extract single parameter from workflow with tunable recipe and model", {
spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(),
neighbors = hardhat::tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::step_bs(recipes::all_predictors(),
deg_free = hardhat::tune(), degree = hardhat::tune())
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_tunable <- workflow(spline_rec, bst_model)
expect_equal(
extract_parameter_dials(wf_tunable, "imputation"),
dials::neighbors()
)
expect_equal(
extract_parameter_dials(wf_tunable, "threshold"),
dials::threshold(c(0, 1/10))
)
expect_equal(
extract_parameter_dials(wf_tunable, "deg_free"),
dials::spline_degree(range = c(1, 15))
)
expect_equal(
extract_parameter_dials(wf_tunable, "degree"),
dials::degree_int(c(1, 2))
)
expect_equal(
hardhat::extract_parameter_dials(wf_tunable, parameter = "funky name \n"),
dials::trees(c(1, 100))
)
expect_equal(
extract_parameter_dials(wf_tunable, parameter = "rules"),
NA
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.