tests/testthat/test-rpart.R

################################################################################
# This is the part of the 'tidyrules' R package hosted at
# https://github.com/talegari/tidyrules with GPL-3 license.
################################################################################

context("test-rpart")

# setup some models ----
data("attrition", package = "modeldata")

# classification test
attrition_class =
  attrition %>%
  tidytable::mutate(
    tidytable::across(tidytable::where(is.ordered), ~ factor(.x, ordered = F))
    ) %>%
  tidytable::mutate(Attrition = factor(Attrition, levels = c("No", "Yes")))

rpart_att    = rpart::rpart(Attrition ~ ., data = attrition_class)
tr_att_class = tidy(rpart_att)

# regression test
attrition_reg =
  attrition %>%
  tidytable::mutate(
    tidytable::across(tidytable::where(is.ordered), ~ factor(.x, ordered = F))
    ) %>%
  tidytable::select(-Attrition)

rpart_att_reg = rpart::rpart(MonthlyIncome ~ ., data = attrition_reg)
tr_att_reg    = tidy(rpart_att_reg)

# BreastCancer
data(BreastCancer, package = "mlbench")
bc = BreastCancer %>%
  dplyr::select(-Id) %>%
  dplyr::mutate_if(is.ordered, function(x) x = factor(x,ordered = F))

bc_1m = rpart::rpart(Class ~ ., data = bc)

tr_bc_1 = tidy(bc_1m)

# variables with spaces
bc2 = bc

colnames(bc2)[which(colnames(bc2) == "Cell.size")] = "Cell size"
colnames(bc2)[which(colnames(bc2) == "Cell.shape")] = "Cell shape"

bc_2m = rpart::rpart(Class ~ ., data = bc2)

tr_bc_2 = tidy(bc_2m)

# function to check whether a rule is filterable
ruleFilterable = function(rule, data){
  dplyr::filter(data, eval(parse(text = rule)))
}

# function to check whether all rules are filterable
allRulesFilterable = function(tr, data){
  parse_status = sapply(
    tr[["LHS"]],
    function(arule){
      trydf = try(ruleFilterable(arule, data), silent = TRUE)
      if (nrow(trydf) == 0) print(arule)
      inherits(trydf, "data.frame")
    }
  )
  return(parse_status)
}

# test for error while ordered features are present ----
test_that("check error",{
  expect_error(tidy(rpart_att_1))})

# test output type ----

test_that("creates rulelist", {
  expect_is(tr_att_class, "rulelist")
  expect_is(tr_bc_1, "rulelist")
  expect_is(tr_bc_2, "rulelist")
  expect_is(tr_att_reg, "rulelist")
})

# test NA ----
test_that("Are NA present", {
  expect_false(anyNA(tr_att_class))
  expect_false(anyNA(tr_bc_1))
  expect_false(anyNA(tr_bc_2))
  expect_false(anyNA(tr_att_reg))
})

# test parsable ----
test_that("rules are parsable", {
  expect_true(all(allRulesFilterable(tr_att_class, attrition)))
  expect_true(all(allRulesFilterable(tr_bc_1, bc)))
  expect_true(all(allRulesFilterable(tr_bc_2, bc2)))
  expect_true(all(allRulesFilterable(tr_att_reg,attrition)))
})

Try the tidyrules package in your browser

Any scripts or data that you put into this service are public.

tidyrules documentation built on June 30, 2024, 1:07 a.m.