library(testthat)
library(dplyr)
library(rpart)
source(test_path("make_binned_data.R"))
set.seed(8497)
sim_tr_cls <- sim_data_2class(1000)
sim_te_cls <- sim_data_2class(100)
set.seed(8497)
sim_tr_reg <- sim_data_reg(1000)
sim_te_reg <- sim_data_reg(100)
mod <- rpart(y ~ x, data = sim_tr_reg)
best_split <- unname(mod$splits[, "index"])
test_that("low-level binning for classification", {
expect_error(
splits <-
embed:::cart_binning(
sim_tr_cls$x,
"x",
sim_tr_cls$class,
cost_complexity = 0.01,
tree_depth = 5,
min_n = 10
),
regexp = NA
)
expect_equal(splits, best_split)
set.seed(283834)
expect_snapshot({
splits <-
embed:::cart_binning(
sample(sim_tr_cls$x),
"x",
sim_tr_cls$class,
cost_complexity = 0.01,
tree_depth = 5,
min_n = 10
)
})
expect_equal(splits, numeric(0))
})
test_that("low-level binning for regression", {
expect_error(
splits <-
embed:::cart_binning(
sim_tr_reg$x,
"x",
sim_tr_reg$y,
cost_complexity = 0.01,
tree_depth = 5,
min_n = 10
),
regexp = NA
)
expect_equal(splits, best_split)
set.seed(283834)
expect_snapshot({
splits <-
embed:::cart_binning(
sample(sim_tr_reg$x),
"potato",
sim_tr_reg$y,
cost_complexity = 0.01,
tree_depth = 5,
min_n = 10
)
})
expect_equal(splits, numeric(0))
})
test_that("step function for classification", {
expect_snapshot({
cart_rec <-
recipe(class ~ ., data = sim_tr_cls) %>%
step_discretize_cart(all_predictors(), outcome = "class") %>%
prep()
})
expect_equal(names(cart_rec$steps[[1]]$rules), "x")
expect_equal(cart_rec$steps[[1]]$rules$x, best_split)
expect_error(
cart_pred <- bake(cart_rec, sim_tr_cls[, -3]),
regexp = NA
)
expect_true(is.factor(cart_pred$x))
expect_equal(length(levels(cart_pred$x)), 3)
expect_true(is.numeric(cart_pred$z))
})
test_that("step function for regression", {
expect_snapshot({
cart_rec <-
recipe(y ~ ., data = sim_tr_reg) %>%
step_discretize_cart(all_predictors(), outcome = "y") %>%
prep()
})
expect_equal(names(cart_rec$steps[[1]]$rules), "x")
expect_equal(cart_rec$steps[[1]]$rules$x, best_split)
expect_error(
cart_pred <- bake(cart_rec, sim_tr_reg[, -3]),
regexp = NA
)
expect_true(is.factor(cart_pred$x))
expect_equal(length(levels(cart_pred$x)), 3)
expect_true(is.numeric(cart_pred$z))
})
test_that("bad args", {
tmp <- sim_tr_reg
tmp$w <- sample(letters[1:4], nrow(tmp), replace = TRUE)
expect_snapshot(error = TRUE, {
cart_rec <-
recipe(y ~ ., data = tmp) %>%
step_discretize_cart(all_predictors(), outcome = "y") %>%
prep()
})
})
test_that("tidy method", {
cart_rec <-
recipe(y ~ ., data = sim_tr_reg) %>%
step_discretize_cart(all_predictors(), outcome = "y")
res <- tidy(cart_rec, number = 1)
expect_equal(
res$terms,
"all_predictors()"
)
expect_equal(
res$value,
NA_real_
)
expect_snapshot({
cart_rec <- prep(cart_rec)
})
res <- tidy(cart_rec, number = 1)
expect_equal(
res$terms,
rep("x", 2)
)
expect_equal(
res$value,
best_split
)
})
test_that("case weights step functions", {
sim_tr_cls_cw <- sim_tr_cls %>%
mutate(weight = importance_weights(rep(0:1, each = 500)))
sim_tr_reg_cw <- sim_tr_reg %>%
mutate(weight = importance_weights(rep(0:1, each = 500)))
mod_cw <- rpart(y ~ x, data = sim_tr_reg, weights = rep(0:1, each = 500))
best_split_cw <- unname(mod_cw$splits[, "index"])
# Classification
expect_snapshot({
cart_rec <-
recipe(class ~ ., data = sim_tr_cls_cw) %>%
step_discretize_cart(all_predictors(), outcome = "class") %>%
prep()
})
expect_equal(names(cart_rec$steps[[1]]$rules), "x")
expect_equal(cart_rec$steps[[1]]$rules$x, best_split_cw)
# Regression
expect_snapshot({
cart_rec <-
recipe(y ~ ., data = sim_tr_reg_cw) %>%
step_discretize_cart(all_predictors(), outcome = "y") %>%
prep()
})
expect_equal(names(cart_rec$steps[[1]]$rules), c("x", "z"))
expect_equal(cart_rec$steps[[1]]$rules$x, best_split_cw)
expect_snapshot(cart_rec)
})
test_that("tunable", {
rec <-
recipe(~., data = mtcars) %>%
step_discretize_cart(all_predictors(), outcome = "mpg")
rec_param <- tunable.step_discretize_cart(rec$steps[[1]])
expect_equal(rec_param$name, c("cost_complexity", "tree_depth", "min_n"))
expect_true(all(rec_param$source == "recipe"))
expect_true(is.list(rec_param$call_info))
expect_equal(nrow(rec_param), 3)
expect_equal(
names(rec_param),
c("name", "call_info", "source", "component", "component_id")
)
})
# Infrastructure ---------------------------------------------------------------
test_that("bake method errors when needed non-standard role columns are missing", {
rec <- recipe(class ~ ., data = sim_tr_cls) %>%
step_discretize_cart(x, z, outcome = "class") %>%
update_role(x, new_role = "potato") %>%
update_role_requirements(role = "potato", bake = FALSE)
expect_warning(
rec_trained <- prep(rec, training = sim_tr_cls, verbose = FALSE)
)
expect_error(
bake(rec_trained, new_data = sim_tr_cls[, -1]),
class = "new_data_missing_column"
)
})
test_that("empty printing", {
rec <- recipe(mpg ~ ., mtcars)
rec <- step_discretize_cart(rec, outcome = "mpg")
expect_snapshot(rec)
rec <- prep(rec, mtcars)
expect_snapshot(rec)
})
test_that("empty selection prep/bake is a no-op", {
rec1 <- recipe(mpg ~ ., mtcars)
rec2 <- step_discretize_cart(rec1, outcome = "mpg")
rec1 <- prep(rec1, mtcars)
rec2 <- prep(rec2, mtcars)
baked1 <- bake(rec1, mtcars)
baked2 <- bake(rec2, mtcars)
expect_identical(baked1, baked2)
})
test_that("empty selection tidy method works", {
rec <- recipe(mpg ~ ., mtcars)
rec <- step_discretize_cart(rec, outcome = "mpg")
expect <- tibble(terms = character(), value = double(), id = character())
expect_identical(tidy(rec, number = 1), expect)
rec <- prep(rec, mtcars)
expect_identical(tidy(rec, number = 1), expect)
})
test_that("printing", {
rec <- recipe(class ~ ., data = sim_tr_cls) %>%
step_discretize_cart(all_predictors(), outcome = "class")
expect_snapshot(print(rec))
expect_snapshot(prep(rec))
})
test_that("tunable is setup to works with extract_parameter_set_dials", {
skip_if_not_installed("dials")
rec <- recipe(~., data = mtcars) %>%
step_discretize_cart(
all_predictors(),
outcome = "mpg",
cost_complexity = hardhat::tune(),
tree_depth = hardhat::tune(),
min_n = hardhat::tune()
)
params <- extract_parameter_set_dials(rec)
expect_s3_class(params, "parameters")
expect_identical(nrow(params), 3L)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.