Nothing
test_that('check C5.0 opt', {
check_rules <- function(x, ...) {
x$tree == "" & nchar(x$rules) > 10
}
mod_1 <-
bagger(
Species ~ .,
data = iris,
base_model = "C5.0",
control = control_bag(extract = check_rules),
rules = TRUE
)
expect_true(all(unlist(mod_1$model_df$extras)))
expect_true(!is.null(mod_1$imp))
check_winnow <- function(x, ...) {
x$tree == "" &
nchar(x$rules) > 10 &
x$control$bands == 3
}
mod_2 <-
bagger(
Species ~ .,
data = iris,
base_model = "C5.0",
control = control_bag(var_imp = TRUE, extract = check_winnow),
rules = TRUE,
bands = 3
)
expect_true(all(unlist(mod_2$model_df$extras)))
expect_true(inherits(mod_2$imp, "tbl_df"))
})
# ------------------------------------------------------------------------------
test_that('check model reduction', {
set.seed(36323)
reduced <-
bagger(
Species ~ .,
data = iris,
base_model = "C5.0",
times = 3
)
expect_true(length(reduced$model_df$model[[1]]$fit$control) == 1)
expect_equal(reduced$model_df$model[[1]]$fit$call, rlang::call2("dummy_call"))
expect_equal(reduced$model_df$model[[1]]$fit$output, character(0))
set.seed(36323)
full <-
bagger(
Species ~ .,
data = iris,
base_model = "C5.0",
times = 3,
control = control_bag(reduce = FALSE)
)
expect_true(length(full$model_df$model[[1]]$fit$control) > 1)
expect_true(is.call(full$model_df$model[[1]]$fit$call))
expect_true(nchar(full$model_df$model[[1]]$fit$output) > 10)
})
# ------------------------------------------------------------------------------
test_that('check C5 parsnip interface', {
skip_if_not_installed("modeldata")
set.seed(4779)
expect_error(
class_mod <- bag_tree(min_n = 3) %>%
set_engine("C5.0", times = 3) %>%
set_mode("classification") %>%
fit(Class ~ ., data = two_class_dat),
regexp = NA
)
expect_true(
all(purrr::map_lgl(class_mod$fit$model_df$model, ~ inherits(.x, "model_fit")))
)
expect_true(
all(purrr::map_lgl(class_mod$fit$model_df$model, ~ inherits(.x$fit, "C5.0")))
)
expect_error(
class_mod_pred <- predict(class_mod, two_class_dat[1:5, -3]),
regexp = NA
)
expect_true(tibble::is_tibble(class_mod_pred))
expect_equal(nrow(class_mod_pred), 5)
expect_equal(names(class_mod_pred), ".pred_class")
expect_error(
class_mod_prob <- predict(class_mod, two_class_dat[1:5, -3], type = "prob"),
regexp = NA
)
expect_true(tibble::is_tibble(class_mod_prob))
expect_equal(nrow(class_mod_prob), 5)
expect_equal(names(class_mod_prob), c(".pred_Class1", ".pred_Class2"))
set.seed(4779)
expect_error(
class_cost <- bag_tree(min_n = 3, class_cost = 2) %>%
set_engine("C5.0", times = 3) %>%
set_mode("classification") %>%
fit(Class ~ ., data = two_class_dat),
regexp = NA
)
expect_true(
all(purrr::map_lgl(class_cost$fit$model_df$model, ~ inherits(.x, "model_fit")))
)
expect_true(
all(purrr::map_lgl(class_cost$fit$model_df$model, ~ inherits(.x$fit, "C5.0")))
)
expect_error(
class_cost_pred <- predict(class_cost, two_class_dat[1:5, -3]),
regexp = NA
)
expect_true(tibble::is_tibble(class_cost_pred))
expect_equal(nrow(class_cost_pred), 5)
expect_equal(names(class_cost_pred), ".pred_class")
expect_error(
class_cost_prob <- predict(class_cost, two_class_dat[1:5, -3], type = "prob"),
regexp = NA
)
expect_true(tibble::is_tibble(class_cost_prob))
expect_equal(nrow(class_cost_prob), 5)
expect_equal(names(class_cost_prob), c(".pred_Class1", ".pred_Class2"))
expect_output(print(bag_tree(min_n = 3)))
expect_equal(update(bag_tree(), min_n = 3), bag_tree(min_n = 3))
expect_equal(update(bag_tree(), cost_complexity = 3), bag_tree(cost_complexity = 3))
expect_equal(update(bag_tree(), tree_depth = 3), bag_tree(tree_depth = 3))
expect_equal(update(bag_tree(), class_cost = 3), bag_tree(class_cost = 3))
expect_equal(class_cost(c(1, 5))$range$lower, 1)
expect_equal(class_cost(c(1, 5))$range$upper, 5)
})
test_that('mode specific package dependencies', {
expect_identical(
get_from_env(paste0("bag_tree", "_pkgs")) %>%
dplyr::filter(engine == "C5.0", mode == "classification") %>%
dplyr::pull(pkg),
list(c("C50", "baguette"))
)
expect_identical(
get_from_env(paste0("bag_tree", "_pkgs")) %>%
dplyr::filter(engine == "C5.0", mode == "regression") %>%
dplyr::pull(pkg),
list()
)
})
test_that('case weights', {
skip_if_not_installed("modeldata")
data("two_class_dat", package = "modeldata")
set.seed(1)
wts <- runif(nrow(two_class_dat))
wts <- ifelse(wts < 1/5, 0, 1)
expect_error({
set.seed(1)
c5_wts_fit <- bagger(Class ~ A + B, data = two_class_dat,
weights = wts, base_model = "C5.0")
},
regexp = NA
)
expect_true(all(purrr::map_lgl(c5_wts_fit$model_df$model, ~ .x$fit$caseWeights)))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.