Nothing
#' Plot Trees with Customisations
#'
#' This function plots trees from a list of tidygraph objects. It allows for various
#' customisations such as fill colour based on node response or value, node size adjustments,
#' and color palettes.
#'
#' @param trees A data frame of trees.
#' @param iter An integer specifying the iteration number of trees to be included in the output.
#' If NULL, trees from all iterations are included.
#' @param treeNo An integer specifying the number of the tree to include in the output.
#' If NULL, all trees are included.
#' @param fillBy A character string specifying the attribute to color nodes by.
#' Options are 'response' for coloring nodes based on their mean response values or
#' 'mu' for coloring nodes based on their predicted value, or NULL for no
#' specific fill attribute.
#' @param sizeNodes A logical value indicating whether to adjust node sizes.
#' If TRUE, node sizes are adjusted; if FALSE, all nodes are given the same size.
#' @param removeStump A logical value. If TRUE, then stumps are removed from plot.
#' @param selectedVars A vector of selected variables to display. Either a character vector of names
#' or the variables column number.
#' @param pal A colour palette for node colouring. Palette is used when 'fillBy' is specified for gradient colouring.
#' @param center_Mu A logical value indicating whether to center the color scale for the 'mu'
#' attribute around zero. Applicable only when 'fillBy' is set to "mu".
#' @param cluster A character string that specifies the criterion for reordering trees in the output.
#' Currently supports "depth" for ordering by the maximum depth of nodes, and "var" for a
#' clustering based on variables. If NULL, no reordering is performed.
#'
#' @return A ggplot object representing the plotted trees with the specified customisations.
#'
#' @importFrom purrr map
#' @importFrom tidygraph activate pull tbl_graph
#' @importFrom dplyr as_tibble mutate case_when
#' @importFrom tidyr replace_na
#' @importFrom scales hue_pal squish
#' @importFrom ggnewscale new_scale_fill new_scale_color
#' @importFrom igraph gsize
#' @import ggplot2
#' @import ggraph
#' @importFrom grDevices colorRampPalette
#' @importFrom stats as.formula
#'
#' @examples
#'if (requireNamespace("dbarts", quietly = TRUE)) {
#' # Load the dbarts package to access the bart function
#' library(dbarts)
#' # Get Data
#' df <- na.omit(airquality)
#' # Create Simple dbarts Model For Regression:
#' set.seed(1701)
#' dbartModel <- bart(df[2:6],
#' df[, 1],
#' ntree = 5,
#' keeptrees = TRUE,
#' nskip = 10,
#' ndpost = 10
#' )
#' # Tree Data
#' trees_data <- extractTreeData(model = dbartModel, data = df)
#' plotTrees(trees = trees_data, fillBy = 'response', sizeNodes = TRUE)
#' }
#'
#' @export
plotTrees <- function(trees,
iter = NULL,
treeNo = NULL,
fillBy = NULL,
sizeNodes = FALSE,
removeStump = FALSE,
selectedVars = NULL,
pal = rev(colorRampPalette(c('steelblue', '#f7fcfd', 'orange'))(5)),
center_Mu = TRUE,
cluster = NULL){
# get variable names
variable_names <- trees$varName
# make list of trees
trees <- treeList(trees = trees,
iter = iter,
treeNo = treeNo)
# sort graph
if(is.null(cluster)){
facet_name = "iteration"
} else{
if (cluster == "var") {
trees <- clusterTrees(trees)
facet_name <- "varID"
} else if (cluster == "depth") {
facet_name <- "-depthMax"
}
}
# remove stump
if (removeStump) {
# Remove trees with no edges
trees <- Filter(function(x) igraph::gsize(x) > 0, trees)
}
# get unique 'var' values from each tree
all_vars <- purrr::map(trees, function(tree) {
tree %>%
activate(nodes) |>
as_tibble() |>
pull(var) |>
unique()
}) |>
unlist() |>
unique()
# Select variables to display
if(!is.null(selectedVars)){
if(is.numeric(selectedVars)){
selected_names <- variable_names[selectedVars]
not_selected_names <- variable_names[-selectedVars]
}else if(is.character(selectedVars)){
selected_names <- variable_names[variable_names %in% selectedVars]
not_selected_names <- variable_names[!(variable_names %in% selectedVars)]
}
}
# set stump name
if (!is.null(fillBy)) {
stump_name <- "Stump"
} else {
stump_name <- "Stump/Leaf"
}
# Initialize empty data frames for nodes and edges
all_nodes <- data.frame()
all_edges <- data.frame()
# Extract and combine nodes and edges from each tree
for (i in seq_along(trees)) {
nodes <- trees[[i]] |> activate(nodes) |> as_tibble()
edges <- trees[[i]] |> activate(edges) |> as_tibble()
# Adjust node IDs in edges to make them unique across combined graph
edge_offset <- ifelse(nrow(all_nodes) == 0, 0, max(all_nodes$node))
edges$from <- edges$from + edge_offset
edges$to <- edges$to + edge_offset
# Combine
all_nodes <- rbind(all_nodes, mutate(nodes, node = node + edge_offset))
all_nodes$var[is.na(all_nodes$var)] <- stump_name
all_edges <- rbind(all_edges, edges)
}
# rename other variables if selected vars
if(!is.null(selectedVars)){
all_nodes$var <- ifelse(all_nodes$var %in% not_selected_names, "Others", all_nodes$var)
}
# add varID columns for use when clustering by var
all_nodes <- all_nodes |>
mutate(varID = cumsum(c(1, diff(treeNum) != 0)))
# get the limits
leaf_stumps <- all_nodes |>
filter(var == "Stump")
if(is.null(fillBy)) {
fill_value <- NULL
lims <- NULL
legend_name <- 'Variable'
} else if(fillBy == 'response') {
lims <- range(leaf_stumps$respNode, na.rm = TRUE)
lims <- pretty(c(lims[1], lims[2]))
lims <- c(min(lims), max(lims))
legend_name <- 'Mean \nResponse'
all_nodes <- all_nodes |>
mutate(fill_value = respNode)
} else if(fillBy == "mu") {
lims <- range(leaf_stumps$value, na.rm = TRUE)
if(center_Mu){
lims <- c(-max(abs(lims)), max(abs(lims)))
}else{
lims <- pretty(c(lims[1], lims[2]))
lims <- c(min(lims), max(lims))
}
legend_name <- 'Mu'
all_nodes <- all_nodes |>
mutate(fill_value = value)
}
# set node colours
nodeNames <- unique(all_nodes$var) #sort(all_vars)
nodecolors <- setNames(scales::hue_pal(c(0, 360) + 15, 100, 64, 0, 1)(length(nodeNames)), nodeNames)
# set stump colour
if(is.null(fillBy)){
nodecolors[[stump_name]] <- '#808080'
} else{
mean_value <- ifelse(fillBy == 'response', all_nodes$respNode[1],
if(fillBy == 'mu') mean(leaf_stumps$value))
# make sure stumps are coloured by mean value
all_nodes$fill_value <- ifelse(all_nodes$isStump, mean_value, all_nodes$fill_value)
max_colours <- 1000000 # threshold for the maximum number of colors
# error handling if too many colours are selected
if(lims[2] > max_colours) {
warning(paste("lims[2] is too large (", lims[2], "). Using max_colors =", max_colours))
pal_stump <- rev(colorRampPalette(c('steelblue', '#f7fcfd', 'orange'))(max_colours))
} else {
pal_stump <- rev(colorRampPalette(c('steelblue', '#f7fcfd', 'orange'))(length(nodecolors)))
}
nodecolors[[stump_name]] <- get_stump_colour_for_legend(lims = lims,
mean_value = mean_value,
palette = pal_stump)
}
# set "Others" node colour if selected
if(!is.null(selectedVars)){
nodecolors[['Others']] <- '#e6e6e6'
}
if(removeStump){
nodecolors <- nodecolors[setdiff(names(nodecolors), stump_name)]
}
# set node size
if(sizeNodes){
all_nodes <- all_nodes |>
mutate(dynamic_weight = noObs)
}else{
all_nodes <- all_nodes |>
mutate(dynamic_weight = 1)
}
# Create the combined graph
combined_graph <- tbl_graph(nodes = all_nodes, edges = all_edges, directed = TRUE)
# reconfigure var column for plotting
combined_graph <- combined_graph |>
mutate(var = dplyr::case_when(
var == 'Stump' & isStump == FALSE ~ NA_character_,
TRUE ~ var
))
# plot set up
num_plots <- length(trees)
# Dynamic settings based on the number of plots
if (num_plots <= 250) {
panel_spacing_x <- unit(1, "lines")
} else if (num_plots <= 500){
panel_spacing_x <- unit(0.5, "lines")
} else {
panel_spacing_x <- unit(0.25, "lines")
}
if(!is.null(cluster)){
facet_formula <- as.formula(paste0("~ ", facet_name, " + iteration + treeNum"))
} else{
facet_formula <- as.formula(paste0("~ iteration + treeNum"))
}
suppressMessages(
p <- ggraph(combined_graph, layout = "partition", weight = dynamic_weight) +
geom_node_tile(aes(fill = var), linewidth = 0.25) +
scale_y_reverse() +
scale_fill_manual(values = nodecolors,
name = "Variable",
na.value = "#808080") +
facet_nodes(facets = facet_formula, scales = "free") +
theme_void() +
theme(aspect.ratio = 1,
legend.position = "right",
panel.spacing.x = unit(panel_spacing_x, "lines"),
strip.text.x = element_text(size = 0)) +
ggnewscale::new_scale_fill() +
ggnewscale::new_scale_color() +
geom_node_tile(linewidth = 0.15,
data = . %>% filter(is.na(var)),
aes(fill = fill_value)) +
scale_fill_gradientn(
colours = pal,
limits = lims,
name = legend_name,
guide = guide_colorbar(
frame.colour = "black",
ticks.colour = "black",
order = 2
),
)
)
return(p)
}
# END
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.