context("bartc response fits")
source(system.file("common", "linearData.R", package = "bartCause"))
test_that("bart fit matches manual call", {
set.seed(22)
bartcFit <- bartCause:::getBartResponseFit(y, z, x, data = testData, estimand = "ate",
group.by = NULL, commonSup.rule = "none", commonSup.cut = NA,
n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L)
x.train <- with(testData, cbind(z, x))
# colnames(x.train) <- c("x1", "x2", "x3", "z")
x.test <- x.train
x.test[,"z"] <- 1 - x.test[,"z"]
y <- testData$y
set.seed(22)
bartFit <- dbarts::bart2(x.train, y, x.test, n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L, verbose = FALSE)
expect_equal(bartFit$yhat.train, bartcFit$fit$yhat.train)
expect_equal(bartFit$yhat.test, bartcFit$fit$yhat.test)
})
test_that("p.weight fits", {
set.seed(22)
testData$w <- 1 + rpois(length(testData$y), 0.5)
testCall <- quote(bartc(y, z, x, data = testData, method.trt = "glm", method.rsp = "p.weight",
n.chains = 1L, n.threads = 1L, n.samples = 13L, n.burn = 3L, n.trees = 7L,
verbose = FALSE))
expect_is(eval(testCall), "bartcFit")
testCall$method.trt <- "bart"
expect_is(eval(testCall), "bartcFit")
testCall$method.trt <- "glm"
testCall$weights <- quote(w)
expect_is(eval(testCall), "bartcFit")
testCall$method.trt <- "bart"
expect_is(eval(testCall), "bartcFit")
## multiple chains
testCall$n.chains <- 4L
testCall$method.trt <- "glm"
testCall$weights <- NULL
expect_is(eval(testCall), "bartcFit")
testCall$method.trt <- "bart"
expect_is(eval(testCall), "bartcFit")
testCall$method.trt <- "glm"
testCall$weights <- quote(w)
expect_is(eval(testCall), "bartcFit")
testCall$method.trt <- "bart"
expect_is(eval(testCall), "bartcFit")
})
source(system.file("common", "groupedData.R", package = "bartCause"))
test_that("rbart_vi fit matches manual call", {
set.seed(22)
bartcFit <- bartCause:::getBartResponseFit(y, z, x, data = testData, estimand = "ate", group.by = g, commonSup.rule = "none", commonSup.cut = NA,
n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L)
x.train <- with(testData, cbind(z, x))
# colnames(x.train) <- c("x1", "x2", "x3", "z")
x.test <- x.train
x.test[,"z"] <- 1 - x.test[,"z"]
y <- testData$y
set.seed(22)
bartFit <- dbarts::rbart_vi(x.train, y, x.test, group.by = testData$g, group.by.test = testData$g,
n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L, verbose = FALSE)
expect_equal(bartFit$yhat.train, bartcFit$fit$yhat.train)
expect_equal(bartFit$yhat.test, bartcFit$fit$yhat.test)
})
# commenting this out until crossvalidation calls have more control over run time
if (FALSE) test_that("xbart fit matches manual call", {
set.seed(22)
res <- bartCause:::getBartResponseFit(y, z, x, data = testData,
estimand = "ate", group.by = NULL, commonSup.rule = "none", commonSup.cut = NA,
n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L,
crossvalidate = TRUE)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.