tests/testthat/test-18-extractedTrees.R

context("extract for trees")

source(system.file("common", "friedmanData.R", package = "dbarts"), local = TRUE)

df <- with(testData, data.frame(x, y))

test_that("base bart extracts trees correctly", {
  n.trees <- 3L
  n.samples <- 4L
  fit <- bart(y ~ ., df, nthread = 1L, ntree = n.trees, nskip = 0L,
              ndpost = n.samples, keeptrees = TRUE, verbose = FALSE)
  allTrees <- extract(fit, "trees")
  
  expect_true(all(c("sample", "tree") %in% colnames(allTrees)))
  expect_true(!("chain" %in% colnames(allTrees)))

  combinations <- data.frame(
    sample = rep(seq_len(n.samples), each = n.trees),
    tree   = rep(seq_len(n.trees), times = n.samples)
  )
  expect_true(all(interaction(combinations$sample, combinations$tree) %in%
                  interaction(allTrees$sample, allTrees$tree)))
  
  individualSamples <-
    lapply(seq_len(n.samples), function(i) extract(fit, "trees", sampleNums = i))
  individualSamples <- Reduce(rbind, individualSamples)
  row.names(individualSamples) <- as.character(seq_len(nrow(individualSamples)))

  expect_equal(allTrees, individualSamples)
})

n.g <- 5L
g <- sample(n.g, length(testData$y), replace = TRUE)

sigma.b <- 1.5
b <- rnorm(n.g, 0, sigma.b)

df$y <- df$y + b[g]
df$g <- g
rm(g, b, sigma.b, n.g)

test_that("rbart extracts trees correctly", {
  n.trees <- 3L
  n.samples <- 4L
  n.chains <- 2L
  fit <- rbart_vi(y ~ ., df, group.by = g,
                  n.threads = 1L, n.trees = n.trees, n.burn = 0L, n.thin = 1L,
                  n.chains = n.chains,
                  n.samples = n.samples, keepTrees = TRUE, verbose = FALSE)
  allTrees <- extract(fit, "trees")
  
  expect_true(all(c("sample", "chain", "tree") %in% colnames(allTrees)))

  combinations <- data.frame(
    chain  = rep(seq_len(n.chains), each = n.trees * n.samples),
    sample = rep(rep(seq_len(n.samples), each = n.trees), times = n.chains),
    tree   = rep(rep(seq_len(n.trees), times = n.samples), times = n.chains)
  )
  expect_true(all(interaction(combinations$sample, combinations$tree) %in%
                  interaction(allTrees$sample, allTrees$tree)))
  
  individualSamples <-
    lapply(seq_len(n.chains), function(i) extract(fit, "trees", chainNums = i))
  individualSamples <- Reduce(rbind, individualSamples)
  row.names(individualSamples) <- as.character(seq_len(nrow(individualSamples)))

  expect_equal(allTrees, individualSamples)
})
vdorie/dbarts documentation built on May 6, 2024, 7:32 p.m.