Nothing
test_that("tune recipe only", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_tune_1) %>%
add_model(helper_objects$lm_mod)
pset <- extract_parameter_set_dials(wflow) %>%
update(num_comp = dials::num_comp(c(1, 3)))
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
control <- control_grid(extract = identity)
res <- tune_grid(wflow, resamples = folds, grid = grid, control = control)
res_est <- collect_metrics(res)
res_workflow <- res$.extracts[[1]]$.extracts[[1]]
expect_equal(res, .Last.tune.result)
# Ensure tunable parameters in recipe are finalized
num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp
expect_equal(res$id, folds$id)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
expect_false(identical(num_comp, expr(tune())))
expect_true(res_workflow$trained)
})
# ------------------------------------------------------------------------------
test_that("tune model only (with recipe)", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_no_tune_1) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow)
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
control <- control_grid(extract = identity)
res <- tune_grid(wflow, resamples = folds, grid = grid, control = control)
res_est <- collect_metrics(res)
res_workflow <- res$.extracts[[1]]$.extracts[[1]]
# Ensure tunable parameters in spec are finalized
cost_quo <- res_workflow$fit$fit$spec$args$cost
cost <- quo_get_expr(cost_quo)
expect_equal(res$id, folds$id)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
expect_false(identical(cost, expr(tune())))
expect_true(res_workflow$trained)
})
# ------------------------------------------------------------------------------
test_that("tune model only (with variables)", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_variables(mpg, everything()) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow)
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
res <- tune_grid(wflow, resamples = folds, grid = grid)
expect_equal(res$id, folds$id)
res_est <- collect_metrics(res)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
})
# ------------------------------------------------------------------------------
test_that("tune model only (with recipe, multi-predict)", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_no_tune_1) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow)
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
res <- tune_grid(wflow, resamples = folds, grid = grid)
expect_equal(res$id, folds$id)
expect_equal(
colnames(res$.metrics[[1]]),
c("cost", ".metric", ".estimator", ".estimate", ".config")
)
res_est <- collect_metrics(res)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
})
# ------------------------------------------------------------------------------
test_that("tune model and recipe", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_tune_1) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow) %>%
update(num_comp = dials::num_comp(c(1, 3)))
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
control <- control_grid(extract = identity)
res <- tune_grid(wflow, resamples = folds, grid = grid, control = control)
res_est <- collect_metrics(res)
res_workflow <- res$.extracts[[1]]$.extracts[[1]]
# Ensure tunable parameters in spec are finalized
cost_quo <- res_workflow$fit$fit$spec$args$cost
cost <- quo_get_expr(cost_quo)
# Ensure tunable parameters in recipe are finalized
num_comp <- res_workflow$pre$actions$recipe$recipe$steps[[2]]$num_comp
expect_equal(res$id, folds$id)
expect_equal(
colnames(res$.metrics[[1]]),
c("cost", "num_comp", ".metric", ".estimator", ".estimate", ".config")
)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
expect_false(identical(cost, expr(tune())))
expect_false(identical(num_comp, expr(tune())))
expect_true(res_workflow$trained)
})
# ------------------------------------------------------------------------------
test_that("tune model and recipe (multi-predict)", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_tune_1) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow) %>%
update(num_comp = dials::num_comp(c(2, 3)))
grid <- dials::grid_regular(pset, levels = c(3, 2))
folds <- rsample::vfold_cv(mtcars)
res <- tune_grid(wflow, resamples = folds, grid = grid)
expect_equal(res$id, folds$id)
res_est <- collect_metrics(res)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
})
# ------------------------------------------------------------------------------
test_that('tune model and recipe (parallel_over = "everything")', {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_tune_1) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow) %>%
update(num_comp = dials::num_comp(c(1, 3)))
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
control <- control_grid(extract = identity, parallel_over = "everything")
res <- tune_grid(wflow, resamples = folds, grid = grid, control = control)
res_est <- collect_metrics(res)
expect_equal(res$id, folds$id)
expect_equal(
colnames(res$.metrics[[1]]),
c("cost", "num_comp", ".metric", ".estimator", ".estimate", ".config")
)
expect_equal(nrow(res_est), nrow(grid) * 2)
expect_equal(sum(res_est$.metric == "rmse"), nrow(grid))
expect_equal(sum(res_est$.metric == "rsq"), nrow(grid))
expect_equal(res_est$n, rep(10, nrow(grid) * 2))
})
# ------------------------------------------------------------------------------
test_that("tune recipe only - failure in recipe is caught elegantly", {
helper_objects <- helper_objects_tune()
set.seed(7898)
data_folds <- rsample::vfold_cv(mtcars, v = 2)
rec <- recipe(mpg ~ ., data = mtcars) %>%
step_bs(disp, deg_free = tune())
model <- linear_reg(mode = "regression") %>%
set_engine("lm")
# NA values not allowed in recipe
cars_grid <- tibble(deg_free = c(3, NA_real_, 4))
# ask for predictions and extractions
control <- control_grid(
save_pred = TRUE,
extract = function(x) 1L
)
suppressMessages({
cars_res <- tune_grid(
model,
preprocessor = rec,
resamples = data_folds,
grid = cars_grid,
control = control
)
})
notes <- cars_res$.notes
note <- notes[[1]]$note
extract <- cars_res$.extracts[[1]]
predictions <- cars_res$.predictions[[1]]
used_deg_free <- sort(unique(predictions$deg_free))
expect_length(notes, 2L)
# failing rows are not in the output
expect_equal(nrow(extract), 2L)
expect_equal(extract$deg_free, c(3, 4))
expect_equal(used_deg_free, c(3, 4))
})
test_that("tune model only - failure in recipe is caught elegantly", {
helper_objects <- helper_objects_tune()
set.seed(7898)
data_folds <- rsample::vfold_cv(mtcars, v = 2)
# NA values not allowed in recipe
rec <- recipe(mpg ~ ., data = mtcars) %>%
step_bs(disp, deg_free = NA_real_)
cars_grid <- tibble(cost = c(0.01, 0.02))
expect_snapshot(
cars_res <- tune_grid(
helper_objects$svm_mod,
preprocessor = rec,
resamples = data_folds,
grid = cars_grid,
control = control_grid(extract = function(x) {1}, save_pred = TRUE)
)
)
notes <- cars_res$.notes
note <- notes[[1]]$note
extracts <- cars_res$.extracts
predictions <- cars_res$.predictions
expect_length(notes, 2L)
# recipe failed - no models run
expect_equal(extracts, list(NULL, NULL))
expect_equal(predictions, list(NULL, NULL))
})
test_that("tune model only - failure in formula is caught elegantly", {
helper_objects <- helper_objects_tune()
set.seed(7898)
data_folds <- rsample::vfold_cv(mtcars, v = 2)
cars_grid <- tibble(cost = 0.01)
# these terms don't exist!
expect_snapshot(
cars_res <- tune_grid(
helper_objects$svm_mod,
y ~ z,
resamples = data_folds,
grid = cars_grid,
control = control_grid(extract = function(x) {1}, save_pred = TRUE)
)
)
notes <- cars_res$.notes
note <- notes[[1]]$note
extracts <- cars_res$.extracts
predictions <- cars_res$.predictions
expect_length(notes, 2L)
# formula failed - no models run
expect_equal(extracts, list(NULL, NULL))
expect_equal(predictions, list(NULL, NULL))
})
test_that("tune model and recipe - failure in recipe is caught elegantly", {
helper_objects <- helper_objects_tune()
set.seed(7898)
data_folds <- rsample::vfold_cv(mtcars, v = 2)
rec <- recipe(mpg ~ ., data = mtcars) %>%
step_bs(disp, deg_free = tune())
# NA values not allowed in recipe
cars_grid <- tibble(deg_free = c(NA_real_, 10L), cost = 0.01)
suppressMessages({
cars_res <- tune_grid(
helper_objects$svm_mod,
preprocessor = rec,
resamples = data_folds,
grid = cars_grid,
control = control_grid(extract = function(x) {1}, save_pred = TRUE)
)
})
notes <- cars_res$.notes
note <- notes[[1]]$note
extract <- cars_res$.extracts[[1]]
prediction <- cars_res$.predictions[[1]]
expect_length(notes, 2L)
# recipe failed half of the time, only 1 model passed
expect_equal(nrow(extract), 1L)
expect_equal(extract$deg_free, 10L)
expect_equal(extract$cost, 0.01)
expect_equal(
unique(prediction[, c("deg_free", "cost")]),
tibble(deg_free = 10, cost = 0.01)
)
})
test_that("argument order gives errors for recipes", {
helper_objects <- helper_objects_tune()
expect_snapshot(error = TRUE, {
tune_grid(
helper_objects$rec_tune_1,
helper_objects$lm_mod,
rsample::vfold_cv(mtcars, v = 2)
)
})
})
test_that("argument order gives errors for formula", {
helper_objects <- helper_objects_tune()
expect_snapshot(error = TRUE, {
tune_grid(mpg ~ ., helper_objects$lm_mod, rsample::vfold_cv(mtcars, v = 2))
})
})
test_that("ellipses with tune_grid", {
helper_objects <- helper_objects_tune()
wflow <- workflow() %>%
add_recipe(helper_objects$rec_tune_1) %>%
add_model(helper_objects$lm_mod)
folds <- rsample::vfold_cv(mtcars)
expect_snapshot(
tune_grid(wflow, resamples = folds, grid = 3, something = "wrong")
)
})
test_that("determining the grid type", {
grid_1 <- expand.grid(a = 1:100, b = letters[1:2])
expect_true(tune:::is_regular_grid(grid_1))
expect_true(tune:::is_regular_grid(grid_1[-(1:10), ]))
expect_false(tune:::is_regular_grid(grid_1[-(1:100), ]))
set.seed(1932)
grid_2 <- data.frame(a = runif(length(letters)), b = letters)
expect_false(tune:::is_regular_grid(grid_2))
})
test_that("retain extra attributes", {
helper_objects <- helper_objects_tune()
set.seed(4400)
wflow <- workflow() %>%
add_recipe(helper_objects$rec_no_tune_1) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow)
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
res <- tune_grid(wflow, resamples = folds, grid = grid)
att <- attributes(res)
att_names <- names(att)
expect_true(any(att_names == "metrics"))
expect_true(any(att_names == "outcomes"))
expect_true(any(att_names == "parameters"))
expect_true(is.character(att$outcomes))
expect_true(att$outcomes == "mpg")
expect_true(inherits(att$parameters, "parameters"))
expect_true(inherits(att$metrics, "metric_set"))
set.seed(4400)
wflow <- workflow() %>%
add_formula(mpg ~ .) %>%
add_model(helper_objects$svm_mod)
pset <- extract_parameter_set_dials(wflow)
grid <- dials::grid_regular(pset, levels = 3)
folds <- rsample::vfold_cv(mtcars)
res <- tune_grid(wflow, resamples = folds, grid = grid)
att <- attributes(res)
att_names <- names(att)
expect_true(any(att_names == "metrics"))
expect_true(any(att_names == "outcomes"))
expect_true(any(att_names == "parameters"))
expect_true(is.character(att$outcomes))
expect_true(att$outcomes == "mpg")
expect_true(inherits(att$parameters, "parameters"))
expect_true(inherits(att$metrics, "metric_set"))
res2 <- tune_grid(
wflow,
resamples = folds,
grid = grid,
control = control_grid(save_workflow = TRUE)
)
expect_null(attr(res, "workflow"))
expect_true(inherits(attr(res2, "workflow"), "workflow"))
wflow2 <- workflow() %>%
add_recipe(recipes::recipe(mpg ~ ., mtcars[rep(1:32, 3000), ])) %>%
add_model(helper_objects$svm_mod)
pset2 <- extract_parameter_set_dials(wflow2)
grid2 <- dials::grid_regular(pset2, levels = 3)
expect_message(
tune_grid(
wflow2,
resamples = folds,
grid = grid2,
control = control_grid(save_workflow = TRUE)
),
"being saved contains a recipe, which is"
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.