Nothing
context("predict")
source(system.file("common", "linearData.R", package = "bartCause"), local = TRUE)
n.train <- 80L
x <- testData$x[seq_len(n.train),]
y <- testData$y[seq_len(n.train)]
z <- testData$z[seq_len(n.train)]
x.new <- testData$x[seq.int(n.train + 1L, nrow(testData$x)),]
n.test <- nrow(x.new)
test_that("predict gives sane results", {
n.samples <- 7L
n.chains <- 2L
fit <- bartc(y, z, x, method.trt = "glm", method.rsp = "bart",
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
keepTrees = TRUE,
verbose = FALSE)
# check predict for single row
expect_equal(length(predict(fit, x.new[1,], type = "mu.0")), n.samples * n.chains)
p.score <- predict(fit, x.new, type = "p.score")
mu.1 <- predict(fit, x.new, type = "mu.1", combineChains = FALSE)
mu.0 <- predict(fit, x.new, type = "mu.0", combineChains = TRUE)
icate <- predict(fit, x.new, type = "icate", combineChains = TRUE)
expect_true(is.null(dim(p.score)))
expect_equal(dim(mu.1), c(n.chains, n.samples, n.test))
expect_equal(dim(mu.0), c(n.chains * n.samples, n.test))
expect_equal(as.vector(icate), as.vector(matrix(aperm(mu.1, c(2L, 1L, 3L)), n.samples * n.chains)) - as.vector(mu.0))
})
test_that("predict results matches training data", {
n.samples <- 7L
n.chains <- 2L
fit <- bartc(y, z, x, method.trt = "bart", method.rsp = "bart",
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
keepTrees = TRUE,
args.trt = list(k = 1.5), verbose = FALSE)
p.score <- extract(fit, type = "p.score")
mu.1 <- extract(fit, type = "mu.1")
mu.0 <- extract(fit, type = "mu.0")
icate <- extract(fit, type = "icate")
mu <- extract(fit, type = "mu.obs")
p.score.new <- predict(fit, x, type = "p.score")
mu.1.new <- predict(fit, x, type = "mu.1")
mu.0.new <- predict(fit, x, type = "mu.0")
icate.new <- predict(fit, x, type = "icate")
mu.new <- predict(fit, cbind(x, z), type = "mu")
expect_equal(p.score, p.score.new)
expect_equal(mu.0, mu.0.new)
expect_equal(mu.1, mu.1.new)
expect_equal(icate, icate.new)
expect_equal(mu, mu.new)
})
set.seed(22)
g <- sample(3L, nrow(x), replace = TRUE)
n.samples <- 7L
n.chains <- 2L
test_that("predict works with grouped data, glm trt model", {
fit <- bartc(y, z, x, method.trt = "glm", method.rsp = "bart", group.by = g,
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
keepTrees = TRUE, use.ranef = FALSE,
args.trt = list(k = 1.5), verbose = FALSE)
p.score <- fitted(fit, type = "p.score")
mu.1 <- extract(fit, type = "mu.1")
mu.0 <- extract(fit, type = "mu.0")
icate <- extract(fit, type = "icate")
p.score.new <- predict(fit, x, group.by = g, type = "p.score")
mu.1.new <- predict(fit, x, group.by = g, type = "mu.1")
mu.0.new <- predict(fit, x, group.by = g, type = "mu.0")
icate.new <- predict(fit, x, group.by = g, type = "icate")
expect_equal(p.score, p.score.new)
expect_equal(mu.0, mu.0.new)
expect_equal(mu.1, mu.1.new)
expect_equal(icate, icate.new)
})
test_that("predict works with grouped data, glmer trt model", {
skip_if_not_installed("lme4")
suppressWarnings(
fit <- bartc(y, z, x, method.trt = "glm", method.rsp = "bart", group.by = g,
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
keepTrees = TRUE, use.ranef = FALSE,
args.trt = list(k = 1.5), verbose = FALSE)
)
p.score <- fitted(fit, type = "p.score")
mu.1 <- extract(fit, type = "mu.1")
mu.0 <- extract(fit, type = "mu.0")
icate <- extract(fit, type = "icate")
p.score.new <- predict(fit, x, group.by = g, type = "p.score")
mu.1.new <- predict(fit, x, group.by = g, type = "mu.1")
mu.0.new <- predict(fit, x, group.by = g, type = "mu.0")
icate.new <- predict(fit, x, group.by = g, type = "icate")
expect_equal(p.score, p.score.new)
expect_equal(mu.0, mu.0.new)
expect_equal(mu.1, mu.1.new)
expect_equal(icate, icate.new)
})
test_that("predict works with grouped data, bart trt model", {
fit <- bartc(y, z, x, method.trt = "bart", method.rsp = "bart", group.by = g,
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
keepTrees = TRUE,
args.trt = list(k = 1.5), verbose = FALSE)
p.score <- extract(fit, type = "p.score")
mu.1 <- extract(fit, type = "mu.1")
mu.0 <- extract(fit, type = "mu.0")
icate <- extract(fit, type = "icate")
p.score.new <- predict(fit, x, group.by = g, type = "p.score")
mu.1.new <- predict(fit, x, group.by = g, type = "mu.1")
mu.0.new <- predict(fit, x, group.by = g, type = "mu.0")
icate.new <- predict(fit, x, group.by = g, type = "icate")
expect_equal(p.score, p.score.new)
expect_equal(mu.0, mu.0.new)
expect_equal(mu.1, mu.1.new)
expect_equal(icate, icate.new)
fit <- bartc(y, z, x, method.trt = "bart", method.rsp = "bart", group.by = g,
n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
keepTrees = TRUE, use.ranef = FALSE,
args.trt = list(k = 1.5), verbose = FALSE)
p.score <- extract(fit, type = "p.score")
mu.1 <- extract(fit, type = "mu.1")
mu.0 <- extract(fit, type = "mu.0")
icate <- extract(fit, type = "icate")
p.score.new <- predict(fit, x, group.by = g, type = "p.score")
mu.1.new <- predict(fit, x, group.by = g, type = "mu.1")
mu.0.new <- predict(fit, x, group.by = g, type = "mu.0")
icate.new <- predict(fit, x, group.by = g, type = "icate")
expect_equal(p.score, p.score.new)
expect_equal(mu.0, mu.0.new)
expect_equal(mu.1, mu.1.new)
expect_equal(icate, icate.new)
})
rm(testData, n.train, x, y, z, g, n.samples, n.chains, x.new, n.test)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.