tests/testthat/test-09-bartArgs.R

context("extract for trees")

source(system.file("common", "friedmanData.R", package = "stan4bart"), local = TRUE)

testData <- generateFriedmanData(100, TRUE, TRUE, FALSE)
rm(generateFriedmanData)

df <- with(testData, data.frame(x, g.1, g.2, y, z))

test_that("bart args are set correctly", {
  output <- character()
  messages <- character()
  outputConnection <- textConnection("output", "w", local = TRUE)
  messagesConnection <- textConnection("messages", "w", local = TRUE)
  sink(outputConnection)
  sink(messagesConnection, type = "message")
  
  fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2), df,
                   cores = 1, verbose = 2, chains = 1, warmup = 0, iter = 1,
                   bart_args = list(n.trees = 2, power = 2.5, base = 0.9, split.probs = c(X3 = 2, .default = 1)),
                   treatment = z)

  sink(type = "message")
  sink()

  powerAndBaseLine <- grep("\tpower and base for tree prior:", output)
  expect_true(length(powerAndBaseLine) > 0L)
  
  powerAndBase <- 
    as.numeric(strsplit(sub("^[^0-9]+([0-9.]+ [0-9.]+)$", "\\1", output[powerAndBaseLine]), " ")[[1L]])
  
  expect_equal(powerAndBase[1L], 2.5)
  expect_equal(powerAndBase[2L], 0.9)
  
  splitProbsLine <- grep("\ttree split probabilities:", output)
  expect_true(length(splitProbsLine) > 0L)
  splitProbs <- 
    as.numeric(strsplit(sub("^[^0-9]+((?:[0-9.]+, )*[0-9.]+)$", "\\1", output[splitProbsLine], perl = TRUE), ", ")[[1L]])
  expect_equal(splitProbs, c(0.1, 0.1, 0.2, 0.1, 0.1))
})

Try the stan4bart package in your browser

Any scripts or data that you put into this service are public.

stan4bart documentation built on Sept. 12, 2024, 7:39 a.m.