context("test grfcate.R -- Generalized Random Forests")
test_that("Lrnr_grfcate binary A preds match those of grf::causal_forest", {
library(grf)
# Generate some data
set.seed(791)
n <- 500
p <- 10
X <- matrix(rnorm(n * p), n, p)
A <- rbinom(n, 1, 0.15)
X <- cbind(A, X)
X.test <- matrix(0, 101, p)
X.test[, 1] <- seq(-2, 2, length.out = 101)
Y <- X[, 1] * rnorm(n) + 0.5 * A # added a constant treatment effect
d <- data.frame(list(Y = Y, X = X))
names(d)[2] <- "A"
task <- sl3_Task$new(d, covariates = names(d)[-1], outcome = names(d)[1])
seed_int <- 496L
set.seed(seed_int)
# GRF learner class
grfcate_learner <- Lrnr_grfcate$new(A = "A", seed = seed_int, num.threads = 1L)
grfcate_fit <- grfcate_learner$train(task)
grfcate_pred <- grfcate_fit$predict(task)
set.seed(seed_int)
grfcate_pkg <- grf::causal_forest(
X = X[, -1], W = X[, 1], Y = Y, seed = seed_int, num.threads = 1L
)
grfcate_pkg_pred_out <- predict(grfcate_pkg)
grfcate_pkg_pred <- as.numeric(grfcate_pkg_pred_out$predictions)
# test equivalence
expect_equal(grfcate_pred, grfcate_pkg_pred)
})
test_that("Lrnr_grfcate continuous A preds match those of grf::causal_forest", {
library(grf)
# Generate some data
set.seed(791)
n <- 500
p <- 10
X <- matrix(rnorm(n * p), n, p)
A <- rnorm(n)
X <- cbind(A, X)
X.test <- matrix(0, 101, p)
X.test[, 1] <- seq(-2, 2, length.out = 101)
Y <- X[, 1] * rnorm(n) + 0.5 * A
d <- data.frame(list(Y = Y, X = X))
names(d)[2] <- "A"
task <- sl3_Task$new(d, covariates = names(d)[-1], outcome = names(d)[1])
seed_int <- 496L
set.seed(seed_int)
# GRF learner class
grfcate_learner <- Lrnr_grfcate$new(A = "A", seed = seed_int, num.threads = 1L)
grfcate_fit <- grfcate_learner$train(task)
grfcate_pred <- grfcate_fit$predict(task)
set.seed(seed_int)
grfcate_pkg <- grf::causal_forest(
X = X[, -1], W = X[, 1], Y = Y, seed = seed_int, num.threads = 1L
)
grfcate_pkg_pred_out <- predict(grfcate_pkg)
grfcate_pkg_pred <- as.numeric(grfcate_pkg_pred_out$predictions)
# test equivalence
expect_equal(grfcate_pred, grfcate_pkg_pred)
})
test_that("Lrnr_grfcate errors when A not in covariates", {
data(mtcars)
task <- sl3_Task$new(
mtcars,
covariates = c("cyl", "disp", "hp", "drat"), outcome = "mpg"
)
grfcate_learner <- Lrnr_grfcate$new(A = "vs", num.threads = 1L)
expect_error(grfcate_learner$train(task))
})
test_that("Lrnr_grfcate errors when prediction task missing covariates", {
data(mtcars)
task <- sl3_Task$new(
mtcars,
covariates = c("cyl", "disp", "hp", "drat", "vs"), outcome = "mpg"
)
task2 <- sl3_Task$new(
mtcars,
covariates = c("cyl", "disp", "hp", "vs")
)
grfcate_learner <- Lrnr_grfcate$new(A = "vs", num.threads = 1L)
grfcate_fit <- grfcate_learner$train(task)
expect_error(grfcate_fit$predict(task2))
})
# test_that("Lrnr_grfcate warns when prediction task has extra covariates", {
# data(mtcars)
# task <- sl3_Task$new(
# mtcars, covariates = c("cyl", "disp", "hp", "vs"), outcome = "mpg"
# )
# task2 <- sl3_Task$new(
# mtcars, covariates = c("cyl", "disp", "hp", "drat", "vs")
# )
# grfcate_learner <- Lrnr_grfcate$new(A = "vs", num.threads = 1L)
# grfcate_fit <- grfcate_learner$train(task)
# expect_warning(preds <- grfcate_fit$predict(task2))
# })
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.