tests/testthat/test-05-generics.R

context("generic functions")

source(system.file("common", "groupedData.R", package = "bartCause"))

set.seed(22)
fit <- bartc(y, z, x, data = testData, method.trt = "glm", n.samples = 50L,
             group.by = g, use.ranef = FALSE, group.effects = TRUE,
             n.burn = 25L, n.chains = 4L, n.threads = 1L, verbose = FALSE)

p.score <- fitted(glm(z ~ x + g, family = stats::binomial, data = testData))

set.seed(22)
x.train <- cbind(z = testData$z, testData$x, p.score, testData$g)
x.test  <- x.train; x.test[,"z"] <- 1 - x.test[,"z"]
bartFit <- dbarts::bart2(x.train, testData$y, x.test, n.samples = 50L, n.burn = 25L,
                         n.chains = 4L, n.threads = 1L, verbose = FALSE)

obsCfToTrtCtl <- function(obs, cf, trt) {
  if (length(dim(obs)) > 2L) {
    aperm(aperm(obs, c(3L, 1L, 2L)) * trt + aperm(cf, c(3L, 1L, 2L)) * (1 - trt), c(2L, 3L, 1L))
  } else {
    t(t(obs) * trt + t(cf) * (1 - trt))
  }
}

samples.obs <- bartFit$yhat.train
samples.cf  <- bartFit$yhat.test
samples.mu.0 <- obsCfToTrtCtl(samples.obs, samples.cf, 1 - testData$z)
samples.mu.1 <- obsCfToTrtCtl(samples.obs, samples.cf, testData$z)

test_that("combine chains works as expected", {
  combineChains <- bartCause:::combineChains
  
  mu.obs <- extract(fit, "mu.obs")
  expect_equal(as.vector(mu.obs), as.vector(aperm(fit$mu.hat.obs, c(2, 1, 3))))
  sigma <- extract(fit, "sigma")
  expect_equal(sigma, as.vector(t(fit$fit.rsp$sigma)))
})

test_that("fitted matches manual fit", {
  cate <- fitted(fit, "cate")
  mu.1 <- fitted(fit, "mu.1")
  mu.0 <- fitted(fit, "mu.0")
  icate <- fitted(fit, "icate")
  mu.obs <- fitted(fit, "mu.obs")
  
  expect_equal(mu.0, apply(samples.mu.0, 3L, mean))
  expect_equal(mu.1, apply(samples.mu.1, 3L, mean)) 
  
  expect_equal(icate, apply(samples.mu.1 - samples.mu.0, 3L, mean))
  expect_equal(mu.obs, apply(samples.obs, 3L, mean))
  expect_equal(fitted(fit, "p.score"), p.score)
  
  groups <- levels(as.factor(testData$g))
  expect_equal(length(cate), length(groups))
  for (group in groups)
    expect_equal(cate[[as.character(group)]], mean(icate[testData$g == group]))
})

test_that("extract matches manual fit", {
  ## first that combine chains works
  mu.0 <- extract(fit, "mu.0")
  expect_equal(mu.0, matrix(aperm(samples.mu.0, c(2L, 1L, 3L)), dim(samples.mu.0)[1L] * dim(samples.mu.0)[2L], dim(samples.mu.0)[3L]))
  
  mu.0   <- extract(fit, "mu.0",   combineChains = FALSE)
  mu.1   <- extract(fit, "mu.1",   combineChains = FALSE)
  cate   <- extract(fit, "cate",   combineChains = FALSE)
  icate  <- extract(fit, "icate",  combineChains = FALSE)
  mu.obs <- extract(fit, "mu.obs", combineChains = FALSE)
  
  expect_equal(mu.0, samples.mu.0) 
  expect_equal(mu.1, samples.mu.1)
  
  expect_equal(icate, samples.mu.1 - samples.mu.0)
  expect_equal(mu.obs, samples.obs)
  
  groups <- levels(as.factor(testData$g))
  expect_equal(length(cate), length(groups))
  for (group in groups)
    expect_equal(cate[[as.character(group)]], apply(icate[,,testData$g == group], c(1L, 2L), mean))
})

test_that("ppd-based estimates match manual", {
  expect_equal(as.numeric((testData$y - fitted(fit, "y.cf")) * (2 * testData$z - 1)),
               fitted(fit, "ite"))
  expect_equal(mean((testData$y - fitted(fit, "y.cf")) * (2 * testData$z - 1)),
               sum(fitted(fit, "sate") * (table(testData$g) / length(testData$y))))
  expect_equal(testData$y[testData$z == 1], fitted(fit, "y.1")[testData$z == 1])
  expect_equal(testData$y[testData$z == 0], fitted(fit, "y.0")[testData$z == 0])
})

test_that("summary object contains correct information", {
  sum <- summary(fit)
  
  testCall <- parse(text = 'bartc(response = y, treatment = z, confounders = x, data = testData, method.trt = "glm", group.by = g, group.effects = TRUE, use.ranef = FALSE, n.samples = 50L,
             n.burn = 25L, n.chains = 4L, n.threads = 1L, verbose = FALSE)')[[1L]]
  
  expect_true(length(testCall) == length(sum$call) && all(names(sum$call) %in% names(testCall)) &&
              all(names(testCall) %in% names(sum$call)) && all(sum$call[order(names(sum$call))] == testCall[order(names(testCall))]))
  expect_equal(sum$method.rsp, "bart")
  expect_equal(sum$method.trt, "glm")
  expect_equal(sum$ci.info$ci.style, eval(formals(bartCause:::summary.bartcFit)$ci.style)[1L])
  expect_equal(sum$ci.info$ci.level, eval(formals(bartCause:::summary.bartcFit)$ci.level))
  expect_equal(tail(sum$n.obs, 1L), length(testData$y))
  expect_equal(sum$n.samples, 50L)
  expect_equal(sum$n.chains, 4L)
  expect_equal(head(sum$estimates$estimate, -1L), unname(fitted(fit, "cate")))
})

test_that("generics work for p.weights", {
  pfit <- bartc(y, z, x, data = testData, method.trt = "bart", method.rsp = "p.weight", estimand = "att",
                group.by = g, group.effects = TRUE, n.chains = 3L,
                n.samples = 7L, n.burn = 3L, n.threads = 1L, verbose = FALSE)
  pfit.sum <- summary(pfit)
  
  p.weights  <- extract(pfit, "p.weights", sample = "all")
  
  groups <- levels(as.factor(testData$g))
  g.sel <- lapply(groups, function(group) which(testData$g == group))
  
  boundValues <- bartCause:::boundValues
  
  ## match internal bounding
  yBounds <- c(.005, .995)
  p.scoreBounds <- c(0.025, 0.975)
  
  # warnings because "mu.0" isn't meaningful if using p-weights to compute ATT
  mu.0 <- suppressWarnings(extract(pfit, type = "mu.0", sample = "all"))
  mu.1 <- suppressWarnings(extract(pfit, type = "mu.1", sample = "all"))
  p.score <- extract(pfit, sample = "all", type = "p.score")
  
  for (j in seq_along(groups)) {
    m <- min(testData$y[g.sel[[j]]]); M <- max(testData$y[g.sel[[j]]])
    
    mu.hat.0 <- boundValues((boundValues(mu.0[,g.sel[[j]]], c(m, M)) - m) / (M - m), yBounds)
    mu.hat.1 <- boundValues((boundValues(mu.1[,g.sel[[j]]], c(m, M)) - m) / (M - m), yBounds)
    
    icate <- mu.hat.1 - mu.hat.0

    # replicate internal with:
    # temp <- bartCause:::getPWeightEstimates(testData$y[g.sel[[j]]], testData$z[g.sel[[j]]], NULL, "att", mu.hat.0, mu.hat.1,
    #                            extract(pfit, sample = "all", type = "p.score")[g.sel[[j]],], yBounds, p.scoreBounds)
    # f <- bartCause:::getPWeightFunction("att", NULL, icate, boundValues(p.weights[g.sel[[j]],], p.scoreBounds))
    # mean(f(testData$z[g.sel[[j]]], NULL, icate, p.score[g.sel[[j]],])) * (M - m)
    
    est.unscaled <- mean(apply((icate * boundValues(p.score[,g.sel[[j]]], p.scoreBounds)), 2L, mean)) / mean(testData$z[g.sel[[j]]])
    expect_equal(pfit.sum$est$estimate[j], est.unscaled * (M - m))
  }
  
  expect_equal(apply(p.weights, length(dim(p.weights)), mean), fitted(pfit, "p.weights", sample = "all"))
})

test_that("summary works with different styles", {
  expect_is(summary(fit, ci.style = "norm"),  "bartcFit.summary")
  expect_is(summary(fit, ci.style = "quant"), "bartcFit.summary")
  expect_is(summary(fit, ci.style = "hpd"),   "bartcFit.summary")
  expect_is(summary(fit, pate.style = "ppd"), "bartcFit.summary")  
  
  set.seed(22)
  unweighted_fit <- bartc(y, z, x, data = testData, method.trt = "glm", verbose = FALSE,
                          n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)


  expect_is(unweighted_summary <- summary(unweighted_fit, pate.style = "var.exp"), "bartcFit.summary")
  
  set.seed(22)
  weighted_fit <- bartc(y, z, x, data = testData, method.trt = "glm", verbose = FALSE,
                        weights = rep(1, length(testData$y)),
                        n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
  expect_is(weighted_summary <- summary(weighted_fit, pate.style = "var.exp"), "bartcFit.summary")
  
  expect_equal(unweighted_summary$estimate$est, weighted_summary$estimate$est)
  expect_equal(unweighted_summary$estimate$sd, weighted_summary$estimate$sd)
  
  icates <- extract(unweighted_fit, "icate")
  n.samples <- dim(icates)[1L]
  n.obs <- dim(icates)[2L]
  
  var_tot <- var(as.vector(icates))
  var_w <- var(extract(unweighted_fit, "cate"))
  var_b <- mean(apply((icates - apply(icates, 1, mean))^2, 1, sum) / (n.obs - 1))
  
  
  expect_equal(sqrt(var_w + var_b), unweighted_summary$estimate$sd)
  expect_equal(var_tot, (var_w * (n.samples - 1) / n.samples + var_b * (n.obs - 1) / n.obs) * n.obs * n.samples / (n.obs * n.samples - 1))
})

test_that("summary works with different styles for method tmle", {
  skip_on_cran()

  oldWarn <- getOption("warn")
  if (!requireNamespace("tmle", quietly = TRUE))
    options(warn = -1)
  
  fit <- bartc(y, z, x, data = testData, method.trt = "glm", method.rsp = "tmle", verbose = FALSE,
               group.by = g, group.effects = TRUE ,use.ranef = FALSE,
               n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
  
  options(warn = oldWarn)
  
  expect_is(summary(fit), "bartcFit.summary")
  expect_is(summary(fit, pate.style = "ppd"), "bartcFit.summary")
})

test_that("summary works with att/atc", {
  fit <- bartc(y, z, x, data = testData, estimand = "att",
               method.trt = "bart", method.rsp = "bart", verbose = FALSE,
               n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
  expect_is(summary(fit, "pate"), "bartcFit.summary")
  expect_is(summary(fit, "sate"), "bartcFit.summary")
  expect_is(summary(fit, "cate"), "bartcFit.summary")
})

test_that("summary gives consistent answers with grouped data", {
  inGroupFit <- bartc(y, z, x, data = testData, estimand = "ate",
                      group.by = g, group.effects = TRUE,
                      method.trt = "bart", method.rsp = "bart", verbose = FALSE,
                      n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
  

  sum.g.cate <- summary(inGroupFit, target = "cate")
  sum.g.sate <- summary(inGroupFit, target = "sate")
  sum.g.pate <- summary(inGroupFit, target = "pate")
  
  # test that sub group estimates are actually subgroup estimates
  samples.icate <- extract(inGroupFit, "icate", "all")
  samples.gcate <- lapply(unique(testData$g), function(j) rowMeans(samples.icate[,testData$g == j]))
  names(samples.gcate) <- unique(testData$g)
  expect_equal(sum.g.cate$estimates[names(samples.gcate),]$estimate, unname(sapply(samples.gcate, mean)))
  expect_equal(sum.g.cate$estimates[names(samples.gcate),]$sd, unname(sapply(samples.gcate, sd)))
  expect_equal(sum.g.cate$estimates["total",]$estimate, mean(samples.icate))
  expect_equal(sum.g.cate$estimates["total",]$sd, sd(rowMeans(samples.icate)))

  
  expect_equal(nrow(sum.g.cate$estimates), length(unique(testData$g)) + 1L)
  expect_true(length(unique(sum.g.sate$estimates$estimate)) > 1L)

  expect_equal(sum.g.cate$estimates$estimate, sum.g.pate$estimates$estimate)

  expect_true(all(sum.g.pate$estimates$sd > sum.g.cate$estimates$sd))

  expect_equal(unname(sapply(extract(inGroupFit, "cate"), mean)),
               head(sum.g.cate$estimates$estimate, -1L))
})

test_that("common support cutoffs are being applied consistently", {
  n.chains <- 2L
  n.samples <- 7L
  n.obs <- length(testData$y)
  fit <- bartc(y, z, x, data = testData, estimand = "ate",
               method.trt = "bart", method.rsp = "bart", verbose = FALSE,
               commonSup.rule = "sd", seed = 5,
               n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples,
               n.trees = 13L)
  
  sum.cate <- summary(fit, target = "cate")
  sum.sate <- summary(fit, target = "sate")
  
  icates <- extract(fit, "icate", "all")[,fit$commonSup.sub]
  y.obs <- as.vector(testData$y)
  
  
  oldSeed <- .GlobalEnv$.Random.seed
  .GlobalEnv$.Random.seed <- fit$seed
  
  mu.cf <- extract(fit, "mu.cf", combineChains = TRUE)
  iscates <- t((y.obs - t(mu.cf)) * (2 * testData$z - 1))
  iscates <- iscates[,fit$commonSup.sub]
  
  mu.cf <- aperm(array(mu.cf, c(n.samples, n.chains, n.obs)), c(2L, 1L, 3L))
  expect_equal(mu.cf, extract(fit, "mu.cf", combineChains = FALSE))
  
  sigma <- rep(extract(fit, "sigma", combineChains = FALSE), times = n.samples)
  
  epsilon <- rnorm(prod(dim(mu.cf)), 0, sigma)
  
  .GlobalEnv$.Random.seed <- oldSeed
  
  y.cf <- mu.cf + epsilon
  expect_equal(y.cf, extract(fit, "y.cf", combineChains = FALSE))
  y.cf <- matrix(aperm(y.cf, c(2L, 1L, 3L)), nrow = n.samples * n.chains)
  expect_equal(y.cf, extract(fit, "y.cf", combineChains = TRUE))
  
  ites <- t((y.obs - t(y.cf)) * ifelse(testData$z == 1, 1, -1))
  ites <- ites[,fit$commonSup.sub]
  
  expect_equal(sd(apply(ites, 1, mean)), sd(extract(fit, "sate")))
  
  expect_true(!is.nan(sum.cate$estimates$estimate) && is.finite(sum.cate$estimates$estimate))
  expect_true(!is.nan(sum.sate$estimates$estimate) && is.finite(sum.sate$estimates$estimate))
  expect_equal(sum.cate$estimates$estimate, mean(icates))
  expect_equal(sum.cate$estimates$sd, sd(apply(icates, 1, mean)))
  expect_equal(sum.sate$estimates$estimate, mean(iscates))
  expect_equal(sum.sate$estimates$sd,
               sqrt(var(rowMeans(iscates)) + mean(extract(fit, "sigma")^2) / sum(fit$commonSup.sub)))
})

source(system.file("common", "friedmanData.R", package = "bartCause"))

test_that("sate summary is correct quantity", {
  data <- generateFriedmanData(n = 100, causal = TRUE)

  fit <- bartc(y, z, x, data = data, verbose = FALSE, samples = 50L,
               n.burn = 25L, n.chains = 4L, n.threads = 1L, seed = 0)

  fit_sum <- summary(fit, "sate")
  samples.sate <- extract(fit, "sate")
  
  expect_true(abs(fit_sum$estimates$estimate - mean(samples.sate)) < 1e-1)
  expect_true(abs(fit_sum$estimates$sd - sd(samples.sate)) < 1e-1)
})
vdorie/bartCause documentation built on May 5, 2024, 9:29 a.m.