Nothing
## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE)
## ----simulateData-------------------------------------------------------------
f <- function(x)
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
y <- rnorm(n, f(x), sigma)
data <- data.frame(x, y)
## ----fitModel-----------------------------------------------------------------
library(dbarts, quietly = TRUE)
bartFit <- bart(
y ~ ., data,
ndpost = 4, # number of posterior samples
nskip = 1000, # number of "warmup" samples to discard
nchain = 2, # number of independent, parallel chains
nthread = 1, # units of parallel execution
ntree = 3, # number of trees per chain
seed = 2, # chosen to generate a deep tree
keeptrees = TRUE,
verbose = FALSE)
## ----extractTrees-------------------------------------------------------------
trees <- extract(bartFit, "trees")
## ----printFlattenedTrees------------------------------------------------------
print(head(trees, n = 10))
## ----rebuildTrees-------------------------------------------------------------
# Turns a flatted tree data frame into a list of lists, or a "natural" tree
# structure.
rebuildTree <- function(tree, object) {
# Define a worker function that will be recursively called on every node.
rebuildTreeRecurse <- function(tree) {
node <- list(
value = tree$value[1],
n = tree$n[1]
)
# Check node if is a leaf, and if so return early.
if (tree$var[1] == -1) {
node$n_nodes <- 1
return(node)
}
node$var <- variableNames[tree$var[1]]
# By removing the current row, we can recurse down the left branch.
headOfLeftBranch <- tree[-1,]
left <- rebuildTreeRecurse(headOfLeftBranch)
n_nodes.left <- left$n_nodes
left$n_nodes <- NULL
node$left <- left
# The right branch is obtained by advancing past the left nodes.
headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
right <- rebuildTreeRecurse(headOfRightBranch)
n_nodes.right <- right$n_nodes
right$n_nodes <- NULL
node$right <- right
node$n_nodes <- 1L + n_nodes.left + n_nodes.right
return(node)
}
variableNames <- colnames(object$fit$data@x)
result <- rebuildTreeRecurse(tree)
result$n_nodes <- NULL
return(result)
}
treeOfInterest <- subset(trees, chain == 1 & sample == 3 & tree == 1)
print(rebuildTree(treeOfInterest, bartFit))
## ----rebuildAllTrees----------------------------------------------------------
allTrees <- by(
data = trees,
INDICES = trees[,c("chain", "sample", "tree")],
FUN = rebuildTree,
object = bartFit)
# One way to index the result of this:
# allTrees[chain = "1", sample = "2", tree = "3"]
## ----plotTree-----------------------------------------------------------------
bartFit$fit$plotTree(chainNum = 1, sampleNum = 3, treeNum = 1)
## ----getPredictionsForTree----------------------------------------------------
getPredictionsForTree <- function(tree, x) {
predictions <- rep(NA_real_, nrow(x))
getPredictionsForTreeRecursive <- function(tree, indices) {
if (tree$var[1] == -1) {
# Assigns in the calling environment by using <<-
predictions[indices] <<- tree$value[1]
return(1)
}
goesLeft <- x[indices, tree$var[1]] <= tree$value[1]
headOfLeftBranch <- tree[-1,]
n_nodes.left <- getPredictionsForTreeRecursive(
headOfLeftBranch, indices[goesLeft])
headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
n_nodes.right <- getPredictionsForTreeRecursive(
headOfRightBranch, indices[!goesLeft])
return(1 + n_nodes.left + n_nodes.right)
}
getPredictionsForTreeRecursive(tree, seq_len(nrow(x)))
return(predictions)
}
getPredictionsForTree(treeOfInterest, bartFit$fit$data@x[1:5,])
## ----mapOverNodes-------------------------------------------------------------
mapOverNodes <- function(tree, f, ...) {
mapOverNodesRecurse <- function(tree, depth, f, ...) {
node <- list(
value = tree$value[1],
n = tree$n[1],
depth = depth
)
if (tree$var[1] == -1) {
node$n_nodes <- 1
node$f.x <- f(node, ...)
return(node)
}
node$var <- tree$var[1]
node$f.x <- f(node, ...)
headOfLeftBranch <- tree[-1,]
left <- mapOverNodesRecurse(headOfLeftBranch, depth + 1, f, ...)
n_nodes.left <- left$n_nodes
left$n_nodes <- NULL
node$left <- left
headOfRightBranch <- tree[seq.int(2 + n_nodes.left, nrow(tree)),]
right <- mapOverNodesRecurse(headOfRightBranch, depth + 1, f, ...)
n_nodes.right <- right$n_nodes
right$n_nodes <- NULL
node$right <- right
node$n_nodes <- 1 + n_nodes.left + n_nodes.right
return(node)
}
result <- mapOverNodesRecurse(tree, 1, f, ...)
result$n_nodes <- NULL
return(result)
}
## ----observeInteractions------------------------------------------------------
observeInteractions <- function(node, ...) {
if (is.null(node$var)) return(NULL)
interactionData <- list(...)$interactionData
# Make the current node visibile inside the environment.
interactionData$node <- node
with(interactionData, {
if (node$depth <= currentDepth) {
# If true, we have backtracked to go down the right branch, so we
# remove the variables from the left branch.
currentVariables <- currentVariables[seq_len(node$depth - 1)]
}
if (length(interactionData$currentVariables) > 0) {
# This is a brute-force way of updating the following indices,
# relying on the column-major storage order that R uses:
# hasInteraction[currentVariables,,drop = FALSE][,node$var]
updateIndices <- currentVariables +
(node$var - 1) * nrow(hasInteraction)
hasInteraction[updateIndices] <- TRUE
}
currentVariables <- c(currentVariables, node$var)
currentDepth <- node$depth
})
rm("node", envir = interactionData)
# Since the function is used for its side effects, there isn't a return
# value.
return(NULL)
}
numVariables <- ncol(bartFit$fit$data@x)
variableNames <- colnames(bartFit$fit$data@x)
# Define this as an environment as they are mutable
interactionData <- list2env(list(
currentDepth = 0,
currentVariables = integer(),
hasInteraction = matrix(
data = FALSE,
ncol = numVariables, nrow = numVariables,
dimnames = list(ancestor = variableNames, descendant = variableNames)
)
))
invisible(mapOverNodes(
treeOfInterest,
observeInteractions,
interactionData = interactionData
))
## ----printObserveInteractionResults-------------------------------------------
print(interactionData$hasInteraction)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.