tests/testthat/test-discretize_cart.R

library(testthat)
library(dplyr)
library(rpart)

source(test_path("make_binned_data.R"))

set.seed(8497)
sim_tr_cls <- sim_data_2class(1000)
sim_te_cls <- sim_data_2class(100)

set.seed(8497)
sim_tr_reg <- sim_data_reg(1000)
sim_te_reg <- sim_data_reg(100)

mod <- rpart(y ~ x, data = sim_tr_reg)

best_split <- unname(mod$splits[, "index"])

test_that("low-level binning for classification", {
  expect_error(
    splits <-
      embed:::cart_binning(
        sim_tr_cls$x,
        "x",
        sim_tr_cls$class,
        cost_complexity = 0.01,
        tree_depth = 5,
        min_n = 10
      ),
    regexp = NA
  )
  expect_equal(splits, best_split)

  set.seed(283834)
  expect_snapshot({
    splits <-
      embed:::cart_binning(
        sample(sim_tr_cls$x),
        "x",
        sim_tr_cls$class,
        cost_complexity = 0.01,
        tree_depth = 5,
        min_n = 10
      )
  })
  expect_equal(splits, numeric(0))
})

test_that("low-level binning for regression", {
  expect_error(
    splits <-
      embed:::cart_binning(
        sim_tr_reg$x,
        "x",
        sim_tr_reg$y,
        cost_complexity = 0.01,
        tree_depth = 5,
        min_n = 10
      ),
    regexp = NA
  )
  expect_equal(splits, best_split)

  set.seed(283834)
  expect_snapshot({
    splits <-
      embed:::cart_binning(
        sample(sim_tr_reg$x),
        "potato",
        sim_tr_reg$y,
        cost_complexity = 0.01,
        tree_depth = 5,
        min_n = 10
      )
  })
  expect_equal(splits, numeric(0))
})

test_that("step function for classification", {
  expect_snapshot({
    cart_rec <-
      recipe(class ~ ., data = sim_tr_cls) %>%
      step_discretize_cart(all_predictors(), outcome = "class") %>%
      prep()
  })

  expect_equal(names(cart_rec$steps[[1]]$rules), "x")
  expect_equal(cart_rec$steps[[1]]$rules$x, best_split)

  expect_error(
    cart_pred <- bake(cart_rec, sim_tr_cls[, -3]),
    regexp = NA
  )

  expect_true(is.factor(cart_pred$x))
  expect_equal(length(levels(cart_pred$x)), 3)
  expect_true(is.numeric(cart_pred$z))
})

test_that("step function for regression", {
  expect_snapshot({
    cart_rec <-
      recipe(y ~ ., data = sim_tr_reg) %>%
      step_discretize_cart(all_predictors(), outcome = "y") %>%
      prep()
  })

  expect_equal(names(cart_rec$steps[[1]]$rules), "x")
  expect_equal(cart_rec$steps[[1]]$rules$x, best_split)

  expect_error(
    cart_pred <- bake(cart_rec, sim_tr_reg[, -3]),
    regexp = NA
  )

  expect_true(is.factor(cart_pred$x))
  expect_equal(length(levels(cart_pred$x)), 3)
  expect_true(is.numeric(cart_pred$z))
})

test_that("bad args", {
  tmp <- sim_tr_reg
  tmp$w <- sample(letters[1:4], nrow(tmp), replace = TRUE)

  expect_snapshot(error = TRUE, {
    cart_rec <-
      recipe(y ~ ., data = tmp) %>%
      step_discretize_cart(all_predictors(), outcome = "y") %>%
      prep()
  })
})

test_that("tidy method", {
  cart_rec <-
    recipe(y ~ ., data = sim_tr_reg) %>%
    step_discretize_cart(all_predictors(), outcome = "y")

  res <- tidy(cart_rec, number = 1)
  expect_equal(
    res$terms,
    "all_predictors()"
  )
  expect_equal(
    res$value,
    NA_real_
  )

  expect_snapshot({
    cart_rec <- prep(cart_rec)
  })

  res <- tidy(cart_rec, number = 1)
  expect_equal(
    res$terms,
    rep("x", 2)
  )
  expect_equal(
    res$value,
    best_split
  )
})

test_that("case weights step functions", {
  sim_tr_cls_cw <- sim_tr_cls %>%
    mutate(weight = importance_weights(rep(0:1, each = 500)))

  sim_tr_reg_cw <- sim_tr_reg %>%
    mutate(weight = importance_weights(rep(0:1, each = 500)))

  mod_cw <- rpart(y ~ x, data = sim_tr_reg, weights = rep(0:1, each = 500))
  best_split_cw <- unname(mod_cw$splits[, "index"])

  # Classification
  expect_snapshot({
    cart_rec <-
      recipe(class ~ ., data = sim_tr_cls_cw) %>%
      step_discretize_cart(all_predictors(), outcome = "class") %>%
      prep()
  })

  expect_equal(names(cart_rec$steps[[1]]$rules), "x")
  expect_equal(cart_rec$steps[[1]]$rules$x, best_split_cw)

  # Regression
  expect_snapshot({
    cart_rec <-
      recipe(y ~ ., data = sim_tr_reg_cw) %>%
      step_discretize_cart(all_predictors(), outcome = "y") %>%
      prep()
  })

  expect_equal(names(cart_rec$steps[[1]]$rules), c("x", "z"))
  expect_equal(cart_rec$steps[[1]]$rules$x, best_split_cw)

  expect_snapshot(cart_rec)
})

test_that("tunable", {
  rec <-
    recipe(~., data = mtcars) %>%
    step_discretize_cart(all_predictors(), outcome = "mpg")
  rec_param <- tunable.step_discretize_cart(rec$steps[[1]])
  expect_equal(rec_param$name, c("cost_complexity", "tree_depth", "min_n"))
  expect_true(all(rec_param$source == "recipe"))
  expect_true(is.list(rec_param$call_info))
  expect_equal(nrow(rec_param), 3)
  expect_equal(
    names(rec_param),
    c("name", "call_info", "source", "component", "component_id")
  )
})

# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
  rec <- recipe(class ~ ., data = sim_tr_cls) %>%
    step_discretize_cart(x, z, outcome = "class") %>%
    update_role(x, new_role = "potato") %>%
    update_role_requirements(role = "potato", bake = FALSE)
  
  expect_warning(
    rec_trained <- prep(rec, training = sim_tr_cls, verbose = FALSE)
  )
  
  expect_error(
    bake(rec_trained, new_data = sim_tr_cls[, -1]),
    class = "new_data_missing_column"
  )
})

test_that("empty printing", {
  rec <- recipe(mpg ~ ., mtcars)
  rec <- step_discretize_cart(rec, outcome = "mpg")
  
  expect_snapshot(rec)
  
  rec <- prep(rec, mtcars)
  
  expect_snapshot(rec)
})

test_that("empty selection prep/bake is a no-op", {
  rec1 <- recipe(mpg ~ ., mtcars)
  rec2 <- step_discretize_cart(rec1, outcome = "mpg")
  
  rec1 <- prep(rec1, mtcars)
  rec2 <- prep(rec2, mtcars)
  
  baked1 <- bake(rec1, mtcars)
  baked2 <- bake(rec2, mtcars)
  
  expect_identical(baked1, baked2)
})

test_that("empty selection tidy method works", {
  rec <- recipe(mpg ~ ., mtcars)
  rec <- step_discretize_cart(rec, outcome = "mpg")
  
  expect <- tibble(terms = character(), value = double(), id = character())
  
  expect_identical(tidy(rec, number = 1), expect)
  
  rec <- prep(rec, mtcars)
  
  expect_identical(tidy(rec, number = 1), expect)
})

test_that("printing", {
  rec <- recipe(class ~ ., data = sim_tr_cls) %>%
    step_discretize_cart(all_predictors(), outcome = "class")
  
  expect_snapshot(print(rec))
  expect_snapshot(prep(rec))
})

test_that("tunable is setup to works with extract_parameter_set_dials", {
  skip_if_not_installed("dials")
  rec <- recipe(~., data = mtcars) %>%
    step_discretize_cart(
      all_predictors(),
      outcome = "mpg",
      cost_complexity = hardhat::tune(),
      tree_depth = hardhat::tune(),
      min_n = hardhat::tune()
    )
  
  params <- extract_parameter_set_dials(rec)
  
  expect_s3_class(params, "parameters")
  expect_identical(nrow(params), 3L)
})
topepo/embed documentation built on March 26, 2024, 4:11 a.m.