#' Generate a List of Tree Structures from BART Model Output
#'
#' This function takes a dataframe of trees, which is output from a BART model, and organizes
#' it into a list of tree structures. It allows for filtering based on iteration number, tree
#' number, and optionally reordering based on the maximum depth of nodes or variables.
#'
#'
#' @param trees A dataframe that contains the tree structures generated by a BART model.
#' Expected columns include iteration, treeNum, parent, node, obsNode,
#' @param iter An integer specifying the iteration number of trees to be included in the output.
#' If NULL, trees from all iterations are included.
#' @param treeNo An integer specifying the number of the tree to include in the output.
#' If NULL, all trees are included.
#'
#' @return A list of tidygraph objects, each representing the structure of a tree. Each tidygraph object includes
#' node and edge information necessary for visualisation.
#'
#' @importFrom dplyr filter select one_of group_by group_split
#' @importFrom purrr map map2
#' @importFrom tidygraph tbl_graph
#'
#' @examples
#' if(requireNamespace("dbarts", quietly = TRUE)){
#' # Load the dbarts package to access the bart function
#' library(dbarts)
#' library(ggplot2)
#' # Get Data
#' df <- na.omit(airquality)
#' # Create Simple dbarts Model For Regression:
#' set.seed(1701)
#' dbartModel <- bart(df[2:6], df[, 1], ntree = 5, keeptrees = TRUE, nskip = 10, ndpost = 10)
#'
#' # Tree Data
#' trees_data <- extractTreeData(model = dbartModel, data = df)
#' trees_list <- treeList(trees_data)
#' }
#'
#' @export
#'
#'
treeList <- function(trees, iter = NULL, treeNo = NULL) {
# Error check for 'iter' input
if (!is.null(iter)) {
if (iter < 1 || iter > trees$nMCMC) {
stop("Error: 'iter' value is out of allowed range. Max iteration is ", trees$nMCMC, ".")
}
}
# Error check for 'treeNo' input
if (!is.null(treeNo)) {
if (treeNo < 1 || treeNo > trees$nTree) {
stop("Error: 'treeNo' value is out of allowed range. Max tree number is ", trees$nTree, ".")
}
}
# filter if selected
if (is.null(iter) & is.null(treeNo)) {
df <- trees$structure
message("Displaying All Trees.")
} else if (is.null(iter) & !is.null(treeNo)) {
df <- trees$structure |>
filter(treeNum == treeNo)
message(paste0("Tree Number ", treeNo, " Selected."))
} else if (!is.null(iter) & is.null(treeNo)) {
df <- trees$structure |>
filter(iteration == iter)
message(paste0("Iteration ", iter, " Selected."))
} else {
df <- trees$structure |>
filter(iteration == iter, treeNum == treeNo)
message(paste0("Iteration ", iter, " and Tree Number ", treeNo, " Selected."))
}
# get mean response value per node:
respNode <- df$obsNode
respNode <- lapply(respNode, mean)
df$respNode <- unlist(respNode)
# Which columns to display
keeps <- c("var",
"node",
"parent",
"iteration",
"treeNum",
"label",
"value",
"depthMax",
'noObs',
'respNode',
'obsNode',
'isStump')
res <- dplyr::select(
df,
dplyr::one_of(keeps)
)
res <- res |>
group_by(iteration, treeNum)
nodeList <- dplyr::group_split(dplyr::select(res, -parent), .keep = TRUE)
edgeList <- purrr::map(
dplyr::group_split(dplyr::select(
res,
iteration,
treeNum,
parent,
node
), .keep = FALSE),
~ dplyr::filter(., !is.na(parent))
)
# Turn into data structure for tidy graph manipulation
tblgList <- purrr::map2(
.x = nodeList,
.y = edgeList,
.f = ~ tidygraph::tbl_graph(
nodes = .x,
edges = .y,
directed = TRUE
)
)
return(tblgList)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.