Nothing
test_that("formula interface", {
skip_on_cran()
expect_snapshot({
set.seed(1)
res <- f_wflow %>%
tune_sim_anneal(cell_folds,
iter = 2,
control = control_sim_anneal(verbose = TRUE)
)
})
expect_equal(class(res), c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame"))
expect_true(nrow(collect_metrics(res)) == 9)
expect_equal(res, .Last.tune.result)
expect_null(.get_tune_eval_times(res))
expect_null(.get_tune_eval_time_target(res))
})
# ------------------------------------------------------------------------------
test_that("recipe interface", {
skip_on_cran()
skip_on_os("windows")
skip_on_os("linux")
expect_silent({
set.seed(1)
res <- rec_wflow %>%
tune_sim_anneal(cell_folds,
iter = 2,
control = control_sim_anneal(verbose = FALSE, verbose_iter = FALSE)
)
})
expect_equal(class(res), c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame"))
expect_true(nrow(collect_metrics(res)) == 9)
expect_equal(res, .Last.tune.result)
})
# ------------------------------------------------------------------------------
test_that("variable interface", {
skip_on_cran()
expect_snapshot({
set.seed(1)
res <- var_wflow %>%
tune_sim_anneal(cell_folds,
iter = 2,
control = control_sim_anneal(verbose = TRUE, verbose_iter = TRUE)
)
})
expect_equal(class(res), c("iteration_results", "tune_results", "tbl_df", "tbl", "data.frame"))
expect_true(nrow(collect_metrics(res)) == 9)
expect_equal(res, .Last.tune.result)
# Check to see if iterations are picked up when an iterative object is used
# as the initial object
expect_snapshot({
set.seed(1)
new_res <- var_wflow %>%
tune_sim_anneal(cell_folds,
iter = 2, initial = res,
control = control_sim_anneal(verbose = FALSE)
)
})
expect_true(nrow(collect_metrics(new_res)) == 15)
expect_true(max(new_res$.iter) == 4)
expect_true(sum(grepl("^initial", collect_metrics(new_res)$.config)) == 9)
expect_equal(new_res, .Last.tune.result)
# but not for non-iterative objects
set.seed(1)
grid_res <- var_wflow %>%
tune_grid(cell_folds, grid = 2)
expect_snapshot({
set.seed(1)
new_new_res <- var_wflow %>%
tune_sim_anneal(cell_folds,
iter = 2, initial = grid_res,
control = control_sim_anneal(verbose = FALSE)
)
})
expect_true(nrow(collect_metrics(new_new_res)) == 12)
expect_true(max(new_new_res$.iter) == 2)
expect_true(sum(grepl("^initial", collect_metrics(new_new_res)$.config)) == 6)
expect_equal(new_new_res, .Last.tune.result)
})
test_that("unfinalized parameters", {
skip_on_cran()
skip_on_os("windows")
skip_on_os("linux")
data(two_class_dat, package = "modeldata")
set.seed(5046)
bt <- bootstraps(two_class_dat, times = 5)
rec_example <- recipe(Class ~ ., data = two_class_dat)
# RF
model_rf <- rand_forest(mtry = tune()) %>%
set_mode("classification") %>%
set_engine("ranger")
wf_rf <- workflow() %>%
add_model(model_rf) %>%
add_recipe(rec_example)
set.seed(30)
rf_res <- wf_rf %>%
tune_grid(resamples = bt, grid = 4)
expect_snapshot({
set.seed(40)
rf_res_finetune <- wf_rf %>%
tune_sim_anneal(resamples = bt, initial = rf_res)
})
# don't supply an initial grid (#39)
expect_snapshot({
set.seed(40)
rf_res_finetune <- wf_rf %>%
tune_sim_anneal(resamples = bt)
})
})
test_that("incompatible parameter objects", {
skip_on_cran()
skip_if_not_installed("ranger")
skip_if_not_installed("modeldata")
skip_if_not_installed("rsample")
rf_spec <- parsnip::rand_forest(mode = "regression", mtry = tune::tune())
set.seed(1)
grid_with_bigger_range <-
dials::grid_latin_hypercube(dials::mtry(range = c(1, 16)))
set.seed(1)
car_folds <- rsample::vfold_cv(car_prices, v = 2)
car_wflow <- workflows::workflow() %>%
workflows::add_formula(Price ~ .) %>%
workflows::add_model(rf_spec)
set.seed(1)
tune_res_with_bigger_range <- tune::tune_grid(
car_wflow,
resamples = car_folds,
grid = grid_with_bigger_range
)
set.seed(1)
parameter_set_with_smaller_range <-
dials::parameters(dials::mtry(range = c(1, 5)))
scrub_best <- function(lines) {
has_best <- grepl("Initial best", lines)
lines[has_best] <- ""
lines
}
set.seed(1)
expect_snapshot(error = TRUE, transform = scrub_best, {
res <-
tune_sim_anneal(
car_wflow,
param_info = parameter_set_with_smaller_range,
resamples = car_folds,
initial = tune_res_with_bigger_range,
iter = 2
)
})
})
test_that("set event-level", {
# See issue 40
skip_if_not_installed("rpart")
skip_if_not_installed("modeldata")
skip_if_not_installed("yardstick")
skip_if_not_installed("rsample")
skip_on_cran()
# ------------------------------------------------------------------------------
set.seed(1)
dat <- modeldata::sim_classification(500, intercept = 8)
# We should get high sensitivity and low specificity when event_level = "first"
# count(dat, class)
# levels(dat$class)
set.seed(2)
rs <- vfold_cv(dat, strata = class)
cart_spec <- decision_tree(min_n = tune()) %>% set_mode("classification")
stats <- metric_set(accuracy, sensitivity, specificity)
# ------------------------------------------------------------------------------
# high sensitivity and low specificity
set.seed(3)
cart_res_first <-
cart_spec %>%
tune_sim_anneal(class ~ .,
rs,
control = control_sim_anneal(event_level = "first", verbose_iter = FALSE),
metrics = stats)
results_first <-
cart_res_first %>%
collect_metrics() %>%
dplyr::filter(.metric != "accuracy") %>%
dplyr::select(.config, .metric, mean) %>%
tidyr::pivot_wider(id_cols = .config, names_from = .metric, values_from = mean)
dir_check <- all(results_first$sensitivity > results_first$specificity)
expect_true(dir_check)
# ------------------------------------------------------------------------------
# Now reversed
set.seed(3)
cart_res_second <-
cart_spec %>%
tune_sim_anneal(class ~ .,
rs,
control = control_sim_anneal(event_level = "second", verbose_iter = FALSE),
metrics = stats)
results_second <-
cart_res_second %>%
collect_metrics() %>%
dplyr::filter(.metric != "accuracy") %>%
dplyr::select(.config, .metric, mean) %>%
tidyr::pivot_wider(id_cols = .config, names_from = .metric, values_from = mean)
rev_dir_check <- all(results_second$sensitivity < results_second$specificity)
expect_true(rev_dir_check)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.