tests/testthat/test-filter_irrelevant_rules.R

test_that("Filter ireelevant rules run correctly", {

  # Generate sample data
  skip_on_cran()
  set.seed(181)
  dataset_cont <- generate_cre_dataset(n = 100, rho = 0, n_rules = 2, p = 10,
                                       effect_size = 2, binary_outcome = FALSE)
  y <- dataset_cont[["y"]]
  z <- dataset_cont[["z"]]
  X <- dataset_cont[["X"]]
  ite_method <- "aipw"
  learner_ps <- "SL.xgboost"
  learner_y <- "SL.xgboost"
  ntrees <- 100
  node_size <- 20
  max_nodes <- 5
  max_depth <- 3

  # Step 1: Split data
  X <- as.matrix(X)
  y <- as.matrix(y)
  z <- as.matrix(z)

  # Step 2: Estimate ITE
  ite <- estimate_ite(y, z, X, ite_method,
                      learner_ps = learner_ps,
                      learner_y = learner_y)

  expect_equal(ite[10], 0.6874263, tolerance = 0.000001)
  expect_equal(ite[25], -0.2175163, tolerance = 0.000001)
  expect_equal(ite[70], 1.656867, tolerance = 0.000001)


  # Set parameters
  N <- dim(X)[1]
  sf <- min(1, (11 * sqrt(N) + 1) / N)
  mn <- 2 + floor(stats::rexp(1, 1 / (max_nodes - 2)))

  # Random Forest
  forest <- suppressWarnings(randomForest::randomForest(
                              x = X,
                              y = ite,
                              sampsize = sf * N,
                              ntree = 1,
                              maxnodes = mn,
                              nodesize = node_size))
  for (i in 2:ntrees) {
    mn <- 2 + floor(stats::rexp(1, 1 / (max_nodes - 2)))
    model1_RF <- suppressWarnings(randomForest::randomForest(
                                   x = X,
                                   y = ite,
                                   sampsize = sf * N,
                                   replace = FALSE,
                                   ntree = 1,
                                   maxnodes = mn,
                                   nodesize = node_size))
    forest <- randomForest::combine(forest, model1_RF)
  }
  treelist <- inTrees_RF2List(forest)

  expect_equal(length(treelist), 2)
  expect_equal(length(treelist[2]$list), 100)
  expect_equal(colnames(treelist[2]$list[[1]])[1], "left daughter")
  expect_equal(treelist[2]$list[[1]][2, 6], 0.5720872, tolerance = 0.000001)
  expect_equal(treelist[2]$list[[2]][3, 6], 0.6381637, tolerance = 0.000001)
  expect_equal(treelist[2]$list[[10]][3, 6], -0.6764889, tolerance = 0.000001)

  rules <- extract_rules(treelist, X, ntrees, max_depth)

  t_decay <- 0.025

  ###### Run Tests ######

  # Incorrect inputs
  expect_error(filter_irrelevant_rules(rules = NA, X, ite, t_decay))
  expect_error(filter_irrelevant_rules(rules, X = NA, ite, t_decay))
  expect_error(filter_irrelevant_rules(rules, X, ite = NA, t_decay))
  expect_error(filter_irrelevant_rules(rules, X, ite, t_decay = NA))

  # Correct outputs
  rules_RF <- filter_irrelevant_rules(rules, X, ite, t_decay)
  expect_true(class(rules_RF) == "character")
})

Try the CRE package in your browser

Any scripts or data that you put into this service are public.

CRE documentation built on Oct. 19, 2024, 5:07 p.m.