# `get_param_info()` -----------------------------------------------------
test_that("`get_param_info()` works for a workflow without tags for tuning", {
wflow <- workflow(mpg ~ ., parsnip::linear_reg())
param_info <- get_param_info(wflow)
expect_named(param_info, c("name", "id", "source", "has_submodel"))
expect_identical(nrow(param_info), 0L)
})
test_that("`get_param_info()` works for a workflow with tags for tuning", {
skip_if_not_installed("splines2")
skip_if_not_installed("probably")
# tuning tags in all components
wflow <- workflow(rec_tune, mod_tune_no_submodel, tlr_tune)
param_info <- get_param_info(wflow)
expect_named(param_info, c("name", "id", "source", "has_submodel"))
expect_identical(
param_info$id,
c("min_n", "threshold", "disp_df", "lower_limit")
)
expect_identical(param_info$has_submodel, c(FALSE, NA, NA, NA))
})
test_that("`get_param_info()` works when there are submodel parameters", {
skip_if_not_installed("probably")
# tuning tags only in model spec
rec_no_steps <- recipes::recipe(mpg ~ ., data = mtcars)
wflow <- workflow(rec_no_steps, mod_tune, tlr_no_tune)
param_info <- get_param_info(wflow)
expect_named(param_info, c("name", "id", "source", "has_submodel"))
expect_identical(param_info$name, c("trees", "min_n"))
expect_identical(param_info$has_submodel, c(TRUE, FALSE))
})
# `schedule_predict_stage_i()` -------------------------------------------
test_that("`schedule_predict_stage_i()` works with: no submodel, no post-processing", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_no_tune, tlr_no_tune)
param_info <- get_param_info(wflow)
grid_predict_stage <- tibble::tibble()
schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
expect_named(schedule, c("post_stage"))
expect_identical(nrow(schedule), 0L)
})
test_that("`schedule_predict_stage_i()` works with: no submodel, with post-processing", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_no_tune, tlr_tune)
param_info <- get_param_info(wflow)
grid_predict_stage <- tibble::tibble(lower_limit = 1:2)
schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
expect_named(schedule, c("post_stage"))
expect_identical(nrow(schedule), 1L)
expect_identical(schedule$post_stage[[1]], grid_predict_stage)
})
test_that("`schedule_predict_stage_i()` works with: with submodel, no post-processing", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune_submodel, tlr_no_tune)
param_info <- get_param_info(wflow)
grid_predict_stage <- tibble::tibble(trees = 1:2)
schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
expect_named(schedule, c("trees", "post_stage"))
expect_identical(schedule$trees, grid_predict_stage$trees)
expect_identical(
purrr::map_int(schedule$post_stage, ncol),
c(0L, 0L)
)
})
test_that("`schedule_predict_stage_i()` works with: with submodel, with post-processing", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune_submodel, tlr_tune)
param_info <- get_param_info(wflow)
# semi-regular grid
grid_predict_stage <- list(
tibble::tibble(trees = 1L, lower_limit = 1L),
tibble::tibble(trees = c(2L, 2L), lower_limit = 1:2)
) %>%
purrr::list_rbind()
schedule <- schedule_predict_stage_i(grid_predict_stage, param_info)
expect_named(schedule, c("trees", "post_stage"))
expect_identical(schedule$trees, 1:2)
expect_identical(
schedule$post_stage[[1]],
tibble::tibble(lower_limit = 1L)
)
expect_identical(
schedule$post_stage[[2]],
tibble::tibble(lower_limit = 1:2)
)
})
# `schedule_model_stage_i()` ---------------------------------------------
test_that("`schedule_model_stage_i()` works with: no tuning at all", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_no_tune, tlr_no_tune)
param_info <- get_param_info(wflow)
grid_model_stage <- tibble::tibble()
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("predict_stage"))
expect_identical(nrow(schedule), 0L)
})
test_that("`schedule_model_stage_i()` works with only non-submodel: with non-submodel, no submodel, no post", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune_no_submodel, tlr_no_tune)
param_info <- get_param_info(wflow)
grid_model_stage <- tibble::tibble(min_n = 1:2)
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("min_n", "predict_stage"))
expect_identical(nrow(schedule), 2L)
expect_identical(schedule$min_n, 1:2)
expect_identical(
purrr::map_chr(schedule$predict_stage, ~ class(.)[1]),
c("tbl_df", "tbl_df")
)
expect_identical(
purrr::map_chr(schedule$predict_stage, names),
c("post_stage", "post_stage")
)
})
test_that("`schedule_model_stage_i()` works with only submodel: no non-submodel, with submodel, no post", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune_submodel, tlr_no_tune)
param_info <- get_param_info(wflow)
grid_model_stage <- tibble::tibble(trees = 1:2)
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("trees", "predict_stage"))
expect_identical(nrow(schedule), 1L)
expect_identical(schedule$trees, 2L)
expect_identical(
purrr::map_chr(schedule$predict_stage, ~ class(.)[1]),
c("tbl_df")
)
expect_identical(
purrr::map(schedule$predict_stage, names),
list(c("trees", "post_stage"))
)
expect_identical(
schedule$predict_stage[[1]] %>% pull("trees"),
1:2
)
})
test_that("`schedule_model_stage_i()` works with only post: no non-submodel, no submodel, with post", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_no_tune, tlr_tune)
param_info <- get_param_info(wflow)
grid_model_stage <- tibble::tibble(lower_limit = 1:2)
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("predict_stage"))
expect_identical(nrow(schedule), 1L)
expect_identical(
purrr::map_chr(schedule$predict_stage, ~ class(.)[1]),
c("tbl_df")
)
expect_identical(
purrr::map(schedule$predict_stage, names),
list(c("post_stage"))
)
expect_identical(
schedule$predict_stage[[1]] %>% pull("post_stage"),
list(tibble::tibble(lower_limit = 1:2))
)
})
test_that("`schedule_model_stage_i()` works with both model types only: with non-submodel, with submodel, no post", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune, tlr_no_tune)
param_info <- get_param_info(wflow)
# irregular grid
grid_model_stage <- list(
tibble::tibble(trees = 1:2, min_n = 1L),
tibble::tibble(trees = 2L, min_n = 2L),
tibble::tibble(trees = 2:3, min_n = 3L)
) %>%
purrr::list_rbind()
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("trees", "min_n", "predict_stage"))
expect_identical(nrow(schedule), 3L)
expect_identical(schedule$trees, c(2L, 2L, 3L))
expect_identical(schedule$min_n, 1:3)
expect_identical(
purrr::map_chr(schedule$predict_stage, ~ class(.)[1]) %>% unique(),
"tbl_df"
)
expect_identical(
purrr::map(schedule$predict_stage, names),
list(
c("trees", "post_stage"),
c("trees", "post_stage"),
c("trees", "post_stage")
)
)
expect_identical(
schedule %>%
filter(min_n == 1) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("trees"),
1:2
)
expect_identical(
schedule %>%
filter(min_n == 2) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("trees"),
2L
)
expect_identical(
schedule %>%
filter(min_n == 3) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("trees"),
2:3
)
})
test_that("`schedule_model_stage_i()` works without submodel: with non-submodel, no submodel, with post", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune_no_submodel, tlr_tune)
param_info <- get_param_info(wflow)
# semi-regular grid
grid_model_stage <- list(
tibble::tibble(min_n = 1L, lower_limit = 1:2),
tibble::tibble(min_n = 2L, lower_limit = 1:3),
tibble::tibble(min_n = 3L, lower_limit = 3L)
) %>%
purrr::list_rbind()
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("min_n", "predict_stage"))
expect_identical(nrow(schedule), 3L)
expect_identical(schedule$min_n, 1:3)
expect_identical(
purrr::map_chr(schedule$predict_stage, ~ class(.)[1]) %>% unique(),
"tbl_df"
)
expect_identical(
purrr::map(schedule$predict_stage, names) %>%
purrr::list_c() %>%
unique(),
"post_stage"
)
expect_identical(
schedule %>%
filter(min_n == 1) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("post_stage") %>%
purrr::pluck(1) %>%
pull("lower_limit"),
1:2
)
expect_identical(
schedule %>%
filter(min_n == 2) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("post_stage") %>%
purrr::pluck(1) %>%
pull("lower_limit"),
1:3
)
expect_identical(
schedule %>%
filter(min_n == 3) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("post_stage") %>%
purrr::pluck(1) %>%
pull("lower_limit"),
3L
)
})
test_that("`schedule_model_stage_i()` works everything: with non-submodel, with submodel, with post", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_tune, tlr_tune)
param_info <- get_param_info(wflow)
# semi-regular grid
grid_model_stage <- list(
tibble::tibble(
min_n = 1L,
trees = c(1L, 1L, 2L, 2L),
lower_limit = c(1:2, 1:2)
),
tibble::tibble(
min_n = 2L,
trees = c(1L, rep(2L, 3)),
lower_limit = c(1L, 1:3)
),
# another row to be combined under min_n = 1
tibble::tibble(min_n = 1L, trees = 3L, lower_limit = 4L)
) %>%
purrr::list_rbind()
schedule <- schedule_model_stage_i(grid_model_stage, param_info, wflow)
expect_named(schedule, c("trees", "min_n", "predict_stage"))
expect_identical(nrow(schedule), 2L)
expect_identical(schedule$min_n, 1:2)
expect_identical(
purrr::map_chr(schedule$predict_stage, ~ class(.)[1]) %>% unique(),
"tbl_df"
)
expect_identical(
purrr::map(schedule$predict_stage, names) %>%
purrr::list_c() %>%
unique(),
c("trees", "post_stage")
)
expect_identical(
schedule %>%
filter(min_n == 1) %>%
select(-predict_stage),
tibble::tibble(trees = 3L, min_n = 1L)
)
expect_identical(
schedule %>%
filter(min_n == 1) %>%
pull(predict_stage) %>%
purrr::pluck(1) %>%
pull(trees),
1:3
)
expect_identical(
schedule %>%
filter(min_n == 1) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("post_stage") %>%
purrr::list_rbind() %>%
pull("lower_limit"),
c(1:2, 1:2, 4L)
)
expect_identical(
schedule %>%
filter(min_n == 2) %>%
select(-predict_stage),
tibble::tibble(trees = 2L, min_n = 2L)
)
expect_identical(
schedule %>%
filter(min_n == 2) %>%
pull(predict_stage) %>%
purrr::pluck(1) %>%
pull(trees),
1:2
)
expect_identical(
schedule %>%
filter(min_n == 2) %>%
pull("predict_stage") %>%
purrr::pluck(1) %>%
pull("post_stage") %>%
purrr::list_rbind() %>%
pull("lower_limit"),
c(1L, 1:3)
)
})
# `schedule_stages()` ----------------------------------------------------
test_that("`schedule_stages()` works without preprocessing", {
skip_if_not_installed("probably")
wflow <- workflow(mpg ~ ., mod_no_tune, tlr_no_tune)
grid <- tibble::tibble()
schedule <- schedule_stages(grid, wflow)
expect_named(schedule, c("model_stage"))
expect_identical(nrow(schedule), 0L)
})
test_that("`schedule_stages()` works with preprocessing", {
skip_if_not_installed("splines2")
skip_if_not_installed("probably")
wflow <- workflow(rec_tune, mod_no_tune, tlr_no_tune)
grid <- tibble::tibble(
threshold = rep(1:2, each = 2),
disp_df = c(1:2, 1:2)
)
schedule <- schedule_stages(grid, wflow)
expect_named(schedule, c("threshold", "disp_df", "model_stage"))
expect_identical(nrow(schedule), 4L)
expect_identical(
schedule %>% dplyr::select(-model_stage),
grid
)
})
# `schedule_grid()` ------------------------------------------------------
# No tuning or postprocesing estimation
test_that("grid processing schedule - no parameters", {
wflow_nada <- workflow(outcome ~ ., parsnip::logistic_reg())
grid_nada <- tibble::tibble()
sched_nada <- schedule_grid(grid_nada, wflow_nada)
expect_named(sched_nada, "model_stage")
expect_equal(nrow(sched_nada), 0)
expect_s3_class(
sched_nada,
c("resample_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe and model", {
skip_if_not_installed("splines2")
wflow_pre_only <- workflow(rec_no_tune, parsnip::logistic_reg())
grid_pre_only <- tibble::tibble()
sched_pre_only <- schedule_grid(grid_pre_only, wflow_pre_only)
expect_named(sched_pre_only, c("model_stage"))
expect_equal(nrow(sched_pre_only), 0)
expect_s3_class(
sched_pre_only,
c("resample_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe, model, and post", {
skip_if_not_installed("splines2")
skip_if_not_installed("probably")
wflow_three <- workflow(rec_no_tune, parsnip::logistic_reg(), tlr_no_tune)
grid_three <- tibble::tibble()
sched_three <- schedule_grid(grid_three, wflow_three)
expect_named(sched_three, c("model_stage"))
expect_equal(nrow(sched_three), 0)
expect_s3_class(
sched_three,
c("resample_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
# Tuning, no postprocesing estimation
test_that("grid processing schedule - recipe only", {
skip_if_not_installed("splines2")
wflow_pre_only <- workflow(rec_tune, parsnip::logistic_reg())
grid_pre_only <-
extract_parameter_set_dials(wflow_pre_only) %>%
dials::grid_regular(levels = 3) %>%
arrange(threshold, disp_df)
sched_pre_only <-
schedule_grid(grid_pre_only, wflow_pre_only)
expect_named(sched_pre_only, c("threshold", "disp_df", "model_stage"))
expect_equal(nrow(sched_pre_only), nrow(grid_pre_only))
# All of the other nested tibbles should be empty
expect_equal(
sched_pre_only %>%
tidyr::unnest(model_stage) %>%
tidyr::unnest(predict_stage) %>%
tidyr::unnest(post_stage),
grid_pre_only
)
expect_s3_class(
sched_pre_only,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - model only, no submodels", {
wflow_rf_only <- workflow(outcome ~ ., mod_tune_no_submodel)
grid_rf_only <-
extract_parameter_set_dials(wflow_rf_only) %>%
dials::grid_regular(levels = 3)
sched_rf_only <-
schedule_grid(grid_rf_only, wflow_rf_only)
expect_named(sched_rf_only, c("model_stage"))
expect_equal(nrow(sched_rf_only), 1L)
rf_n <- length(sched_rf_only$model_stage)
for (i in 1:rf_n) {
# No real need for the loop here
expect_named(sched_rf_only$model_stage[[i]], c("min_n", "predict_stage"))
expect_equal(
sched_rf_only$model_stage[[i]] %>%
tidyr::unnest(predict_stage) %>%
tidyr::unnest(post_stage),
grid_rf_only
)
}
expect_s3_class(
sched_rf_only,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - model only, submodels, regular grid", {
wflow_bst <- workflow(outcome ~ ., mod_tune)
grid_bst <- extract_parameter_set_dials(wflow_bst) %>%
dials::grid_regular(levels = 3)
min_n_only <- grid_bst %>% dplyr::distinct(min_n) %>% dplyr::arrange(min_n)
trees_only <- grid_bst %>% dplyr::distinct(trees) %>% dplyr::arrange(trees)
# regular grid
sched_bst <- schedule_grid(grid_bst, wflow_bst)
expect_named(sched_bst, c("model_stage"))
expect_equal(nrow(sched_bst), 1L)
reg_n <- length(sched_bst$model_stage)
for (i in 1:reg_n) {
expect_named(sched_bst$model_stage[[i]], c("trees", "min_n", "predict_stage"))
expect_equal(
sched_bst$model_stage[[i]] %>%
dplyr::select(-trees, -predict_stage),
min_n_only
)
for (j in seq_along(sched_bst$model_stage[[i]]$predict_stage)) {
expect_named(
sched_bst$model_stage[[i]]$predict_stage[[j]],
c("trees", "post_stage"))
expect_equal(
sched_bst$model_stage[[i]]$predict_stage[[j]] %>%
dplyr::select(trees),
trees_only
)
}
expect_equal(
sched_bst$model_stage[[i]] %>%
dplyr::select(-trees) %>%
tidyr::unnest(predict_stage) %>%
tidyr::unnest(post_stage) %>%
dplyr::select(trees, min_n),
grid_bst
)
}
expect_s3_class(
sched_bst,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - model only, submodels, SFD grid", {
wflow_bst <- workflow(outcome ~ ., mod_tune)
grid_sfd_bst <- extract_parameter_set_dials(wflow_bst) %>%
dials::grid_space_filling(size = 5, type = "uniform")
sched_sfd_bst <- schedule_grid(grid_sfd_bst, wflow_bst)
expect_named(sched_sfd_bst, c("model_stage"))
expect_equal(nrow(sched_sfd_bst), 1L)
irreg_n <- length(sched_sfd_bst$model_stage)
expect_equal(irreg_n, 1L)
expect_named(sched_sfd_bst$model_stage[[1]], c("trees", "min_n", "predict_stage"))
expect_equal(
sched_sfd_bst$model_stage[[1]] %>%
dplyr::select(-predict_stage) %>%
dplyr::select(trees, min_n) %>%
dplyr::arrange(trees, min_n),
grid_sfd_bst %>%
dplyr::select(trees, min_n) %>%
dplyr::arrange(trees, min_n)
)
expect_equal(
sched_sfd_bst$model_stage[[1]] %>%
dplyr::select(-trees) %>%
tidyr::unnest(predict_stage) %>%
tidyr::unnest(post_stage) %>%
dplyr::select(trees, min_n) %>%
dplyr::arrange(trees, min_n),
grid_sfd_bst %>%
dplyr::select(trees, min_n) %>%
dplyr::arrange(trees, min_n)
)
expect_s3_class(
sched_sfd_bst,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - model only, submodels, irregular design", {
wflow_bst <- workflow(outcome ~ ., mod_tune)
grid_odd_bst <- tibble::tibble(
min_n = c(1, 1, 2, 3, 4, 5),
trees = rep(1:2, 3)
)
sched_odd_bst <- schedule_grid(grid_odd_bst, wflow_bst)
expect_named(sched_odd_bst, c("model_stage"))
expect_equal(nrow(sched_odd_bst), 1L)
odd_n <- length(sched_odd_bst$model_stage)
expect_equal(odd_n, 1L)
expect_named(sched_odd_bst$model_stage[[1]], c("trees", "min_n", "predict_stage"))
expect_equal(
sched_odd_bst$model_stage[[1]] %>%
dplyr::select(-predict_stage) %>%
dplyr::select(trees, min_n),
tibble::tibble(trees = c(2, 1, 2, 1, 2), min_n = c(1, 2, 3, 4, 5))
)
for (i in 1:nrow(sched_odd_bst$model_stage[[1]])) {
prd <- sched_odd_bst$model_stage[[1]]$predict_stage[[i]]
if (sched_odd_bst$model_stage[[1]]$min_n[i] == 1) {
expect_equal(nrow(prd), 2L)
} else {
expect_equal(nrow(prd), 1L)
}
expect_true(
all(purrr::map_int(prd$post_stage, nrow) == 1)
)
}
expect_s3_class(
sched_odd_bst,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - model only, submodels, 1 point design", {
wflow_bst <- workflow(outcome ~ ., mod_tune)
set.seed(1)
grid_1_pt <- extract_parameter_set_dials(wflow_bst) %>%
dials::grid_random(size = 1)
sched_1_pt <- schedule_grid(grid_1_pt, wflow_bst)
expect_named(sched_1_pt, c("model_stage"))
expect_equal(nrow(sched_1_pt), 1L)
expect_equal(length(sched_1_pt$model_stage), 1L)
expect_named(
sched_1_pt$model_stage[[1]],
c("trees", "min_n", "predict_stage")
)
expect_equal(
length(sched_1_pt$model_stage[[1]]$predict_stage),
1L
)
expect_named(
sched_1_pt$model_stage[[1]]$predict_stage[[1]],
c("trees", "post_stage")
)
expect_equal(
length(sched_1_pt$model_stage[[1]]$predict_stage[[1]]$post_stage),
1L
)
expect_equal(
dim(sched_1_pt$model_stage[[1]]$predict_stage[[1]]$post_stage[[1]]),
1:0
)
expect_s3_class(
sched_1_pt,
c(
"single_schedule",
"grid_schedule",
"schedule",
"tbl_df",
"tbl",
"data.frame"
)
)
})
test_that("grid processing schedule - postprocessing only", {
skip_if_not_installed("probably")
wflow_thrsh <- workflow(outcome ~ ., parsnip::logistic_reg(), tlr_tune)
grid_thrsh <- extract_parameter_set_dials(wflow_thrsh) %>%
update(lower_limit = dials::lower_limit(c(0, 1))) %>%
dials::grid_regular(levels = 3)
sched_thrsh <- schedule_grid(grid_thrsh, wflow_thrsh)
expect_named(sched_thrsh, c("model_stage"))
expect_equal(nrow(sched_thrsh), 1L)
expect_named(sched_thrsh$model_stage[[1]], c("predict_stage"))
expect_equal(nrow(sched_thrsh$model_stage[[1]]), 1L)
expect_named(
sched_thrsh$model_stage[[1]]$predict_stage[[1]],
c("post_stage")
)
expect_equal(nrow(sched_thrsh$model_stage[[1]]), 1L)
expect_equal(
sched_thrsh$model_stage[[1]]$predict_stage[[1]]$post_stage[[1]],
grid_thrsh
)
expect_s3_class(
sched_thrsh,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe + postprocessing, regular grid", {
skip_if_not_installed("splines2")
skip_if_not_installed("probably")
wflow_pre_post <- workflow(rec_tune, parsnip::logistic_reg(), tlr_tune)
grid_pre_post <-
extract_parameter_set_dials(wflow_pre_post) %>%
update(lower_limit = dials::lower_limit(c(0, 1))) %>%
dials::grid_regular(levels = 3)
grid_pre <-
grid_pre_post %>%
distinct(threshold, disp_df)
grid_post <-
grid_pre_post %>%
distinct(lower_limit) %>%
arrange(lower_limit)
sched_pre_post <- schedule_grid(grid_pre_post, wflow_pre_post)
expect_named(sched_pre_post, c("threshold", "disp_df", "model_stage"))
expect_equal(
sched_pre_post %>% select(-model_stage) %>% tibble::as_tibble(),
grid_pre
)
for (i in seq_along(sched_pre_post$model_stage)) {
expect_named(sched_pre_post$model_stage[[i]], c("predict_stage"))
expect_equal(nrow(sched_pre_post$model_stage[[i]]), 1L)
}
for (i in seq_along(sched_pre_post$model_stage)) {
expect_named(
sched_pre_post$model_stage[[i]]$predict_stage[[1]],
c("post_stage")
)
expect_identical(
sched_pre_post$model_stage[[i]]$predict_stage[[1]]$post_stage[[1]] %>%
arrange(lower_limit),
grid_post
)
}
expect_s3_class(
sched_pre_post,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe + postprocessing, irregular grid", {
skip_if_not_installed("splines2")
skip_if_not_installed("probably")
wflow_pre_post <- workflow(rec_tune, parsnip::logistic_reg(), tlr_tune)
grid_pre_post <-
extract_parameter_set_dials(wflow_pre_post) %>%
update(lower_limit = dials::lower_limit(c(0, 1))) %>%
dials::grid_regular() %>%
dplyr::slice(-c(1, 14))
grid_pre <-
grid_pre_post %>%
distinct(threshold, disp_df)
grids_post <-
grid_pre_post %>%
dplyr::group_nest(threshold, disp_df) %>%
mutate(data = purrr::map(data, ~ arrange(.x, lower_limit)))
sched_pre_post <- schedule_grid(grid_pre_post, wflow_pre_post)
expect_named(sched_pre_post, c("threshold", "disp_df", "model_stage"))
expect_equal(
sched_pre_post %>% select(-model_stage) %>% tibble::as_tibble(),
grid_pre
)
for (i in seq_along(sched_pre_post$model_stage)) {
expect_named(sched_pre_post$model_stage[[i]], c("predict_stage"))
expect_equal(nrow(sched_pre_post$model_stage[[i]]), 1L)
}
for (i in seq_along(sched_pre_post$model_stage)) {
expect_named(
sched_pre_post$model_stage[[i]]$predict_stage[[1]],
c("post_stage")
)
pre_grid_i <-
sched_pre_post %>%
slice(i) %>%
select(threshold, disp_df)
post_grid_i <-
pre_grid_i %>%
inner_join(grids_post, by = dplyr::join_by(threshold, disp_df)) %>%
purrr::pluck("data") %>%
purrr::pluck(1) %>%
arrange(lower_limit)
expect_identical(
sched_pre_post$model_stage[[i]]$predict_stage[[1]]$post_stage[[1]] %>%
arrange(lower_limit),
post_grid_i
)
}
expect_s3_class(
sched_pre_post,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe + model, no submodels, regular grid", {
skip_if_not_installed("splines2")
wflow_pre_model <- workflow(rec_tune, mod_tune_no_submodel)
grid_pre_model <-
extract_parameter_set_dials(wflow_pre_model) %>%
dials::grid_regular()
grid_pre <-
grid_pre_model %>%
distinct(threshold, disp_df)
grid_model <-
grid_pre_model %>%
distinct(min_n) %>%
arrange(min_n)
sched_pre_model <- schedule_grid(grid_pre_model, wflow_pre_model)
expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage"))
expect_equal(
sched_pre_model %>% select(-model_stage) %>% tibble::as_tibble(),
grid_pre
)
for (i in seq_along(sched_pre_model$model_stage)) {
expect_named(sched_pre_model$model_stage[[i]], c("min_n", "predict_stage"))
expect_equal(
sched_pre_model$model_stage[[i]] %>% select(min_n) %>% arrange(min_n),
grid_model
)
}
for (i in seq_along(sched_pre_model$model_stage)) {
expect_named(
sched_pre_model$model_stage[[i]]$predict_stage[[1]],
c("post_stage")
)
expect_equal(
nrow(sched_pre_model$model_stage[[i]]$predict_stage[[1]]),
1L
)
expect_equal(
nrow(sched_pre_model$model_stage[[i]]$predict_stage[[1]]$post_stage[[1]]),
1L
)
}
expect_s3_class(
sched_pre_model,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe + model, submodels, irregular grid", {
skip_if_not_installed("splines2")
wflow_pre_model <- workflow(rec_tune, mod_tune)
grid_pre_model <-
extract_parameter_set_dials(wflow_pre_model) %>%
dials::grid_regular() %>%
# This will make the submodel parameter (trees) unbalanced for some
# combination of parameters of the other parameters.
slice(-c(1, 2, 11))
grid_pre <-
grid_pre_model %>%
distinct(threshold, disp_df)
grid_model <-
grid_pre_model %>%
dplyr::group_nest(threshold, disp_df) %>%
mutate(
data = purrr::map(data, ~ .x %>% dplyr::summarize(trees = max(trees), .by = c(min_n))),
data = purrr::map(data, ~ .x %>% arrange(min_n))
)
sched_pre_model <- schedule_grid(grid_pre_model, wflow_pre_model)
expect_named(sched_pre_model, c("threshold", "disp_df", "model_stage"))
expect_equal(
sched_pre_model %>% select(-model_stage) %>% tibble::as_tibble(),
grid_pre
)
for (i in seq_along(sched_pre_model$model_stage)) {
model_i <- sched_pre_model$model_stage[[i]]
expect_named(model_i, c("trees", "min_n", "predict_stage"))
expect_equal(
model_i %>% select(min_n, trees) %>% arrange(min_n),
grid_model$data[[i]]
)
for (j in seq_along(sched_pre_model$model_stage[[i]]$predict_stage)) {
predict_j <- model_i$predict_stage[[j]]
# We need to figure out the trees that need predicting for the current
# set of other parameters.
# Get the settings that have already be resolved:
other_ij <-
model_i %>%
select(-predict_stage, -trees) %>%
slice(j) %>%
vctrs::vec_cbind(
sched_pre_model %>%
select(threshold, disp_df) %>%
slice(i)
)
# What are the matching values from the grid?
trees_ij <-
grid_pre_model %>%
inner_join(other_ij, by = c("min_n", "threshold", "disp_df")) %>%
select(trees)
expect_equal(
predict_j %>% select(trees) %>% arrange(trees),
trees_ij %>% arrange(trees)
)
}
}
expect_s3_class(
sched_pre_model,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
test_that("grid processing schedule - recipe + model + tailor, submodels, irregular grid", {
skip_if_not_installed("splines2")
skip_if_not_installed("probably")
wflow_pre_model_post <- workflow(rec_tune, mod_tune, tlr_tune)
grid_pre_model_post <-
extract_parameter_set_dials(wflow_pre_model_post) %>%
update(lower_limit = dials::lower_limit(c(0, 1))) %>%
dials::grid_regular() %>%
# This will make the submodel parameter (trees) unbalanced for some
# combination of parameters of the other parameters.
slice(seq(1, 240, by = 7))
grid_pre <-
grid_pre_model_post %>%
distinct(threshold, disp_df)
grid_model <-
grid_pre_model_post %>%
select(-lower_limit) %>%
dplyr::group_nest(threshold, disp_df) %>%
mutate(
data = purrr::map(
data,
~ .x %>% dplyr::summarize(trees = max(trees), .by = c(min_n))
),
data = purrr::map(data, ~ .x %>% arrange(min_n))
)
sched_pre_model_post <- schedule_grid(
grid_pre_model_post,
wflow_pre_model_post
)
expect_named(sched_pre_model_post, c("threshold", "disp_df", "model_stage"))
expect_equal(
sched_pre_model_post %>% select(-model_stage) %>% tibble::as_tibble(),
grid_pre
)
for (i in seq_along(sched_pre_model_post$model_stage)) {
model_i <- sched_pre_model_post$model_stage[[i]]
# Get the current set of preproc parameters to remove
other_i <-
sched_pre_model_post[i,] %>%
dplyr::select(-model_stage)
# We expect to evaluate these specific models for this set of preprocessors
exp_i <-
grid_pre_model_post %>%
inner_join(other_i, by = c("threshold", "disp_df")) %>%
arrange(trees, min_n, lower_limit) %>%
select(trees, min_n, lower_limit)
# What we will evaluate:
subgrid_i <-
model_i %>%
select(-trees) %>%
unnest(predict_stage) %>%
unnest(post_stage) %>%
arrange(trees, min_n, lower_limit) %>%
select(trees, min_n, lower_limit)
expect_equal(subgrid_i, exp_i)
for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) {
model_ij <- model_i[j,]
expect_named(model_ij, c("trees", "min_n", "predict_stage"))
predict_j <- model_ij$predict_stage[[1]]
expect_named(predict_j, c("trees", "post_stage"))
exp_post_grid <-
# Condition on the current set of non-submodel or post param to see
# what we should be evaluating:
model_ij %>%
dplyr::select(-trees) %>%
vctrs::vec_cbind(other_i) %>%
dplyr::inner_join(
grid_pre_model_post,
by = c("threshold", "disp_df", "min_n")
) %>%
dplyr::select(trees, lower_limit) %>%
dplyr::arrange(trees, lower_limit)
# Which as scheduled to be evaluated:
subgrid_ij <-
predict_j %>%
unnest(post_stage) %>%
dplyr::arrange(trees, lower_limit)
expect_equal(subgrid_ij, exp_post_grid)
}
}
expect_s3_class(
sched_pre_model_post,
c("grid_schedule", "schedule", "tbl_df", "tbl", "data.frame")
)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.