Nothing
context("semiparametric models fit with stan4bart")
source(system.file("common", "friedmanData.R", package = "bartCause"))
testData <- generateFriedmanData(100, ranef = TRUE, causal = TRUE)
test.df <- with(testData, data.frame(x, y, z, g.1))
test_that("semiparametric models are consistent with each other", {
# Because this is not documented, to enable this test execute from R
# Sys.setenv(NOT_CRAN = "true")
# or from shell
# export NOT_CRAN=true
skip_on_cran()
skip_if_not_installed("stan4bart")
seed <- 0
fit1 <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
group.by = g.1,
seed = seed,
data = test.df,
verbose = FALSE)
summary1 <- summary(fit1)
fit2 <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
parametric = (1 | g.1),
seed = seed,
data = test.df,
verbose = FALSE)
summary2 <- summary(fit2)
expect_true(inherits(fit2$fit.trt, "stan4bartFit"))
expect_equal(fit2$fit.trt$call$formula, str2lang("z ~ (1 | g.1) + bart(X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 +
X9 + X10)"))
data_frame <- test.df[setdiff(names(test.df), "y")]
fit_frame <- fit2$fit.trt$frame[names(data_frame)]
expect_equal(fit_frame, data_frame)
expect_true(inherits(fit2$fit.rsp, "stan4bartFit"))
expect_equal(fit2$fit.rsp$call$formula, str2lang("y ~ z + ps + bart(X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 +
X9 + X10 + z + ps) + (1 | g.1)"))
expect_equal(fit2$fit.rsp$frame[names(test.df)], test.df)
fit3 <- bartc(y, z, X1 + X2 + X3 + X5 + X6 + X7 + X8 + X9 + X10,
parametric = X4 + (1 | g.1),
seed = seed,
data = test.df,
verbose = FALSE)
summary3 <- summary(fit3)
expect_true(inherits(fit3$fit.trt, "stan4bartFit"))
expect_equal(fit3$fit.trt$call$formula, str2lang("z ~ X4 + (1 | g.1) + bart(X1 + X2 + X3 + X5 + X6 + X7 + X8 +
X9 + X10)"))
data_frame <- test.df[setdiff(names(test.df), "y")]
fit_frame <- fit3$fit.trt$frame[names(data_frame)]
expect_equal(fit_frame, data_frame)
expect_true(inherits(fit3$fit.rsp, "stan4bartFit"))
expect_equal(fit3$fit.rsp$call$formula, str2lang("y ~ z + ps + bart(X1 + X2 + X3 + X5 + X6 + X7 + X8 +
X9 + X10 + z + ps) + (X4 + (1 | g.1))"))
expect_equal(fit3$fit.rsp$frame[names(test.df)], test.df)
expect_in_range <- function(x, r) expect_true(all(x >= r[1L] & x <= r[2L]))
expect_in_range(summary1$estimates$estimate, c(4.8, 5.4))
expect_in_range(summary2$estimates$estimate, c(4.8, 5.4))
expect_in_range(summary3$estimates$estimate, c(4.8, 5.4))
expect_in_range(summary1$estimates$sd, c(0.7, 0.85))
expect_in_range(summary2$estimates$sd, c(0.7, 0.87))
expect_in_range(summary3$estimates$sd, c(0.7, 0.85))
})
test_that("semiparametric model does not have constant treatment effect", {
skip_if_not_installed("stan4bart")
p.score <- fitted(glm(z ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10, data = test.df, family = binomial))
fit <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
parametric = (1 | g.1),
seed = 0,
data = test.df,
method.trt = p.score,
iter = 20, warmup = 10, chains = 1,
verbose = FALSE)
icate <- extract(fit, type = "icate")
expect_true(sd(colMeans(icate)) > 0.2)
})
test_that("semiparametric predict works", {
skip_if_not_installed("stan4bart")
fit <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
parametric = (1 | g.1),
seed = 0,
data = test.df,
verbose = FALSE,
iter = 10, warmup = 3, chains = 2,
bart_args = list(keepTrees = TRUE, n.trees = 7))
expect_true(is.finite(fitted(fit)))
p.score.train <- extract(fit, type = "p.score")
p.score.test <- predict(fit, test.df, type = "p.score")
expect_true(sqrt(mean((p.score.train - p.score.test)^2)) <= 1e-10)
icate.train <- extract(fit, type = "icate")
icate.test <- predict(fit, test.df, type = "icate")
expect_true(sqrt(mean((icate.train - icate.test)^2)) <= 1e-10)
})
test_that("refit works with common support", {
skip_on_cran()
skip_if_not_installed("stan4bart")
sdcut <- 0.5
chisqcut <- 0.3
fit_sd <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
parametric = (1 | g.1), data = test.df,
verbose = FALSE, seed = 5,
commonSup.rule = "sd", commonSup.cut = sdcut)
fit_chisq <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
parametric = (1 | g.1), data = test.df,
verbose = FALSE, seed = 5,
commonSup.rule = "chisq", commonSup.cut = chisqcut)
refit_chisq <- refit(fit_chisq, commonSup.rule = "chisq", commonSup.cut = chisqcut)
refit_sd <- refit(fit_sd, commonSup.rule = "sd", commonSup.cut = sdcut)
expect_equal(fit_sd$commonSup.sub, refit_sd$commonSup.sub)
expect_equal(fit_chisq$commonSup.sub, refit_chisq$commonSup.sub)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.