knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.path = "vig/" ) options(rmarkdown.html_vignette.check_title = FALSE)
For more detailed information and a comprehensive discussion, please refer to our paper associated with this document, available at DOI: 10.52933/jdssv.v4i1.79.
Tree-based regression and classification has become a standard tool in modern data science. Bayesian Additive Regression Trees (BART)^[Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian additive regression trees. The Annals of Applied Statistics, 4(1), 266-298.] has in particular gained wide popularity due its flexibility in dealing with interactions and non-linear effects. BART is a Bayesian tree-based machine learning method that can be applied to both regression and classification problems and yields competitive or superior results when compared to other predictive models. As a Bayesian model, BART allows the practitioner to explore the uncertainty around predictions through the posterior distribution. In the bartMan
package, we present new visualisation techniques for exploring BART models. We construct conventional plots to analyze a model’s performance and stability as well as create new tree-based plots to analyze variable importance, interaction, and tree structure. We employ Value Suppressing Uncertainty Palettes (VSUP)^[Correll, M., Moritz, D., & Heer, J. (2018, April). Value-suppressing uncertainty palettes. In Proceedings of the 2018 CHI Conference on Human Factors in Computing Systems (pp. 1-11).] to construct heatmaps that display variable importance and interactions jointly using color scale to represent posterior uncertainty. Our new visualisations are designed to work with the most popular BART R packages available, namely BART
^[Sparapani R, Spanbauer C, McCulloch R (2021). Nonparametric Machine Learning and Efficient Computation with Bayesian Additive Regression Trees: The BART R Package. Journal of Statistical Software], dbarts
^[Vincent Dorie, dbarts: Discrete Bayesian Additive Regression Trees Sampler, 2020], and bartMachine
^[Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning with Bayesian Additive Regression Trees. Journal of Statistical Software].
In this document, we demonstrate our visualisations for evaluation of BART models using the bartMan
(BART Model ANalysis) package.
BART (Bayesian Additive Regression Trees) is a Bayesian non-parametric model using an ensemble of trees for predicting continuous and multi-class responses. Unlike linear regression, BART adapts a flexible functional form, uncovering main and interaction effects. The model is formulated as
[ y_i = \sum_{j=1}^m g(x_i, T_j, M_j) + \varepsilon_i ]
with ( \varepsilon_i \sim N(0, \sigma^2) ), where ( g(x_i, T_j, M_j) = \mu_{j\ell} ) maps observations to predicted values in the terminal node (\ell) of tree (j). ( T_j ) and ( M_j ) denote the tree's structure and the set of predicted values at its terminal nodes, respectively. Tree structures involve binary splits ([x_j \leq c]), with variables and splits randomly chosen and updated through the model fitting process. This updating occurs via Markov chain Monte Carlo methods, involving tree modifications like growing, pruning, changing, or swapping nodes, with specifics depending on the implementation.
Chipman et al. (2010) propose a method called the inclusion proportion to evaluate the variable importance in a BART (Bayesian Additive Regression Trees) model from the posterior samples of the tree structures. This measure of variable importance first calculates for each iteration the proportion of times a variable is used to split nodes considering all (m) trees, and then averages these proportions across all iterations.
More formally, let (K) be the number of posterior samples obtained from a BART model. Let (c_{rk}) be the number of splitting rules using the (r)th predictor as a split variable in the (k)th posterior sample of the trees’ structure across (m) trees. Additionally, let (c_{.k} = \sum_{r=1}^p c_{rk}) represent the total number of splitting rules found in the (k)th posterior sample across the total (p) variables. Therefore, (z_{rk} = \frac{c_{rk}}{c_{.k}}) is the proportion of splitting rules for the (r)th variable, and the average use per splitting rule is given by:
[ \text{VImp}{\text{r}} = \frac{1}{K} \sum{k=1}^K z_{rk} ]
Variable interaction refers to the combined effect of two or more variables on the response. We focus on bivariate interactions, identifying them through the dependency of tree structures on multiple variables. Chipman et al. (2010)^[Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian additive regression trees. The Annals of Applied Statistics, 4(1), 266-298.] and Kapelner & Bleich (2016)^[Adam Kapelner, Justin Bleich (2016). bartMachine: Machine Learning with Bayesian Additive Regression Trees. Journal of Statistical Software] introduced interaction measures based on the analysis of successive splitting rules in tree models. Similar approaches are applied in random forests, using minimal depth to evaluate interaction and importance. In contrast to random forests, where splits are optimized, BART models employ a stochastic search, treating the order of splits as inconsequential. The interaction strength between two variables ( r ) and ( q ) is quantified by
[ \text{VInt}{\text{rq}} = \frac{1}{K} \sum{k=1}^K z_{rqk} ]
To install the development version from GitHub, use:
# install.packages("devtools") #devtools::install_github("AlanInglis/bartMan") library(bartMan)
The data used in the following examples is simulated from the Friedman benchmark problem 7^[Friedman, Jerome H. (1991) Multivariate adaptive regression splines. The Annals of Statistics 19 (1), pages 1-67.]. This benchmark problem is commonly used for testing purposes. The output is created according to the equation:
## Create Friedman data fData <- function(n = 200, sigma = 1.0, seed = 1701, nvar = 5) { set.seed(seed) x <- matrix(runif(n * nvar), n, nvar) colnames(x) <- paste0("x", 1:nvar) Ey <- 10 * sin(pi * x[, 1] * x[, 2]) + 20 * (x[, 3] - 0.5)^2 + 10 * x[, 4] + 5 * x[, 5] y <- rnorm(n, Ey, sigma) data <- as.data.frame(cbind(x, y)) return(data) } f_data <- fData(nvar = 10) x <- f_data[, 1:10] y <- f_data$y
Now we will create a basic BART model using the dbarts
package. However, the visualisation process outlined in this document is identical for any of the dbarts
, BART
, or bartMachine
BART packages.
To begin we load the libraries and then create our models.
# load libraries library(dbarts) # for model library(ggplot2) # for plots
# create dbarts model: set.seed(1701) dbartModel <- bart(x, y, ntree = 20, keeptrees = TRUE, nskip = 100, ndpost = 1000 )
In dbartModel
we have selected there to be 20 trees, 1000 iterations, and a burn-in of 100.
Once the model is built we can extract the data concerning the trees via the extractTreeData
function.
# Create data frames ------------------------------------------------------ trees_data <- extractTreeData(model = dbartModel, data = fData)
The object created by the extractTreeData
function is a list containing five elements. These are:
The tree data frame created from the extractTreeData
function contains 17 columns concerning different attributes associated with the trees, and the structure of which is the same across all BART packages. It can be accessed via $structure
. This data frame is used by many of the bartMan
functions to create the visualisations shown later. In the code below, we take a look at the structure of the data frame of trees.
options(tibble.width = Inf) # used to display full tibble in output head(trees_data$structure, 5)
In the following table, each of the columns of trees_data$structure
are explained:
library(knitr) # Define a data frame with column descriptions column_descriptions <- data.frame( Column = c("**var**", "**splitValue**", "**terminal**", "**leafValue**", "**iteration**", "**treeNum**", "**node**", "**childLeft**", "**childRight**", "**parent**", "**depth**", "**depthMax**", "**isStump**", "**label**", "**value**", "**obsNode**", "**noObs**"), Description = c("Variable name used for splitting.", "Value of the variable at which the split occurs.", "Logical indicator if the node is terminal (TRUE) or not (FALSE).", "Value at the leaf node, NA for non-terminal nodes.", "Iteration number.", "Tree number.", "Unique identifier for the node (following depth-first-left-side traversal).", "Identifier for the left child of the node, NA for terminal nodes.", "Identifier for the right child of the node, NA for terminal nodes.", "Identifier for the parent of the node, NA for root nodes.", "Depth of the node in the tree, starting from 0 for root nodes.", "Maximum depth of the tree.", "Logical indicator if the node is a stump (TRUE) or not (FALSE).", "Node label.", "The value in a node (i.e., either the split value or leaf value).", "List of observations in the node, represented in a compact form.", "Number of observations in the node.") ) # Print the table with column descriptions, ensuring Markdown is correctly processed kable(column_descriptions, format = "markdown", escape = FALSE, col.names = c("Column", "Description"))
The trees in the data frame are ordered in a depth-first left-side traversal method. An example of this is shown below in Figure 1. Here we can see the ordering and node number used in this method. For clarity, the $structure$var
column (created from the extractTreeData()
function) would be ordered as X1, NA, X2, X2, NA, NA, NA
, where NA
indicates terminal (or leaf) nodes.
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/tree_example.png?raw=true")
In Inglis et al. (2022)^[Inglis, A., Parnell, A., & Hurley, C. B. (2022). Visualizing Variable Importance and Variable Interaction Effects in Machine Learning Models. Journal of Computational and Graphical Statistics, 1-13.], the authors propose using a heatmap to display both variable importance (VImp) and variable interactions (VInt) simultaneously (together VIVI), where the importance values are on the diagonal and interaction values on the off-diagonal. We adapt the heatmap displays of importance and interactions to include the uncertainty by use of a VSUP. To begin we fist generate a heatmap containing the raw VIVI values without uncertainty.
stdMat <- viviBartMatrix(trees_data, type = 'standard', metric = 'propMean')
Now we create a list of two matrices. One containing the raw inclusion proportions and one containing the uncertainties. Here we use the coefficient of variation as our uncertainty measure. However the standard deviation or standard error is also available by setting metricError = 'SD'
or metricError = 'SE'
, in the code below.
vsupMat <- viviBartMatrix(trees_data, type = 'vsup', metric = 'propMean', metricError = "CV")
Once the matrices have been created, they can be plotted displaying a VSUP plot with the uncertainty included. For illustration purposes, we show the plot without uncertainty in Figure 1 and with uncertainty in Figure 2:
viviBartPlot(stdMat)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/heatmap_vivid_1.png?raw=true")
viviBartPlot(vsupMat, max_desat = 1, pow_desat = 0.6, max_light = 0.6, pow_light = 1, label = 'CV')
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/heatmap_vsup_1.png?raw=true")
In Figure 3, when we include uncertainty, both the importance and interaction values for the noise variables have a high coefficient of variation associated with them and as such, the most influential variables are highlighted.
When plotting the VSUP with uncertainty, some of the relevant function arguments include:
unc_levels
: The number of uncertainty levelsmax_desat
: The maximum desaturation level of the uncertainty palettepow_desat
: The power of desaturation levelmax_light
: The maximum light brightness of the uncertainty palettepow_light
: The power of light levelHere we examine more closely the structure of the decision trees created when building a BART model. Examining the tree structure may yield information on the stability and variability of the tree structures as the algorithm iterates to create the posterior. By sorting and colouring the trees appropriately we can identify important variables and common interactions between variables for a given iteration. Alternatively we can look at how a single tree evolves through the iteration to explore the fitting algorithm’s stability.
To plot an individual tree, we can choose to display it either in dendrogram format, or icicle format. Additionally, we can choose which tree number or iteration to display:
plotSingleTree(trees = trees_data, treeNo = 1, iter = 1, plotType = "dendrogram") plotSingleTree(trees = trees_data, treeNo = 1, iter = 1, plotType = "icicle")
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/single_trees_1.png?raw=true")
The plotTrees
function allows for a few different options when plotting. The function arguments for plotTrees
are outlined as follows:
• trees
: A data frame of trees, usually created via extractTreeData()
.
• iter
: An integer specifying the iteration number of trees to be included in the output.
If NULL, trees from all iterations are included.
• treeNo
: An integer specifying the number of the tree to include in the output.
If NULL, all trees are included.
• fillBy
: A character string specifying the attribute to color nodes by.
Options are 'response' for coloring nodes based on their mean response values or
'mu' for coloring nodes based on their predicted value, or NULL for no
specific fill attribute.
• sizeNodes
A logical value indicating whether to adjust node sizes.
If TRUE, node sizes are adjusted; if FALSE, all nodes are given the same size.
• removeStump
A logical value. If TRUE, then stumps are removed from plot.
• selectedVars
A vector of selected variables to display. Either a character vector of names
or the variables column number.
• pal
A colour palette for node colouring. Palette is used when 'fillBy' is specified for gradient colouring.
• center_Mu
A logical value indicating whether to center the color scale for the 'mu'
attribute around zero. Applicable only when 'fillBy' is set to "mu".
• cluster
A character string that specifies the criterion for reordering trees in the output.
Currently supports "depth" for ordering by the maximum depth of nodes, and "var" for a
clustering based on variables. If NULL, no reordering is performed.
For example, we can chose to display all the trees from a selected iteration:
plotTrees(trees = trees_data, iter = 1)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/trees_iter_1_1.png?raw=true")
When the number of variables or trees is large it can become hard to identify interesting features. By using the selectedVars
argument, we can highlight selected variables by colouring them brightly while uniformly colouring the remaining variables a light grey.
plotTrees(trees = trees_data, iter = 1, removeStump = TRUE, selectedVars = c("x1", "x2") )
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/trees_selected_1.png?raw=true")
We can plot a single tree over all iterations by selecting a tree via the treeNo
argument. This shows us visually BART's grow, prune, change, swap mechanisms in action.
plotTrees(trees = trees_data, treeNo = 1)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/trees_treeNum_1_1.png?raw=true")
When viewing the trees, it can be useful to view different aspects or metrics. In Figure 7 we show some of these aspects by displaying all the trees in a selected iteration. For example, in (a) we color them by the terminal node parameter value. In (b) we color terminal nodes and stumps by the mean response. In (c) we sort the trees by structure starting with the most common tree structure and descending to the least common tree found (useful for identifying the most important splits). Finally, in (d) we sort the trees by depth. As the $\mu$ values in (a) are centered around zero by default, we use a single-hue, colorblind friendly, diverging color palette to display the values. For comparison, we use the same palette to represent the mean response values in (b).
plotTrees(trees = trees_data, iter = 1, sizeNode = T, fillBy = 'mu') plotTrees(trees = trees_data, iter = 1, sizeNode = T, fillBy = 'response') plotTrees(trees = trees_data, iter = 1, cluster = "var") plotTrees(trees = trees_data, iter = 1, cluster = "depth")
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/tree_quad_1.png?raw=true")
As an alternative to the sorting of the tree structures, seen in Figure 8 (c), we provide a bar plot summarizing the tree structures. Here we choose to display the top 10 most frequent tree structures, however displaying a single tree across iterations is possible via the iter
argument.
treeBarPlot(trees = trees_data, topTrees = 10, iter = NULL)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/trees_barplot_1.png?raw=true")
An interesting finding from looking at Figure 9 is we can see that the most frequent tree is a tree with a single binary split on (x_3). However, looking at the rest of the trees, we can see the prevalence of splits on (x_1) and (x_2), indicating that these are important variables.
Proximity matrices combined with multidimensional scaling (MDS) are commonly used in random forests to identify outlying observations^[Breiman, L. (2001). Random forests. Machine learning, 45(1), 5-32.Chicago]. When two observations lie in the same terminal node repeatedly they can be said to be similar, and so an $N × N$ proximity matrix is obtained by accumulating the number of times at which this occurs for each pair of observations, and subsequently divided by the total number of trees. A higher value indicates that two observations are more similar.
To begin, we fist create a proximity matrix. This can be seriated to group similar observations together by setting reorder = TRUE
. The normailze
argument will divide the proximity scores by the total number of trees. Additionally, we can choose to get the proximity matrix for a single iteration (as shown below) or over all iterations, the latter is achieved by setting iter = NUll
.
bmProx <- proximityMatrix(trees = trees_data, reorder = TRUE, normalize = TRUE, iter = 1)
We can then visualize the proximity matrix using the plotProximity
function.
plotProximity(matrix = bmProx) + theme(axis.text.x = element_text(angle = 90,))
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/proximity_plot_1.png?raw=true")
The proximity matrix can then be visualized using classical MDS (henceforth MDS) to plot their relationship in a lower dimensional projection.
In BART, as there is a proximity matrix for every iteration and a posterior distribution of proximity matrices. We introduce a rotational constraint so that we can similarly obtain a posterior distribution of each observation in the lower dimensional space. We first choose a target iteration (as shown above) and apply MDS. For each subsequent iteration we rotate the MDS solution matrix to match this target as closely as possible using Procrustes’ method. We end up with a point for each observation per iteration per MDS dimension.We then group the observations by the mean of each group and produce a scatter plot, where each point represents the centroid of the location of each observation across all the MDS solutions. We extend this further by displaying confidence ellipses around each observation’s posterior location in the reduced space (via the level
argument). Since these are often overlapping we have created an interactive version that highlights an observation’s ellipse and displays the observation number when hovering the mouse pointer above the ellipse (Figure 11 shows a screenshot of this interaction in use). However, non-interactive versions are available via the plotType
argument.
mdsBart(trees = trees_data, data = f_data, target = bmProx, plotType = 'interactive', level = 0.25, response = 'y')
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/mds.png?raw=true")
In addition to the above, we also provide visualisations for general diagnostics of a BART model. These include checking for convergence, the stability of the trees, the efficiency of the algorithm, and the predictive performance of the model. To begin we take a look at some general diagnostics to assess the stability of the model fit. The burnIn
argument should be set to the burn-in value selected when building the model and indicates the separation between the pre and post burn-in period in the plot.
bartDiag(model = dbartModel, response = f_data$y, burnIn = 100, data = fData)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/diagnostics_1.png?raw=true")
The post burn-in percentage acceptance rate across all iterations can also be visualized, where each point represents a single iteration. A regression line is shown to indicate the changes in acceptance rate across iterations and to identify the mean rate.
acceptRate(trees = trees_data)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/accept_rate_1.png?raw=true")
As with the acceptance rate, the average tree depth and average number of all nodes per iteration can give an insight into the fit’s stability. Figure 11 displays these two metrics. A locally estimated scatter plot smoothing (LOESS) regression line is shown to indicate the changes in both the average tree depth and the average number of nodes across iterations:
treeDepth(trees = trees_data) treeNodes(trees = trees_data)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/tree_depth_node_1.png?raw=true")
Figure 12 shows the densities of split values over all post burn-in iterations for each variable for both models (in green), combined with the densities of the predictor variables (labeled “data”, in red):
splitDensity(trees = trees_data, data = f_data, display = 'dataSplit')
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/split_density_1.png?raw=true")
Alternatively, we can just examine the plit value densities in a ridge plot style.
splitDensity(trees = trees_data, data = f_data, display = 'ridge')
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/ridges_1.png?raw=true")
To assess inclusion proportions related to variable importance and interactions, we offer functions that facilitate the extraction and visualisation of these metrics. For instance, the viviBart
function allows retrieval of a list encompassing both variable importance and interactions, along with their associated error metrics, an example of which is shown below. To select either just the importance or just the interactions, the out
argument should be set to 'vimp'
or 'vint'
, respectively.
# show both vimp and vint viviBart(trees = trees_data, out = 'vivi')
To visualize the inclusion proportion variable importance (with their 25\% to 75\% quantile interval included) we use the vimpPlot
function.
# plot inclusion proportions of each variable: vimpPlot(trees = trees_data, plotType = 'point') vimpPlot(trees = trees_data, plotType = 'barplot')
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/vimp_point_bar_1.png?raw=true")
An alternative method to display the inclusion proportions is by using a Letter-value plot^[Hofmann, H., Wickham, H., & Kafadar, K. (2017). value plots: Box plots for large data. Journal of Computational and Graphical Statistics, 26(3), 469-477.]. This type of plot is useful for visualizing the distribution of a continuous variable (here variable inclusion proportions), with the inner-most box showing the lower and upper fourths, as with a conventional box plot, the median value being shown as a black line and outliers as blue triangles. Each extending section is drawn at incremental steps of upper and lower eights, sixteenths and so on until a stopping rule has been reached. The color of each box corresponds to the density of the data with darker shades indicating higher data density. Note, when plotType = 'lvp'
, the geom_lv
function from the lvplot
package is used. This function requires the ggplot2
package to be loaded.
library("ggplot2") vimpPlot(trees = trees_data, plotType = 'lvp') + coord_flip()
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/vimp_lvp_1.png?raw=true")
Similarly, we provide a function for viewing the inclusion proportions for interactions, again with 25\% to 75\% quantile interval included. In Figure 19, we display only the top 5 strongest variable pair interactions.
# plot inclusion proportions of each variable pair: vintPlot(trees = trees_data, top = 5)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/vint_point_bar_1.png?raw=true")
In our package we also implement one of the variable selection procedures developed in Bleich et al. (2014)^[Bleich, J., Kapelner, A., George, E. I., & Jensen, S. T. (2014). Variable selection for BART: an application to gene regulation. The Annals of Applied Statistics, 8(3), 1750-1781.], specifically, the so-called local threshold procedure. In this method, the proportion of splitting rules is calculated, then the response variable is randomly permuted, which has the effect of breaking the relationship between the response and the covariates. The model is then re-built as a null model using the permuted response. From this, the null proportion, is calculated and a new measure of importance is obtained.
When using this method, there are three key arguments: numRep
, numTreesRep
, and alpha
. numRep
determines the number of replicates to perform for the BART null model's variable inclusion proportions. Whereas, numTreesRep
determines the number of trees to be used in the replicates. alpha
sets the cut-off level for the thresholds. That is, a predictor is deemed important if its variable inclusion proportion exceeds the 1 − $\alpha$ quantile of its own null distribution. If setting shift = TRUE
, the inclusion proportions are shifted by the difference in distance between the quantile and the value of the inclusion proportion point.
localProcedure(model = dbartModel, data = f_data, numRep = 5, numTreesRep = 5, alpha = 0.5, shift = FALSE)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/local_procedure_both.png?raw=true")
shift = FALSE
, whereas in the right plot, shift = TRUE
.We also provide the functionality for a variable selection approach which creates a null model by permuting the response once, rebuilding the model, and calculating the inclusion proportion on the null model. The final result displayed is the original model's inclusion proportion minus the null inclusion proportion. This function is available for both the importance and the interactions.
permVimp(model = dbartModel, data = f_data, response = 'y', numTreesPerm = 5) permVimp(model = dbartModel, data = f_data, response = 'y', numTreesPerm = 10) permVimp(model = dbartModel, data = f_data, response = 'y', numTreesPerm = 20)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/permvimp_all.png?raw=true")
numTreesPerm
is equal to 5, 10, and 20. In Figure 21, we compare the difference in importance when numTreesPerm
is equal to 5, 10, and 20. We can see that as we increase the number of trees to be used in the replicates, variable (\x_1) and (\x_2) start to emerge as the most important variables.
For assessing the interactions using the single permutation method we have:
permVint(trees = trees_data,model = dbartModel, data = f_data, response = 'y', top = 5)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/permvint_1.png?raw=true")
If any of the variables used to build the BART model are categorical, the aforementioned BART packages replace the categorical variables with $d$ dummy variables, where $d$ is the number of factor levels. However, we provide the functionality to adjust the inclusion proportions for variable importance and interaction by aggregating over factor levels. This provides a complete picture of the importance of a factor, rather than that associated with individual factor levels.
In the following example, we build a BART model using the BART
package where one of the covariates is a factor and extract the tree data.
library(BART) data(iris) set.seed(1701) bartModel <- wbart(x.train = iris[,2:5], y.train = iris[,1], nskip = 10, ndpost = 100, nkeeptreedraws = 100, ntree = 5 ) bt_trees <- extractTreeData(model = bartModel, data = iris)
As Species
is a factor with three levels, it is split into three dummy variables when building the model. For a practical example of what this looks like, we examine the variable importance inclusion proportions both before and after aggregating over the factor levels.
# extract the vimp data vimpData <- viviBart(trees = bt_trees, out = 'vimp') vimpData[,1:3] # looking at the relevant columns
To combine the dummy factor variables, we use the combineDummy()
function. This function takes in a tree data created by extractTreeData()
, combines the dummy factors, and outputs a new tree data object with the combined dummies.
bt_trees_combined <- combineDummy(trees = bt_trees)
Looking again we see the dummies are combined:
# extract the vimp data vimpData_combined <- viviBart(trees = bt_trees_combined, out = 'vimp') vimpData_combined[,1:3] # looking at the relevant columns
Taking a look at the trees before and after combining the dummy variables
plotTrees(trees = bt_trees, iter = 1) plotTrees(trees = bt_trees_combined, iter = 1)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/trees_dummy_1.png?raw=true")
The bartMan
package, initially designed for three primary BART packages, also allows users to input their own data to create an object comparable with the extractTreeData()
function output. This integration is facilitated through the tree_dataframe()
function, which ensures its output aligns with extractTreeData()
. The tree_dataframe()
function requires the original dataset used to build the model and a tree data frame. The tree data frame must include a var
column that follows a depth-first left-side traversal order, a value
column with the split or terminal node values, and columns for iteration
and treeNum
to denote the specific iteration and tree number, respectively.
In the example below, we have both a data set f_data
and a tree data frame df_tree
containing the needed columns.
# Original Data f_data <- data.frame( x1 = c(0.127393428, 0.766723202, 0.054421675, 0.561384595, 0.937597936, 0.296445079, 0.665117463, 0.652215607, 0.002313313, 0.661490602), x2 = c(0.85600486, 0.02407293, 0.51942589, 0.20590965, 0.71206404, 0.27272126, 0.66765977, 0.94837341, 0.46710461, 0.84157353), x3 = c(0.4791849, 0.8265008, 0.1076198, 0.2213454, 0.6717478, 0.5053170, 0.8849426, 0.3560469, 0.5732139, 0.5091688), x4 = c(0.55089910, 0.35612092, 0.80230714, 0.73043828, 0.72341749, 0.98789408, 0.04751297, 0.06630861, 0.55040341, 0.95719901), x5 = c(0.9201376, 0.9279873, 0.5993939, 0.1135139, 0.2472984, 0.4514940, 0.3097986, 0.2608917, 0.5375610, 0.9608329), y = c(14.318655, 12.052513, 13.689970, 13.433919, 18.542184, 14.927344, 14.843248, 13.611167, 7.777591, 23.895456) ) # Your own data frame of trees df_tree <- data.frame( var = c("x3", "x1", NA, NA, NA, "x1", NA, NA, "x3", NA, NA, "x1", NA, NA), value = c(0.823,0.771,-0.0433,0.0188,-0.252,0.215,-0.269,0.117,0.823,0.0036,-0.244,0.215,-0.222,0.0783), iteration = c(1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2), treeNum = c(1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2) ) # take a look at the dataframe of trees df_tree
Applying the tree_dataframe()
function creates an object that aligns with the output of extractTreeData()
. The response
argument is an optional character argument of the name of the response variable in your BART model. Including the response will remove it from the list elements Variable names
and nVar
.
trees_data <- tree_dataframe(data = f_data , trees = df_tree, response = 'y')
# look at tree data object
trees_data
Once we have our newly created object, it can be used in any of the plotting functions, for example:
plotTrees(trees = trees_data, iter = NULL)
knitr::include_graphics("https://github.com/AlanInglis/bartMan/blob/master/bartman_vignettte_new_plots_1/own_trees_iter_null_1.png?raw=true ")
tree_dataframe()
function. In addition to the data frame of trees, we also provide an option for separating the trees into a list, where each element is a tidygraph
object, containing the structure of every individual tree. This allows a user to quickly access each tree for custom data analysis or visualisation, with options for selection of either a specific iteration, tree number, or both together. Each tidygraph
object includes node and edge information. For example, if we want to create a list of the trees used in the first iteration, we would do the following:
tree_list <- treeList(trees = trees_data, iter = 1, treeNo = NULL)
Examine the first tree yields:
tree_list[[1]]
As we can see, it contains many of the same columns found in the data frame of trees. Specifically, var, node, iteration, treeNum, label, value, depthMax, noObs, respNode, obsNode, isStump, where each of the columns has the same meaning as outlined previously. The one exception is respNode, which contains the average response value over all observations found in a particular node.
The tree_dataframe()
function internally utilises a set of utility functions that are also available for direct use. These functions are instrumental in processing and analysing tree structures:
node_depth
: This function calculates the depth of each node within the tree, assuming a binary tree structure. It requires the tree data frame to have a terminal
column that indicates whether a node is terminal. The function is typically used as follows:depthList <- lapply(split(trees, ~treeNum + iteration), function(x) cbind(x, depth = node_depth(x)-1)) trees <- dplyr::bind_rows(depthList, .id = "list_id")
In this example, node_depth is applied to each subset of the tree data frame to compute and then bind the depth information, creating an enriched tree data frame.
getChildren
: This function adds childLeft
, childRight
, and parent
columns to the tree data frame, establishing parent-child relationships between the nodes. This structural information is crucial for understanding and navigating the hierarchy of the tree.getChildren(data = trees)
getObservations
: Identifies which observations from a dataset correspond to which nodes in a tree, given the tree's structural data (treeData
). The treeData
must include columns for iteration
, treeNum
, var
, and splitValue
, which is used in mapping the dataset observations to the appropriate nodes of the tree. Additionally, the original data used to build the model is required. getObservations(data = original_dataframe, treeData = trees_data)
The bartMan
package offers a comprehensive suite of tools for visualizing and interpreting the outputs of Bayesian Additive Regression Trees (BART) models. Through its intuitive functions and detailed analyses, users can effectively uncover the underlying structure and interactions within their data, enhancing their understanding of model behavior and decision-making processes.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.