context("generic functions")
source(system.file("common", "groupedData.R", package = "bartCause"))
set.seed(22)
fit <- bartc(y, z, x, data = testData, method.trt = "glm", n.samples = 50L,
group.by = g, use.ranef = FALSE, group.effects = TRUE,
n.burn = 25L, n.chains = 4L, n.threads = 1L, verbose = FALSE)
p.score <- fitted(glm(z ~ x + g, family = stats::binomial, data = testData))
set.seed(22)
x.train <- cbind(z = testData$z, testData$x, p.score, testData$g)
x.test <- x.train; x.test[,"z"] <- 1 - x.test[,"z"]
bartFit <- dbarts::bart2(x.train, testData$y, x.test, n.samples = 50L, n.burn = 25L,
n.chains = 4L, n.threads = 1L, verbose = FALSE)
obsCfToTrtCtl <- function(obs, cf, trt) {
if (length(dim(obs)) > 2L) {
aperm(aperm(obs, c(3L, 1L, 2L)) * trt + aperm(cf, c(3L, 1L, 2L)) * (1 - trt), c(2L, 3L, 1L))
} else {
t(t(obs) * trt + t(cf) * (1 - trt))
}
}
samples.obs <- bartFit$yhat.train
samples.cf <- bartFit$yhat.test
samples.mu.0 <- obsCfToTrtCtl(samples.obs, samples.cf, 1 - testData$z)
samples.mu.1 <- obsCfToTrtCtl(samples.obs, samples.cf, testData$z)
test_that("combine chains works as expected", {
combineChains <- bartCause:::combineChains
mu.obs <- extract(fit, "mu.obs")
expect_equal(as.vector(mu.obs), as.vector(aperm(fit$mu.hat.obs, c(2, 1, 3))))
sigma <- extract(fit, "sigma")
expect_equal(sigma, as.vector(t(fit$fit.rsp$sigma)))
})
test_that("fitted matches manual fit", {
cate <- fitted(fit, "cate")
mu.1 <- fitted(fit, "mu.1")
mu.0 <- fitted(fit, "mu.0")
icate <- fitted(fit, "icate")
mu.obs <- fitted(fit, "mu.obs")
expect_equal(mu.0, apply(samples.mu.0, 3L, mean))
expect_equal(mu.1, apply(samples.mu.1, 3L, mean))
expect_equal(icate, apply(samples.mu.1 - samples.mu.0, 3L, mean))
expect_equal(mu.obs, apply(samples.obs, 3L, mean))
expect_equal(fitted(fit, "p.score"), p.score)
groups <- levels(as.factor(testData$g))
expect_equal(length(cate), length(groups))
for (group in groups)
expect_equal(cate[[as.character(group)]], mean(icate[testData$g == group]))
})
test_that("extract matches manual fit", {
## first that combine chains works
mu.0 <- extract(fit, "mu.0")
expect_equal(mu.0, matrix(aperm(samples.mu.0, c(2L, 1L, 3L)), dim(samples.mu.0)[1L] * dim(samples.mu.0)[2L], dim(samples.mu.0)[3L]))
mu.0 <- extract(fit, "mu.0", combineChains = FALSE)
mu.1 <- extract(fit, "mu.1", combineChains = FALSE)
cate <- extract(fit, "cate", combineChains = FALSE)
icate <- extract(fit, "icate", combineChains = FALSE)
mu.obs <- extract(fit, "mu.obs", combineChains = FALSE)
expect_equal(mu.0, samples.mu.0)
expect_equal(mu.1, samples.mu.1)
expect_equal(icate, samples.mu.1 - samples.mu.0)
expect_equal(mu.obs, samples.obs)
groups <- levels(as.factor(testData$g))
expect_equal(length(cate), length(groups))
for (group in groups)
expect_equal(cate[[as.character(group)]], apply(icate[,,testData$g == group], c(1L, 2L), mean))
})
test_that("ppd-based estimates match manual", {
expect_equal(as.numeric((testData$y - fitted(fit, "y.cf")) * (2 * testData$z - 1)),
fitted(fit, "ite"))
expect_equal(mean((testData$y - fitted(fit, "y.cf")) * (2 * testData$z - 1)),
sum(fitted(fit, "sate") * (table(testData$g) / length(testData$y))))
expect_equal(testData$y[testData$z == 1], fitted(fit, "y.1")[testData$z == 1])
expect_equal(testData$y[testData$z == 0], fitted(fit, "y.0")[testData$z == 0])
})
test_that("summary object contains correct information", {
sum <- summary(fit)
testCall <- parse(text = 'bartc(response = y, treatment = z, confounders = x, data = testData, method.trt = "glm", group.by = g, group.effects = TRUE, use.ranef = FALSE, n.samples = 50L,
n.burn = 25L, n.chains = 4L, n.threads = 1L, verbose = FALSE)')[[1L]]
expect_true(length(testCall) == length(sum$call) && all(names(sum$call) %in% names(testCall)) &&
all(names(testCall) %in% names(sum$call)) && all(sum$call[order(names(sum$call))] == testCall[order(names(testCall))]))
expect_equal(sum$method.rsp, "bart")
expect_equal(sum$method.trt, "glm")
expect_equal(sum$ci.info$ci.style, eval(formals(bartCause:::summary.bartcFit)$ci.style)[1L])
expect_equal(sum$ci.info$ci.level, eval(formals(bartCause:::summary.bartcFit)$ci.level))
expect_equal(tail(sum$n.obs, 1L), length(testData$y))
expect_equal(sum$n.samples, 50L)
expect_equal(sum$n.chains, 4L)
expect_equal(head(sum$estimates$estimate, -1L), unname(fitted(fit, "cate")))
})
test_that("generics work for p.weights", {
pfit <- bartc(y, z, x, data = testData, method.trt = "bart", method.rsp = "p.weight", estimand = "att",
group.by = g, group.effects = TRUE, n.chains = 3L,
n.samples = 7L, n.burn = 3L, n.threads = 1L, verbose = FALSE)
pfit.sum <- summary(pfit)
p.weights <- extract(pfit, "p.weights", sample = "all")
groups <- levels(as.factor(testData$g))
g.sel <- lapply(groups, function(group) which(testData$g == group))
boundValues <- bartCause:::boundValues
## match internal bounding
yBounds <- c(.005, .995)
p.scoreBounds <- c(0.025, 0.975)
# warnings because "mu.0" isn't meaningful if using p-weights to compute ATT
mu.0 <- suppressWarnings(extract(pfit, type = "mu.0", sample = "all"))
mu.1 <- suppressWarnings(extract(pfit, type = "mu.1", sample = "all"))
p.score <- extract(pfit, sample = "all", type = "p.score")
for (j in seq_along(groups)) {
m <- min(testData$y[g.sel[[j]]]); M <- max(testData$y[g.sel[[j]]])
mu.hat.0 <- boundValues((boundValues(mu.0[,g.sel[[j]]], c(m, M)) - m) / (M - m), yBounds)
mu.hat.1 <- boundValues((boundValues(mu.1[,g.sel[[j]]], c(m, M)) - m) / (M - m), yBounds)
icate <- mu.hat.1 - mu.hat.0
# replicate internal with:
# temp <- bartCause:::getPWeightEstimates(testData$y[g.sel[[j]]], testData$z[g.sel[[j]]], NULL, "att", mu.hat.0, mu.hat.1,
# extract(pfit, sample = "all", type = "p.score")[g.sel[[j]],], yBounds, p.scoreBounds)
# f <- bartCause:::getPWeightFunction("att", NULL, icate, boundValues(p.weights[g.sel[[j]],], p.scoreBounds))
# mean(f(testData$z[g.sel[[j]]], NULL, icate, p.score[g.sel[[j]],])) * (M - m)
est.unscaled <- mean(apply((icate * boundValues(p.score[,g.sel[[j]]], p.scoreBounds)), 2L, mean)) / mean(testData$z[g.sel[[j]]])
expect_equal(pfit.sum$est$estimate[j], est.unscaled * (M - m))
}
expect_equal(apply(p.weights, length(dim(p.weights)), mean), fitted(pfit, "p.weights", sample = "all"))
})
test_that("summary works with different styles", {
expect_is(summary(fit, ci.style = "norm"), "bartcFit.summary")
expect_is(summary(fit, ci.style = "quant"), "bartcFit.summary")
expect_is(summary(fit, ci.style = "hpd"), "bartcFit.summary")
expect_is(summary(fit, pate.style = "ppd"), "bartcFit.summary")
set.seed(22)
unweighted_fit <- bartc(y, z, x, data = testData, method.trt = "glm", verbose = FALSE,
n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
expect_is(unweighted_summary <- summary(unweighted_fit, pate.style = "var.exp"), "bartcFit.summary")
set.seed(22)
weighted_fit <- bartc(y, z, x, data = testData, method.trt = "glm", verbose = FALSE,
weights = rep(1, length(testData$y)),
n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
expect_is(weighted_summary <- summary(weighted_fit, pate.style = "var.exp"), "bartcFit.summary")
expect_equal(unweighted_summary$estimate$est, weighted_summary$estimate$est)
expect_equal(unweighted_summary$estimate$sd, weighted_summary$estimate$sd)
icates <- extract(unweighted_fit, "icate")
n.samples <- dim(icates)[1L]
n.obs <- dim(icates)[2L]
var_tot <- var(as.vector(icates))
var_w <- var(extract(unweighted_fit, "cate"))
var_b <- mean(apply((icates - apply(icates, 1, mean))^2, 1, sum) / (n.obs - 1))
expect_equal(sqrt(var_w + var_b), unweighted_summary$estimate$sd)
expect_equal(var_tot, (var_w * (n.samples - 1) / n.samples + var_b * (n.obs - 1) / n.obs) * n.obs * n.samples / (n.obs * n.samples - 1))
})
test_that("summary works with different styles for method tmle", {
skip_on_cran()
oldWarn <- getOption("warn")
if (!requireNamespace("tmle", quietly = TRUE))
options(warn = -1)
fit <- bartc(y, z, x, data = testData, method.trt = "glm", method.rsp = "tmle", verbose = FALSE,
group.by = g, group.effects = TRUE ,use.ranef = FALSE,
n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
options(warn = oldWarn)
expect_is(summary(fit), "bartcFit.summary")
expect_is(summary(fit, pate.style = "ppd"), "bartcFit.summary")
})
test_that("summary works with att/atc", {
fit <- bartc(y, z, x, data = testData, estimand = "att",
method.trt = "bart", method.rsp = "bart", verbose = FALSE,
n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
expect_is(summary(fit, "pate"), "bartcFit.summary")
expect_is(summary(fit, "sate"), "bartcFit.summary")
expect_is(summary(fit, "cate"), "bartcFit.summary")
})
test_that("summary gives consistent answers with grouped data", {
inGroupFit <- bartc(y, z, x, data = testData, estimand = "ate",
group.by = g, group.effects = TRUE,
method.trt = "bart", method.rsp = "bart", verbose = FALSE,
n.chains = 2L, n.threads = 1L, n.burn = 0L, n.samples = 7L, n.trees = 13L)
sum.g.cate <- summary(inGroupFit, target = "cate")
sum.g.sate <- summary(inGroupFit, target = "sate")
sum.g.pate <- summary(inGroupFit, target = "pate")
# test that sub group estimates are actually subgroup estimates
samples.icate <- extract(inGroupFit, "icate", "all")
samples.gcate <- lapply(unique(testData$g), function(j) rowMeans(samples.icate[,testData$g == j]))
names(samples.gcate) <- unique(testData$g)
expect_equal(sum.g.cate$estimates[names(samples.gcate),]$estimate, unname(sapply(samples.gcate, mean)))
expect_equal(sum.g.cate$estimates[names(samples.gcate),]$sd, unname(sapply(samples.gcate, sd)))
expect_equal(sum.g.cate$estimates["total",]$estimate, mean(samples.icate))
expect_equal(sum.g.cate$estimates["total",]$sd, sd(rowMeans(samples.icate)))
expect_equal(nrow(sum.g.cate$estimates), length(unique(testData$g)) + 1L)
expect_true(length(unique(sum.g.sate$estimates$estimate)) > 1L)
expect_equal(sum.g.cate$estimates$estimate, sum.g.pate$estimates$estimate)
expect_true(all(sum.g.pate$estimates$sd > sum.g.cate$estimates$sd))
expect_equal(unname(sapply(extract(inGroupFit, "cate"), mean)),
head(sum.g.cate$estimates$estimate, -1L))
})
test_that("common support cutoffs are being applied consistently", {
n.chains <- 2L
n.samples <- 7L
n.obs <- length(testData$y)
fit <- bartc(y, z, x, data = testData, estimand = "ate",
method.trt = "bart", method.rsp = "bart", verbose = FALSE,
commonSup.rule = "sd", seed = 5,
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples,
n.trees = 13L)
sum.cate <- summary(fit, target = "cate")
sum.sate <- summary(fit, target = "sate")
icates <- extract(fit, "icate", "all")[,fit$commonSup.sub]
y.obs <- as.vector(testData$y)
oldSeed <- .GlobalEnv$.Random.seed
.GlobalEnv$.Random.seed <- fit$seed
mu.cf <- extract(fit, "mu.cf", combineChains = TRUE)
iscates <- t((y.obs - t(mu.cf)) * (2 * testData$z - 1))
iscates <- iscates[,fit$commonSup.sub]
mu.cf <- aperm(array(mu.cf, c(n.samples, n.chains, n.obs)), c(2L, 1L, 3L))
expect_equal(mu.cf, extract(fit, "mu.cf", combineChains = FALSE))
sigma <- rep(extract(fit, "sigma", combineChains = FALSE), times = n.samples)
epsilon <- rnorm(prod(dim(mu.cf)), 0, sigma)
.GlobalEnv$.Random.seed <- oldSeed
y.cf <- mu.cf + epsilon
expect_equal(y.cf, extract(fit, "y.cf", combineChains = FALSE))
y.cf <- matrix(aperm(y.cf, c(2L, 1L, 3L)), nrow = n.samples * n.chains)
expect_equal(y.cf, extract(fit, "y.cf", combineChains = TRUE))
ites <- t((y.obs - t(y.cf)) * ifelse(testData$z == 1, 1, -1))
ites <- ites[,fit$commonSup.sub]
expect_equal(sd(apply(ites, 1, mean)), sd(extract(fit, "sate")))
expect_true(!is.nan(sum.cate$estimates$estimate) && is.finite(sum.cate$estimates$estimate))
expect_true(!is.nan(sum.sate$estimates$estimate) && is.finite(sum.sate$estimates$estimate))
expect_equal(sum.cate$estimates$estimate, mean(icates))
expect_equal(sum.cate$estimates$sd, sd(apply(icates, 1, mean)))
expect_equal(sum.sate$estimates$estimate, mean(iscates))
expect_equal(sum.sate$estimates$sd,
sqrt(var(rowMeans(iscates)) + mean(extract(fit, "sigma")^2) / sum(fit$commonSup.sub)))
})
source(system.file("common", "friedmanData.R", package = "bartCause"))
test_that("sate summary is correct quantity", {
data <- generateFriedmanData(n = 100, causal = TRUE)
fit <- bartc(y, z, x, data = data, verbose = FALSE, samples = 50L,
n.burn = 25L, n.chains = 4L, n.threads = 1L, seed = 0)
fit_sum <- summary(fit, "sate")
samples.sate <- extract(fit, "sate")
expect_true(abs(fit_sum$estimates$estimate - mean(samples.sate)) < 1e-1)
expect_true(abs(fit_sum$estimates$sd - sd(samples.sate)) < 1e-1)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.