tests/testthat/test-09-binary.R

context("binary outcomes")

source(system.file("common", "binaryData.R", package = "bartCause"))

set.seed(22)
bartcFit <- bartc(y, z, x, data = testData,
                  method.rsp = "bart", method.trt = "bart", verbose = FALSE,
                  n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L)

test_that("binary outcome model matches manual", {
  set.seed(22)
  fit.trt <- dbarts::bart2(z ~ x, testData, verbose = FALSE,
                           n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L)
  p.score <- apply(pnorm(fit.trt$yhat.train), 2L, mean)
  expect_equal(p.score, fitted(bartcFit, type = "p.score"))
  
  x.train <- cbind(z = testData$z, testData$x, ps = p.score)
  x.test  <- cbind(z = 1, testData$x, ps = p.score)
  x.test <- rbind(x.test, x.test)
  x.test[seq.int(nrow(testData$x) + 1L, nrow(x.test)),"z"] <- 0
  
  fit.rsp <- dbarts::bart2(x.train, testData$y, x.test, verbose = FALSE,
                           n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L)
  expect_equal(extract(bartcFit, type = "mu.0"),
               pnorm(fit.rsp$yhat.test)[,seq.int(nrow(testData$x) + 1L, nrow(x.test))])
})

test_that("summary works for binary outcomes", {
  sum <- summary(bartcFit)
  expect_is(sum, "bartcFit.summary")
})


test_that("binary outcome runs with tmle", {
  oldWarn <- getOption("warn")
  if (!requireNamespace("tmle", quietly = TRUE))
    options(warn = -1)
  
  expect_is(bartc(y, z, x, data = testData,
                  method.rsp = "tmle", method.trt = "bart", verbose = FALSE,
                  n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L), "bartcFit")
  
  options(warn = oldWarn)
})
vdorie/bartCause documentation built on May 5, 2024, 9:29 a.m.