test_that("recipe merges", {
skip_if_not_installed("modeldata")
skip_if_not_installed("splines2")
data("Chicago", package = "modeldata")
spline_rec <-
recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, dplyr::ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(), neighbors = tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_numeric_predictors()) %>%
recipes::step_spline_b(recipes::all_predictors(), deg_free = tune(), degree = tune())
spline_grid <-
tibble::tribble(
~imputation, ~threshold, ~deg_free, ~degree,
3L, 0.088, 14L, 1L,
6L, 0.058, 8L, 1L,
8L, 0.051, 14L, 1L,
9L, 0.007, 10L, 1L,
1L, 0.032, 15L, 2L,
8L, 0.018, 9L, 2L,
1L, 0.036, 5L, 1L,
10L, 0.1, 10L, 2L,
9L, 0.094, 12L, 1L,
4L, 0.025, 12L, 1L
)
expect_no_error(
spline_updated <- merge(spline_rec, spline_grid)
)
check_merged_tibble(spline_updated)
for (i in 1:nrow(spline_grid)) {
expect_equal(
spline_updated$x[[i]]$steps[[4]]$neighbors, spline_grid$imputation[[i]]
)
expect_equal(
spline_updated$x[[i]]$steps[[5]]$threshold, spline_grid$threshold[[i]]
)
expect_equal(
spline_updated$x[[i]]$steps[[8]]$deg_free, spline_grid$deg_free[[i]]
)
expect_equal(
spline_updated$x[[i]]$steps[[8]]$degree, spline_grid$degree[[i]]
)
}
})
test_that("partially recipe merge", {
skip_if_not_installed("modeldata")
skip_if_not_installed("splines2")
data("Chicago", package = "modeldata")
spline_rec <-
recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, dplyr::ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(), neighbors = tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_numeric_predictors()) %>%
recipes::step_spline_b(recipes::all_predictors(), deg_free = tune(), degree = tune())
spline_grid <-
tibble::tribble(
~imputation, ~threshold, ~deg_free, ~degree,
3L, 0.088, 14L, 1L,
6L, 0.058, 8L, 1L,
8L, 0.051, 14L, 1L,
9L, 0.007, 10L, 1L,
1L, 0.032, 15L, 2L,
8L, 0.018, 9L, 2L,
1L, 0.036, 5L, 1L,
10L, 0.1, 10L, 2L,
9L, 0.094, 12L, 1L,
4L, 0.025, 12L, 1L
)
expect_no_error(
spline_updated <- merge(spline_rec, spline_grid[, -1])
)
check_merged_tibble(spline_updated, complete = FALSE)
for (i in 1:nrow(spline_grid)) {
expect_equal(
spline_updated$x[[i]]$steps[[4]]$neighbors, tune("imputation")
)
expect_equal(
spline_updated$x[[i]]$steps[[5]]$threshold, spline_grid$threshold[[i]]
)
expect_equal(
spline_updated$x[[i]]$steps[[8]]$deg_free, spline_grid$deg_free[[i]]
)
expect_equal(
spline_updated$x[[i]]$steps[[8]]$degree, spline_grid$degree[[i]]
)
}
})
test_that("umerged recipe merge", {
skip_if_not_installed("modeldata")
skip_if_not_installed("splines2")
data("Chicago", package = "modeldata")
spline_rec <-
recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, dplyr::ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(), neighbors = tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_numeric_predictors()) %>%
recipes::step_spline_b(recipes::all_predictors(), deg_free = tune(), degree = tune())
bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2))
expect_no_error(
spline_updated <- merge(spline_rec, bst_grid)
)
check_merged_tibble(spline_updated, complete = FALSE)
for (i in 1:nrow(bst_grid)) {
expect_equal(
spline_updated$x[[i]]$steps[[4]]$neighbors, tune("imputation")
)
expect_equal(
spline_updated$x[[i]]$steps[[5]]$threshold, tune()
)
expect_equal(
spline_updated$x[[i]]$steps[[8]]$deg_free, tune()
)
expect_equal(
spline_updated$x[[i]]$steps[[8]]$degree, tune()
)
}
})
# ------------------------------------------------------------------------------
test_that("model spec merges", {
library(parsnip)
bst_model <-
parsnip::boost_tree(mode = "classification", trees = tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = tune(), noGlobalPruning = TRUE)
bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2))
expect_no_error(
bst_updated <- merge(bst_model, bst_grid)
)
check_merged_tibble(bst_updated, "model_spec")
for (i in 1:nrow(bst_grid)) {
expect_equal(
bst_updated$x[[i]]$args$trees,
rlang::as_quosure(bst_grid[["funky name \n"]][[i]], empty_env())
)
expect_equal(
bst_updated$x[[i]]$eng_args$rules,
rlang::as_quosure(bst_grid$rules[[i]], empty_env())
)
}
# ensure that `grid` can handle list-columns
bst_model_obj <-
boost_tree(mode = "classification") %>%
set_args(objective = tune())
bst_grid_obj <- tibble::tibble(objective = list("hey", "there"))
bst_updated_obj <- merge(bst_model_obj, bst_grid_obj)
expect_equal(
rlang::quo_get_expr(bst_updated_obj$x[[1]]$eng_args$objective),
"hey"
)
expect_equal(
rlang::quo_get_expr(bst_updated_obj$x[[2]]$eng_args$objective),
"there"
)
})
test_that("partially model spec merge", {
bst_model <-
parsnip::boost_tree(mode = "classification", trees = tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = tune(), noGlobalPruning = TRUE)
bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2))
expect_no_error(
bst_updated <- merge(bst_model, bst_grid[, -1])
)
check_merged_tibble(bst_updated, "model_spec", complete = FALSE)
for (i in 1:nrow(bst_grid)) {
expect_equal(
rlang::get_expr(bst_updated$x[[i]]$args$trees), tune("funky name \n")
)
expect_equal(
bst_updated$x[[i]]$eng_args$rules,
rlang::as_quosure(bst_grid$rules[[i]], empty_env())
)
}
})
test_that("umerged model spec merge", {
bst_model <-
parsnip::boost_tree(mode = "classification", trees = tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = tune(), noGlobalPruning = TRUE)
bst_grid <- tibble::tibble("funky name \n" = 1:4, rules = rep(c(TRUE, FALSE), each = 2))
other_grid <- bst_grid
names(bst_grid) <- letters[1:2]
expect_no_error(
bst_not_updated <- merge(bst_model, other_grid)
)
check_merged_tibble(bst_not_updated, "model_spec", complete = FALSE)
# for (i in 1:nrow(other_grid)) {
# if (i == 1) {
# print(rlang::get_expr(bst_not_updated$x[[i]]$args$trees))
# print(rlang::get_expr(bst_not_updated$x[[i]]$eng_args$rules))
# }
# expect_equal(
# rlang::get_expr(bst_not_updated$x[[i]]$args$trees), tune("funky name \n")
# )
# expect_equal(
# rlang::get_expr(bst_not_updated$x[[i]]$eng_args$rules), tune()
# )
# }
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.