tests/testthat/test-06-regression.R

context("regression")

source(system.file("common", "linearData.R", package = "bartCause"))

test_that("naive bart matches old", {
  set.seed(22)
  fit <- bartc(y, z, x, data = testData,
               method.rsp = "bart", method.trt = "none", estimand = "att", verbose = FALSE,
               n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L, n.trees = 5L)
  expect_equal(fitted(fit, "cate"), -0.237484791465709)
})

test_that("bart on p.score matches old", {
  set.seed(22)
  fit <- bartc(y, z, x, data = testData,
               method.rsp = "bart", method.trt = "bart", estimand = "att", verbose = FALSE,
               n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L, n.trees = 5L, n.reps = 5L)
  expect_equal(fitted(fit, "cate"), 0.0213594201614189)
})

test_that("bart w/p.weighting matches old", {
  set.seed(22)
  fit <- bartc(y, z, x, data = testData,
               method.rsp = "p.weight", method.trt = "bart", estimand = "att", verbose = FALSE,
               n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L, n.trees = 5L, n.reps = 5L)
  expect_equal(fitted(fit, "pate"), 0.0211982499681833)
})

test_that("bart w/TMLE matches old", {
  skip_if_not_installed("tmle")
  
  set.seed(22)
  
  fit <- bartc(y, z, x, data = testData,
               method.rsp = "tmle", method.trt = "bart", estimand = "att", verbose = FALSE,
               n.samples = 5L, n.burn = 5L, n.chains = 1L, n.threads = 1L, n.trees = 5L, n.reps = 5L)
  
  if (packageVersion("tmle") >= "1.5.0")
    expect_equal(fitted(fit, "pate"), 0.30048319956979)
  else
    expect_equal(fitted(fit, "pate"), 0.293195268298812)
})

Try the bartCause package in your browser

Any scripts or data that you put into this service are public.

bartCause documentation built on Jan. 24, 2023, 1:09 a.m.