tests/testthat/test-08-predict.R

context("predict")

source(system.file("common", "linearData.R", package = "bartCause"), local = TRUE)

n.train <- 80L
x <- testData$x[seq_len(n.train),]
y <- testData$y[seq_len(n.train)]
z <- testData$z[seq_len(n.train)]

x.new <- testData$x[seq.int(n.train + 1L, nrow(testData$x)),]
n.test <- nrow(x.new)

test_that("predict gives sane results", {
  n.samples <- 7L
  n.chains  <- 2L
  fit <- bartc(y, z, x, method.trt = "glm", method.rsp = "bart",
               n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
               keepTrees = TRUE,
               verbose = FALSE)
  
  # check predict for single row
  expect_equal(length(predict(fit, x.new[1,], type = "mu.0")), n.samples * n.chains)
  
  p.score <- predict(fit, x.new, type = "p.score")
  mu.1    <- predict(fit, x.new, type = "mu.1", combineChains = FALSE)
  mu.0    <- predict(fit, x.new, type = "mu.0", combineChains = TRUE)
  icate   <- predict(fit, x.new, type = "icate", combineChains = TRUE)
  
  expect_true(is.null(dim(p.score)))
  expect_equal(dim(mu.1), c(n.chains, n.samples, n.test))
  expect_equal(dim(mu.0), c(n.chains * n.samples, n.test))
  expect_equal(as.vector(icate), as.vector(matrix(aperm(mu.1, c(2L, 1L, 3L)), n.samples * n.chains)) - as.vector(mu.0))
})

test_that("predict results matches training data", {
  n.samples <- 7L
  n.chains  <- 2L
  fit <- bartc(y, z, x, method.trt = "bart", method.rsp = "bart",
               n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
               keepTrees = TRUE,
               args.trt = list(k = 1.5), verbose = FALSE)
  
  p.score <- extract(fit, type = "p.score")
  mu.1    <- extract(fit, type = "mu.1")
  mu.0    <- extract(fit, type = "mu.0")
  icate   <- extract(fit, type = "icate")
  mu      <- extract(fit, type = "mu.obs")
  
  p.score.new <- predict(fit, x, type = "p.score")
  mu.1.new    <- predict(fit, x, type = "mu.1")
  mu.0.new    <- predict(fit, x, type = "mu.0")
  icate.new   <- predict(fit, x, type = "icate")
  mu.new      <- predict(fit, cbind(x, z), type = "mu")
  
  expect_equal(p.score, p.score.new)
  expect_equal(mu.0, mu.0.new)
  expect_equal(mu.1, mu.1.new)
  expect_equal(icate, icate.new)
  expect_equal(mu, mu.new)
})

set.seed(22)
g <- sample(3L, nrow(x), replace = TRUE)

n.samples <- 7L
n.chains  <- 2L

test_that("predict works with grouped data, glm trt model", {
 
  fit <- bartc(y, z, x, method.trt = "glm", method.rsp = "bart", group.by = g,
               n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
               keepTrees = TRUE, use.ranef = FALSE,
               args.trt = list(k = 1.5), verbose = FALSE)
  
  p.score <- fitted(fit, type = "p.score")
  mu.1    <- extract(fit, type = "mu.1")
  mu.0    <- extract(fit, type = "mu.0")
  icate   <- extract(fit, type = "icate")
  
  p.score.new <- predict(fit, x, group.by = g, type = "p.score")
  mu.1.new    <- predict(fit, x, group.by = g, type = "mu.1")
  mu.0.new    <- predict(fit, x, group.by = g, type = "mu.0")
  icate.new   <- predict(fit, x, group.by = g, type = "icate")
  
  expect_equal(p.score, p.score.new)
  expect_equal(mu.0, mu.0.new)
  expect_equal(mu.1, mu.1.new)
  expect_equal(icate, icate.new)
})

test_that("predict works with grouped data, glmer trt model", {
  skip_if_not_installed("lme4")

  suppressWarnings(
    fit <- bartc(y, z, x, method.trt = "glm", method.rsp = "bart", group.by = g,
                 n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
                 keepTrees = TRUE, use.ranef = FALSE,
                 args.trt = list(k = 1.5), verbose = FALSE)
  )
  
  p.score <- fitted(fit, type = "p.score")
  mu.1    <- extract(fit, type = "mu.1")
  mu.0    <- extract(fit, type = "mu.0")
  icate   <- extract(fit, type = "icate")
  
  p.score.new <- predict(fit, x, group.by = g, type = "p.score")
  mu.1.new    <- predict(fit, x, group.by = g, type = "mu.1")
  mu.0.new    <- predict(fit, x, group.by = g, type = "mu.0")
  icate.new   <- predict(fit, x, group.by = g, type = "icate")
  
  expect_equal(p.score, p.score.new)
  expect_equal(mu.0, mu.0.new)
  expect_equal(mu.1, mu.1.new)
  expect_equal(icate, icate.new)
})

test_that("predict works with grouped data, bart trt model", {
  fit <- bartc(y, z, x, method.trt = "bart", method.rsp = "bart", group.by = g,
               n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
               keepTrees = TRUE,
               args.trt = list(k = 1.5), verbose = FALSE)
  
  p.score <- extract(fit, type = "p.score")
  mu.1    <- extract(fit, type = "mu.1")
  mu.0    <- extract(fit, type = "mu.0")
  icate   <- extract(fit, type = "icate")
  
  p.score.new <- predict(fit, x, group.by = g, type = "p.score")
  mu.1.new    <- predict(fit, x, group.by = g, type = "mu.1")
  mu.0.new    <- predict(fit, x, group.by = g, type = "mu.0")
  icate.new   <- predict(fit, x, group.by = g, type = "icate")
  
  expect_equal(p.score, p.score.new)
  expect_equal(mu.0, mu.0.new)
  expect_equal(mu.1, mu.1.new)
  expect_equal(icate, icate.new)
  
  fit <- bartc(y, z, x, method.trt = "bart", method.rsp = "bart", group.by = g,
               n.chains = n.chains, n.threads = 1L, n.burn = 0L, n.samples = n.samples, n.trees = 13L,
               keepTrees = TRUE, use.ranef = FALSE,
               args.trt = list(k = 1.5), verbose = FALSE)
  
  p.score <- extract(fit, type = "p.score")
  mu.1    <- extract(fit, type = "mu.1")
  mu.0    <- extract(fit, type = "mu.0")
  icate   <- extract(fit, type = "icate")
  
  p.score.new <- predict(fit, x, group.by = g, type = "p.score")
  mu.1.new    <- predict(fit, x, group.by = g, type = "mu.1")
  mu.0.new    <- predict(fit, x, group.by = g, type = "mu.0")
  icate.new   <- predict(fit, x, group.by = g, type = "icate")
  
  expect_equal(p.score, p.score.new)
  expect_equal(mu.0, mu.0.new)
  expect_equal(mu.1, mu.1.new)
  expect_equal(icate, icate.new)
})

rm(testData, n.train, x, y, z, g, n.samples, n.chains, x.new, n.test)
vdorie/bartCause documentation built on May 5, 2024, 9:29 a.m.