tests/testthat/test-03-responseFit.R

context("bartc response fits")

source(system.file("common", "linearData.R", package = "bartCause"))

test_that("bart fit matches manual call", {
  set.seed(22)
  bartcFit <- bartCause:::getBartResponseFit(y, z, x, data = testData, estimand = "ate",
                                             group.by = NULL, commonSup.rule = "none", commonSup.cut = NA,
                                             n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L)
  x.train <- with(testData, cbind(z, x))
  # colnames(x.train) <- c("x1", "x2", "x3", "z")
  x.test <- x.train
  x.test[,"z"] <- 1 - x.test[,"z"]
  y <- testData$y
  set.seed(22)
  bartFit <- dbarts::bart2(x.train, y, x.test, n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L, verbose = FALSE)
      
  expect_equal(bartFit$yhat.train, bartcFit$fit$yhat.train)
  expect_equal(bartFit$yhat.test,  bartcFit$fit$yhat.test)
})

test_that("p.weight fits", {
  set.seed(22)
  testData$w <- 1 + rpois(length(testData$y), 0.5)
  
  testCall <- quote(bartc(y, z, x, data = testData, method.trt = "glm", method.rsp = "p.weight",
                          n.chains = 1L, n.threads = 1L, n.samples = 13L, n.burn = 3L, n.trees = 7L,
                          verbose = FALSE))
  
  expect_is(eval(testCall), "bartcFit")
  
  testCall$method.trt <- "bart"
  expect_is(eval(testCall), "bartcFit")
  
  testCall$method.trt <- "glm"
  testCall$weights <- quote(w)
  expect_is(eval(testCall), "bartcFit")
  
  testCall$method.trt <- "bart"
  expect_is(eval(testCall), "bartcFit")
  
  ## multiple chains
  testCall$n.chains  <- 4L
  testCall$method.trt <- "glm"
  testCall$weights <- NULL
  
  expect_is(eval(testCall), "bartcFit")
  
  testCall$method.trt <- "bart"
  expect_is(eval(testCall), "bartcFit")
  
  testCall$method.trt <- "glm"
  testCall$weights <- quote(w)
  expect_is(eval(testCall), "bartcFit")
  
  testCall$method.trt <- "bart"
  expect_is(eval(testCall), "bartcFit")
})

source(system.file("common", "groupedData.R", package = "bartCause"))

test_that("rbart_vi fit matches manual call", {
  set.seed(22)
  bartcFit <- bartCause:::getBartResponseFit(y, z, x, data = testData, estimand = "ate", group.by = g, commonSup.rule = "none", commonSup.cut = NA,
                                             n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L)
  x.train <- with(testData, cbind(z, x))
  # colnames(x.train) <- c("x1", "x2", "x3", "z")
  x.test <- x.train
  x.test[,"z"] <- 1 - x.test[,"z"]
  y <- testData$y
  set.seed(22)
  bartFit <- dbarts::rbart_vi(x.train, y, x.test, group.by = testData$g, group.by.test = testData$g,
                              n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L, verbose = FALSE)
      
  expect_equal(bartFit$yhat.train, bartcFit$fit$yhat.train)
  expect_equal(bartFit$yhat.test,  bartcFit$fit$yhat.test)
})

# commenting this out until crossvalidation calls have more control over run time
if (FALSE) test_that("xbart fit matches manual call", {
  set.seed(22)
  res <- bartCause:::getBartResponseFit(y, z, x, data = testData,
                                        estimand = "ate", group.by = NULL, commonSup.rule = "none", commonSup.cut = NA,
                                        n.chains = 1L, n.threads = 1L, n.burn = 3L, n.samples = 13L, n.trees = 7L,
                                        crossvalidate = TRUE)
})
vdorie/bartCause documentation built on May 5, 2024, 9:29 a.m.