Nothing
context("split probabilities")
source(system.file("common", "friedmanData.R", package = "dbarts"), local = TRUE)
df <- with(testData, data.frame(x, y))
df$X10 <- as.factor(paste0("C", 1 + round(4 * df$X10, 0)))
fitCall <- quote(bart(
testData$x,
testData$y,
ndpost = 1,
nskip = 0,
ntree = 1,
keeptrees = TRUE,
verbose = FALSE))
test_that("works defaults, no column names",
{
bartFit <- eval(fitCall)
expect_true(length(bartFit$fit$model@tree.prior@splitProbabilities) == 0L)
fitCall$splitprobs <- quote(1 / numvars)
bartFit <- eval(fitCall)
expect_true(length(bartFit$fit$model@tree.prior@splitProbabilities) == 0L)
fitCall$splitprobs <- quote(1)
bartFit <- eval(fitCall)
expect_true(length(bartFit$fit$model@tree.prior@splitProbabilities) == 0L)
})
test_that("works specific values, no column names",
{
probs <- c(2, rep.int(1, ncol(testData$x) - 1L))
fitCall$splitprobs <- quote(probs)
bartFit <- eval(fitCall)
expect_equal(bartFit$fit$model@tree.prior@splitProbabilities, probs / sum(probs))
probs <- probs[-1L]
expect_error(eval(fitCall))
probs <- c(-1, probs)
expect_error(eval(fitCall))
probs[1L] <- NA_real_
expect_error(eval(fitCall))
})
fitCall <- quote(bart(
y ~ .,
df,
ndpost = 3,
nskip = 0,
ntree = 2,
verbose = FALSE,
keeptrees = TRUE))
test_that("works with column names",
{
fitCall$splitprobs <- quote(c(X4 = 2, X10 = 1.5, .default = 1))
bartFit <- eval(fitCall)
split.probs <- bartFit$fit$model@tree.prior@splitProbabilities
expect_equal(length(split.probs), ncol(df) - 2 + nlevels(df$X10))
expect_equal(split.probs[["X4"]], 2 * split.probs[["X1"]])
x10_values <- startsWith(names(split.probs), "X10.")
expect_equal(sum(x10_values), nlevels(df$X10))
expect_true(all(split.probs[["X4"]] == 2 * split.probs[x10_values] / 1.5))
default_values <- !x10_values
default_values[names(split.probs) == "X4"] <- FALSE
expect_true(all(split.probs[default_values] == split.probs[default_values][1]))
fitCall$splitprobs <- quote(c(X4 = -1, X10 = 1.5, .default = 1))
expect_error(eval(fitCall))
fitCall$splitprobs <- quote(c(X4 = NA_real_, X10 = 1.5, .default = 1))
expect_error(eval(fitCall))
})
test_that("split probabilities sample from prior",
{
set.seed(0)
n.trees <- 200L
control <- dbartsControl(updateState = FALSE, verbose = FALSE,
n.burn = 0L, n.samples = 1L, n.thin = 1L,
n.trees = n.trees, keepTrees = FALSE,
n.chains = 1L, n.threads = 1L)
sampler <- dbarts(y ~ ., df, control = control,
tree.prior = cgm(split.probs = c(X4 = 2, .default = 1)))
sampler$sampleTreesFromPrior()
trees <- sampler$getTrees()
treeTable <- table(trees$var)
treeTable <- treeTable[setdiff(names(treeTable), "-1")]
names(treeTable) <- colnames(sampler$data@x)
expect_true(abs(treeTable[["X4"]] - 2 * mean(treeTable[setdiff(names(treeTable), "X4")])) / n.trees < 0.05)
})
test_that("split probabilities sample from posterior",
{
# X6 is uncorrelated
set.seed(0)
n.trees <- 5L
control <- dbartsControl(updateState = FALSE, verbose = FALSE,
n.burn = 1000L, n.samples = 100L, n.thin = 1L,
n.trees = n.trees, keepTrees = FALSE,
n.chains = 1L, n.threads = 1L)
sampler <- dbarts(y ~ ., df, control = control,
tree.prior = cgm(split.probs = c(X6 = 2, .default = 1)))
samples <- sampler$run(200L, 100L)
varcounts <- apply(samples$varcount, 1L, mean)
names(varcounts) <- colnames(sampler$data@x)
expect_true(all(varcounts[["X6"]] <= varcounts[paste0("X", 1:5)]))
expect_true(all(varcounts[["X6"]] >= varcounts[setdiff(names(varcounts), paste0("X", 1:6))] - 0.02))
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.