Nothing
context("predict")
source(system.file("common", "friedmanData.R", package = "dbarts"), local = TRUE)
test_that("combine/uncombine chains and convert from/to bart style works correctly",
{
n.chains <- 3L
n.samples <- 5L
n.obs <- 7L
dbartsSamples <- array(seq_len(n.chains * n.samples * n.obs),
c(n.obs, n.samples, n.chains),
dimnames = list(as.character(seq_len(n.obs)), NULL, NULL))
bartSamples <- dbarts:::convertSamplesFromDbartsToBart(dbartsSamples, n.chains)
expect_equal(dim(bartSamples), c(n.chains, n.samples, n.obs))
expect_equal(dimnames(bartSamples)[c(3L, 2L, 1L)], dimnames(dbartsSamples))
for (k in seq_len(n.chains))
expect_equal(t(bartSamples[k,,]), dbartsSamples[,,k])
bartSamples.cc <- dbarts:::convertSamplesFromDbartsToBart(dbartsSamples, n.chains, combineChains = TRUE)
expect_equal(dim(bartSamples.cc), c(n.chains * n.samples, n.obs))
expect_equal(colnames(bartSamples.cc), dimnames(dbartsSamples)[[1L]])
for (k in seq_len(n.chains))
expect_equal(bartSamples[k,,], bartSamples.cc[seq_len(n.samples) + (k - 1L) * n.samples,])
expect_equal(dbarts:::uncombineChains(bartSamples.cc, n.chains), bartSamples)
expect_equal(dbarts:::combineChains(bartSamples), bartSamples.cc)
expect_equal(dbarts:::convertSamplesFromBartsToDbarts(bartSamples, n.chains),
dbartsSamples)
expect_equal(dbarts:::convertSamplesFromBartsToDbarts(bartSamples.cc, n.chains, uncombineChains = TRUE),
dbartsSamples)
})
test_that("predict fails if sampler not saved", {
bartFit <- bart(testData$x, testData$y, ndpost = 20, nskip = 5, ntree = 5L, verbose = FALSE)
expect_error(predict(bartFit, testData$x))
})
test_that("predict gives same result as x_train with linear data", {
bartFit <- bart(testData$x, testData$y, ndpost = 20, nskip = 5, ntree = 5L, verbose = FALSE, keeptrees = TRUE)
predictions <- predict(bartFit, testData$x)
expect_equal(predictions, bartFit$yhat.train)
bartFit <- bart(testData$x, testData$y, ndpost = 20, nskip = 5, ntree = 5L, nchain = 4L, nthread = 1L, verbose = FALSE, keeptrees = TRUE)
predictions <- predict(bartFit, testData$x)
expect_equal(predictions, bartFit$yhat.train)
})
test_that("extract and fitted give correct results", {
n.chains <- 4L
n.samples <- 20L
bartFit <- bart(testData$x, testData$y, testData$x[1:10,], ndpost = n.samples, nskip = 5, ntree = 5L, nchain = n.chains, verbose = FALSE)
expect_equal(extract(bartFit), bartFit$yhat.train)
expect_equal(fitted(bartFit), bartFit$yhat.train.mean)
expect_equal(extract(bartFit, sample = "test"), bartFit$yhat.test)
expect_equal(fitted(bartFit, sample = "test"), bartFit$yhat.test.mean)
extracted <- extract(bartFit, combineChains = FALSE)
for (i in seq_len(n.chains))
expect_equal(extracted[i,,], bartFit$yhat.train[seq_len(n.samples) + (i - 1L) * n.samples,])
bartFit <- bart(testData$x, testData$y, testData$x[1:10,], ndpost = n.samples, nskip = 5, ntree = 5L, nchain = n.chains, verbose = FALSE, combinechains = FALSE)
extracted <- extract(bartFit)
for (i in seq_len(n.chains))
expect_equal(extracted[seq_len(n.samples) + (i - 1L) * n.samples,], bartFit$yhat.train[i,,])
})
test_that("posterior predictive distribution samples use correct sigma", {
n.samples <- 7L
n.chains <- 2L
n.obs <- length(testData$y)
bartFit <- bart(testData$x, testData$y, verbose = FALSE,
ndpost = n.samples, nskip = 0L, nchain = n.chains,
ntree = 25L, nthread = 1L)
set.seed(0)
samples.ppd <- extract(bartFit, type = "ppd")
set.seed(0)
samples.pm <- extract(bartFit)
for (i in seq_len(n.obs))
expect_equal(samples.pm[,i] + rnorm(n.samples * n.chains, 0, bartFit$sigma),
samples.ppd[,i])
set.seed(0)
samples.ppd <- extract(bartFit, type = "ppd", combineChains = FALSE)
set.seed(0)
samples.pm <- extract(bartFit, combineChains = FALSE)
for (i in seq_len(n.obs))
expect_equal(samples.pm[,,i] + matrix(rnorm(n.samples * n.chains, 0, bartFit$sigma), nrow = n.chains),
samples.ppd[,,i])
})
test_that("fixed sample mode when run sequentially gives same predictions as sequential updates mode", {
set.seed(0)
pred.bart <- bart2(testData$x, testData$y, testData$x, n.samples = 5, n.burn = 0L, n.trees = 4L,
k = 2, n.chains = 1L, n.threads = 1L, keepTrees = TRUE, verbose = FALSE)$yhat.test
set.seed(0)
sampler <- dbarts(testData$x, testData$y,
control = dbartsControl(n.samples = 5, n.burn = 0L, n.trees = 4L,
n.chains = 1L, n.threads = 1L, keepTrees = TRUE, updateState = FALSE))
sampler$sampleTreesFromPrior()
for (i in seq_len(5L))
invisible(sampler$run(0L, 1L))
pred.dbarts <- sampler$predict(testData$x)
expect_equal(pred.bart, t(pred.dbarts))
})
test_that("sequentially running samples don't overflow with fixed trees", {
sampler <- dbarts(testData$x, testData$y,
control = dbartsControl(n.samples = 5, n.burn = 0L, n.trees = 4L, n.chains = 1L, n.threads = 1L, keepTrees = TRUE, updateState = FALSE))
sampler$sampleTreesFromPrior()
for (i in seq_len(6L))
invisible(sampler$run(0L, 1L))
expect_is(sampler, "dbartsSampler")
})
source(system.file("common", "probitData.R", package = "dbarts"), local = TRUE)
test_that("predict gives same result as x_train with binary data", {
bartFit <- bart(y.train = testData$Z, x.train = testData$X, ndpost = 20, nskip = 5, ntree = 5L, k = 4.5, verbose = FALSE, keeptrees = TRUE)
predictions <- predict(bartFit, testData$X, type = "bart")
expect_equal(predictions, bartFit$yhat.train)
bartFit <- bart(y.train = testData$Z, x.train = testData$X, ndpost = 20, nskip = 5, ntree = 5L, k = 4.5, nchain = 4L, nthread = 1L, verbose = FALSE, keeptrees = TRUE)
predictions <- predict(bartFit, testData$X, type = "bart")
expect_equal(predictions, bartFit$yhat.train)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.