tests/testthat/test-11-stan4bart.R

context("semiparametric models fit with stan4bart")

source(system.file("common", "friedmanData.R", package = "bartCause"))

testData <- generateFriedmanData(100, ranef = TRUE, causal = TRUE)

test.df <- with(testData, data.frame(x, y, z, g.1))

test_that("semiparametric models are consistent with each other", {
  # Because this is not documented, to enable this test execute from R
  #   Sys.setenv(NOT_CRAN = "true")
  # or from shell
  #   export NOT_CRAN=true
  skip_on_cran()
  skip_if_not_installed("stan4bart")
  
  seed <- 0
  fit1 <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
                group.by = g.1,
                seed = seed,
                data = test.df,
                verbose = FALSE)

  summary1 <- summary(fit1)
  
  fit2 <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
                parametric = (1 | g.1),
                seed = seed,
                data = test.df,
                verbose = FALSE)
  
  summary2 <- summary(fit2)

  expect_true(inherits(fit2$fit.trt, "stan4bartFit"))
  expect_equal(fit2$fit.trt$call$formula, str2lang("z ~ (1 | g.1) + bart(X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + 
    X9 + X10)"))
  data_frame <- test.df[setdiff(names(test.df), "y")]
  fit_frame <- fit2$fit.trt$frame[names(data_frame)]
  expect_equal(fit_frame, data_frame)
  
  expect_true(inherits(fit2$fit.rsp, "stan4bartFit"))
  expect_equal(fit2$fit.rsp$call$formula, str2lang("y ~ z + ps + bart(X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + 
    X9 + X10 + z + ps) + (1 | g.1)"))
  expect_equal(fit2$fit.rsp$frame[names(test.df)], test.df)
  
  fit3 <- bartc(y, z, X1 + X2 + X3 + X5 + X6 + X7 + X8 + X9 + X10,
                parametric = X4 + (1 | g.1),
                seed = seed,
                data = test.df,
                verbose = FALSE)
  
  summary3 <- summary(fit3)
  expect_true(inherits(fit3$fit.trt, "stan4bartFit"))
  expect_equal(fit3$fit.trt$call$formula, str2lang("z ~ X4 + (1 | g.1) + bart(X1 + X2 + X3 + X5 + X6 + X7 + X8 + 
    X9 + X10)"))
  data_frame <- test.df[setdiff(names(test.df), "y")]
  fit_frame <- fit3$fit.trt$frame[names(data_frame)]
  expect_equal(fit_frame, data_frame)
  
  expect_true(inherits(fit3$fit.rsp, "stan4bartFit"))
  expect_equal(fit3$fit.rsp$call$formula, str2lang("y ~ z + ps + bart(X1 + X2 + X3 + X5 + X6 + X7 + X8 + 
    X9 + X10 + z + ps) + (X4 + (1 | g.1))"))
  expect_equal(fit3$fit.rsp$frame[names(test.df)], test.df)

  
  expect_in_range <- function(x, r) expect_true(all(x >= r[1L] & x <= r[2L]))
  
  expect_in_range(summary1$estimates$estimate, c(4.8, 5.4))
  expect_in_range(summary2$estimates$estimate, c(4.8, 5.4))
  expect_in_range(summary3$estimates$estimate, c(4.8, 5.4))
  
  expect_in_range(summary1$estimates$sd, c(0.7, 0.85))
  expect_in_range(summary2$estimates$sd, c(0.7, 0.87))
  expect_in_range(summary3$estimates$sd, c(0.7, 0.85))
})

test_that("semiparametric model does not have constant treatment effect", {
  skip_if_not_installed("stan4bart")
  
  p.score <- fitted(glm(z ~ X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10, data = test.df, family = binomial))

  fit <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
               parametric = (1 | g.1),
               seed = 0,
               data = test.df,
               method.trt = p.score,
               iter = 20, warmup = 10, chains = 1,
               verbose = FALSE)
  
  icate <- extract(fit, type = "icate")
  expect_true(sd(colMeans(icate)) > 0.2)
})

test_that("semiparametric predict works", {
  skip_if_not_installed("stan4bart")
  
  fit <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
               parametric = (1 | g.1),
               seed = 0,
               data = test.df,
               verbose = FALSE,
               iter = 10, warmup = 3, chains = 2,
               bart_args = list(keepTrees = TRUE, n.trees = 7))
  
  expect_true(is.finite(fitted(fit)))

  p.score.train <- extract(fit, type = "p.score")
  p.score.test <- predict(fit, test.df, type = "p.score")
  
  expect_true(sqrt(mean((p.score.train - p.score.test)^2)) <= 1e-10)
  
  icate.train <- extract(fit, type = "icate")
  icate.test  <- predict(fit, test.df, type = "icate")
  
  expect_true(sqrt(mean((icate.train - icate.test)^2)) <= 1e-10)
})

test_that("refit works with common support", {
  skip_on_cran()
  skip_if_not_installed("stan4bart")
  
  sdcut <- 0.5
  chisqcut <- 0.3
  
  fit_sd <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
                  parametric = (1 | g.1), data = test.df,
                  verbose = FALSE, seed = 5,
                  commonSup.rule = "sd", commonSup.cut = sdcut)
  
  fit_chisq <- bartc(y, z, X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10,
                     parametric = (1 | g.1), data = test.df,
                     verbose = FALSE, seed = 5,
                     commonSup.rule = "chisq", commonSup.cut = chisqcut)
  
  refit_chisq <- refit(fit_chisq, commonSup.rule = "chisq", commonSup.cut = chisqcut)
  
  refit_sd <- refit(fit_sd, commonSup.rule = "sd", commonSup.cut = sdcut)
  
  expect_equal(fit_sd$commonSup.sub, refit_sd$commonSup.sub)
  expect_equal(fit_chisq$commonSup.sub, refit_chisq$commonSup.sub)
})
vdorie/bartCause documentation built on May 5, 2024, 9:29 a.m.