Nothing
library(parsnip)
library(rsample)
library(recipes)
data(Chicago, package = "modeldata")
lr_spec <- linear_reg() %>% set_engine("lm")
set.seed(1)
car_set_1 <-
workflow_set(
list(
reg = recipe(mpg ~ ., data = mtcars) %>% step_log(disp),
nonlin = mpg ~ wt + 1 / sqrt(disp)
),
list(lm = lr_spec)
) %>%
workflow_map("fit_resamples",
resamples = vfold_cv(mtcars, v = 3),
control = tune::control_resamples(save_pred = TRUE)
)
# ------------------------------------------------------------------------------
test_that("extracts", {
# workflows specific errors, so we don't capture their messages
expect_error(extract_fit_engine(car_set_1, id = "reg_lm"))
expect_error(extract_fit_parsnip(car_set_1, id = "reg_lm"))
expect_error(extract_mold(car_set_1, id = "reg_lm"))
expect_error(extract_recipe(car_set_1, id = "reg_lm"))
expect_s3_class(
extract_preprocessor(car_set_1, id = "reg_lm"),
"recipe"
)
expect_s3_class(
extract_spec_parsnip(car_set_1, id = "reg_lm"),
"model_spec"
)
expect_s3_class(
extract_workflow(car_set_1, id = "reg_lm"),
"workflow"
)
expect_s3_class(
extract_recipe(car_set_1, id = "reg_lm", estimated = FALSE),
"recipe"
)
expect_equal(
car_set_1 %>% extract_workflow("reg_lm"),
car_set_1$info[[1]]$workflow[[1]]
)
expect_equal(
car_set_1 %>% extract_workflow_set_result("reg_lm"),
car_set_1$result[[1]]
)
expect_snapshot(error = TRUE, {
car_set_1 %>% extract_workflow_set_result("Gideon Nav")
})
expect_snapshot(error = TRUE, {
car_set_1 %>% extract_workflow("Coronabeth Tridentarius")
})
})
test_that("extract parameter set from workflow set with untunable workflow", {
rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_set <- workflow_set(
list(reg = rm_rec),
list(lm = lm_model, bst = bst_model)
)
lm_info <- hardhat::extract_parameter_set_dials(wf_set, id = "reg_lm")
check_parameter_set_tibble(lm_info)
expect_equal(nrow(lm_info), 0)
})
test_that("extract parameter set from workflow set with tunable workflow", {
rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_set <- workflow_set(
list(reg = rm_rec),
list(lm = lm_model, bst = bst_model)
)
c5_info <- extract_parameter_set_dials(wf_set, id = "reg_bst")
expect_equal(
c5_info,
extract_parameter_set_dials(bst_model)
)
check_parameter_set_tibble(c5_info)
expect_equal(nrow(c5_info), 2)
expect_true(all(c5_info$source == "model_spec"))
expect_true(all(c5_info$component == "boost_tree"))
expect_equal(c5_info$component_id, c("main", "engine"))
nms <- c("trees", "rules")
expect_equal(c5_info$name, nms)
ids <- c("funky name \n", "rules")
expect_equal(c5_info$id, ids)
expect_equal(c5_info$object[[1]], dials::trees(c(1, 100)))
expect_equal(c5_info$object[[2]], NA)
c5_new_info <-
c5_info %>%
update(
rules = dials::new_qual_param("logical",
values = c(TRUE, FALSE),
label = c(rules = "Rules"))
)
wf_set_2 <-
wf_set %>%
option_add(id = "reg_bst", param_info = c5_new_info)
check_parameter_set_tibble(c5_new_info)
expect_s3_class(c5_new_info$object[[2]], "qual_param")
expect_equal(
c5_new_info,
extract_parameter_set_dials(wf_set_2, "reg_bst")
)
})
test_that("extract single parameter from workflow set with untunable workflow", {
rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_set <- workflow_set(
list(reg = rm_rec),
list(lm = lm_model, bst = bst_model)
)
expect_error(
hardhat::extract_parameter_dials(wf_set, id = "reg_lm", parameter = "non there")
)
})
test_that("extract single parameter from workflow set with tunable workflow", {
rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wf_set <- workflow_set(
list(reg = rm_rec),
list(lm = lm_model, bst = bst_model)
)
expect_equal(
hardhat::extract_parameter_dials(wf_set, id = "reg_bst", parameter = "funky name \n"),
dials::trees(c(1, 100))
)
expect_equal(
extract_parameter_dials(wf_set, id = "reg_bst", parameter = "rules"),
NA
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.