tests/testthat/test-grfcate.R

context("test grfcate.R -- Generalized Random Forests")

test_that("Lrnr_grfcate binary A preds match those of grf::causal_forest", {
  library(grf)

  # Generate some data
  set.seed(791)
  n <- 500
  p <- 10
  X <- matrix(rnorm(n * p), n, p)
  A <- rbinom(n, 1, 0.15)
  X <- cbind(A, X)
  X.test <- matrix(0, 101, p)
  X.test[, 1] <- seq(-2, 2, length.out = 101)
  Y <- X[, 1] * rnorm(n) + 0.5 * A # added a constant treatment effect
  d <- data.frame(list(Y = Y, X = X))
  names(d)[2] <- "A"
  task <- sl3_Task$new(d, covariates = names(d)[-1], outcome = names(d)[1])

  seed_int <- 496L
  set.seed(seed_int)
  # GRF learner class
  grfcate_learner <- Lrnr_grfcate$new(A = "A", seed = seed_int, num.threads = 1L)
  grfcate_fit <- grfcate_learner$train(task)
  grfcate_pred <- grfcate_fit$predict(task)
  set.seed(seed_int)
  grfcate_pkg <- grf::causal_forest(
    X = X[, -1], W = X[, 1], Y = Y, seed = seed_int, num.threads = 1L
  )
  grfcate_pkg_pred_out <- predict(grfcate_pkg)
  grfcate_pkg_pred <- as.numeric(grfcate_pkg_pred_out$predictions)

  # test equivalence
  expect_equal(grfcate_pred, grfcate_pkg_pred)
})

test_that("Lrnr_grfcate continuous A preds match those of grf::causal_forest", {
  library(grf)

  # Generate some data
  set.seed(791)
  n <- 500
  p <- 10
  X <- matrix(rnorm(n * p), n, p)
  A <- rnorm(n)
  X <- cbind(A, X)
  X.test <- matrix(0, 101, p)
  X.test[, 1] <- seq(-2, 2, length.out = 101)
  Y <- X[, 1] * rnorm(n) + 0.5 * A
  d <- data.frame(list(Y = Y, X = X))
  names(d)[2] <- "A"
  task <- sl3_Task$new(d, covariates = names(d)[-1], outcome = names(d)[1])

  seed_int <- 496L
  set.seed(seed_int)
  # GRF learner class
  grfcate_learner <- Lrnr_grfcate$new(A = "A", seed = seed_int, num.threads = 1L)
  grfcate_fit <- grfcate_learner$train(task)
  grfcate_pred <- grfcate_fit$predict(task)
  set.seed(seed_int)
  grfcate_pkg <- grf::causal_forest(
    X = X[, -1], W = X[, 1], Y = Y, seed = seed_int, num.threads = 1L
  )
  grfcate_pkg_pred_out <- predict(grfcate_pkg)
  grfcate_pkg_pred <- as.numeric(grfcate_pkg_pred_out$predictions)

  # test equivalence
  expect_equal(grfcate_pred, grfcate_pkg_pred)
})

test_that("Lrnr_grfcate errors when A not in covariates", {
  data(mtcars)
  task <- sl3_Task$new(
    mtcars,
    covariates = c("cyl", "disp", "hp", "drat"), outcome = "mpg"
  )

  grfcate_learner <- Lrnr_grfcate$new(A = "vs", num.threads = 1L)
  expect_error(grfcate_learner$train(task))
})

test_that("Lrnr_grfcate errors when prediction task missing covariates", {
  data(mtcars)
  task <- sl3_Task$new(
    mtcars,
    covariates = c("cyl", "disp", "hp", "drat", "vs"), outcome = "mpg"
  )
  task2 <- sl3_Task$new(
    mtcars,
    covariates = c("cyl", "disp", "hp", "vs")
  )
  grfcate_learner <- Lrnr_grfcate$new(A = "vs", num.threads = 1L)
  grfcate_fit <- grfcate_learner$train(task)
  expect_error(grfcate_fit$predict(task2))
})

# test_that("Lrnr_grfcate warns when prediction task has extra covariates", {
#   data(mtcars)
#   task <- sl3_Task$new(
#     mtcars, covariates = c("cyl", "disp", "hp", "vs"), outcome = "mpg"
#   )
#   task2 <- sl3_Task$new(
#     mtcars, covariates = c("cyl", "disp", "hp", "drat", "vs")
#   )
#   grfcate_learner <- Lrnr_grfcate$new(A = "vs", num.threads = 1L)
#   grfcate_fit <- grfcate_learner$train(task)
#   expect_warning(preds <- grfcate_fit$predict(task2))
# })
tlverse/sl3 documentation built on Nov. 18, 2024, 12:46 a.m.