tests/testthat/test-19-splitProbabilities.R

context("split probabilities")

source(system.file("common", "friedmanData.R", package = "dbarts"), local = TRUE)

df <- with(testData, data.frame(x, y))
df$X10 <- as.factor(paste0("C", 1 + round(4 * df$X10, 0)))

fitCall <- quote(bart(
    testData$x,
    testData$y,
    ndpost = 1,
    nskip = 0,
    ntree = 1,
    keeptrees = TRUE,
    verbose = FALSE))

test_that("works defaults, no column names",
{
  bartFit <- eval(fitCall)
  expect_true(length(bartFit$fit$model@tree.prior@splitProbabilities) == 0L)
  

  fitCall$splitprobs <- quote(1 / numvars)

  bartFit <- eval(fitCall)
  expect_true(length(bartFit$fit$model@tree.prior@splitProbabilities) == 0L)

  fitCall$splitprobs <- quote(1)

  bartFit <- eval(fitCall)
  expect_true(length(bartFit$fit$model@tree.prior@splitProbabilities) == 0L)
})

test_that("works specific values, no column names",
{
  probs <- c(2, rep.int(1, ncol(testData$x) - 1L))

  fitCall$splitprobs <- quote(probs)
  bartFit <- eval(fitCall)

  expect_equal(bartFit$fit$model@tree.prior@splitProbabilities, probs / sum(probs))
  
  probs <- probs[-1L]
  expect_error(eval(fitCall))

  probs <- c(-1, probs)
  expect_error(eval(fitCall))

  probs[1L] <- NA_real_
  expect_error(eval(fitCall))
})


fitCall <- quote(bart(
  y ~ .,
  df,
  ndpost = 3,
  nskip = 0,
  ntree = 2,
  verbose = FALSE,
  keeptrees = TRUE))

test_that("works with column names",
{
  fitCall$splitprobs <- quote(c(X4 = 2, X10 = 1.5, .default = 1))
  bartFit <- eval(fitCall)

  split.probs <- bartFit$fit$model@tree.prior@splitProbabilities
  
  expect_equal(length(split.probs), ncol(df) - 2 + nlevels(df$X10))
  expect_equal(split.probs[["X4"]], 2 * split.probs[["X1"]])
  x10_values <- startsWith(names(split.probs), "X10.")
  expect_equal(sum(x10_values), nlevels(df$X10))
  expect_true(all(split.probs[["X4"]] ==  2 * split.probs[x10_values] / 1.5))
  default_values <- !x10_values
  default_values[names(split.probs) == "X4"] <- FALSE
  expect_true(all(split.probs[default_values] == split.probs[default_values][1]))

  
  fitCall$splitprobs <- quote(c(X4 = -1, X10 = 1.5, .default = 1))
  expect_error(eval(fitCall))

  fitCall$splitprobs <- quote(c(X4 = NA_real_, X10 = 1.5, .default = 1))
  expect_error(eval(fitCall))
})

test_that("split probabilities sample from prior",
{
  set.seed(0)
  n.trees <- 200L
  control <- dbartsControl(updateState = FALSE, verbose = FALSE,
                           n.burn = 0L, n.samples = 1L, n.thin = 1L,
                           n.trees = n.trees, keepTrees = FALSE,
                           n.chains = 1L, n.threads = 1L)
  sampler <- dbarts(y ~ ., df, control = control,
                    tree.prior = cgm(split.probs = c(X4 = 2, .default = 1)))
  sampler$sampleTreesFromPrior()

  trees <- sampler$getTrees()
  treeTable <- table(trees$var)
  treeTable <- treeTable[setdiff(names(treeTable), "-1")]
  names(treeTable) <- colnames(sampler$data@x)
  expect_true(abs(treeTable[["X4"]] - 2 * mean(treeTable[setdiff(names(treeTable), "X4")])) / n.trees < 0.05)
})

test_that("split probabilities sample from posterior",
{
  # X6 is uncorrelated
  set.seed(0)
  n.trees <- 5L
  control <- dbartsControl(updateState = FALSE, verbose = FALSE,
                           n.burn = 1000L, n.samples = 100L, n.thin = 1L,
                           n.trees = n.trees, keepTrees = FALSE,
                           n.chains = 1L, n.threads = 1L)
  sampler <- dbarts(y ~ ., df, control = control,
                    tree.prior = cgm(split.probs = c(X6 = 2, .default = 1)))
  samples <- sampler$run(200L, 100L)


  varcounts <- apply(samples$varcount, 1L, mean)
  names(varcounts) <- colnames(sampler$data@x)

  expect_true(all(varcounts[["X6"]] <= varcounts[paste0("X", 1:5)]))
  expect_true(all(varcounts[["X6"]] >= varcounts[setdiff(names(varcounts), paste0("X", 1:6))] - 0.02))
})
vdorie/dbarts documentation built on May 6, 2024, 7:32 p.m.