R/treeList.R

Defines functions treeList

Documented in treeList

#' Generate a List of Tree Structures from BART Model Output
#'
#' This function takes a dataframe of trees, which is output from a BART model, and organizes
#' it into a list of tree structures. It allows for filtering based on iteration number, tree
#' number, and optionally reordering based on the maximum depth of nodes or variables.
#'
#'
#' @param trees A dataframe that contains the tree structures generated by a BART model.
#'                 Expected columns include iteration, treeNum, parent, node, obsNode,
#' @param iter An integer specifying the iteration number of trees to be included in the output.
#'             If NULL, trees from all iterations are included.
#' @param treeNo An integer specifying the number of the tree to include in the output.
#'               If NULL, all trees are included.
#'
#' @return A list of tidygraph objects, each representing the structure of a tree. Each tidygraph object includes
#'         node and edge information necessary for visualisation.
#'
#' @importFrom dplyr filter select one_of group_by group_split
#' @importFrom purrr map map2
#' @importFrom tidygraph tbl_graph
#'
#' @examples
#' if(requireNamespace("dbarts", quietly = TRUE)){
#'  # Load the dbarts package to access the bart function
#'  library(dbarts)
#'  library(ggplot2)
#'  # Get Data
#'  df <- na.omit(airquality)
#'  # Create Simple dbarts Model For Regression:
#'  set.seed(1701)
#'  dbartModel <- bart(df[2:6], df[, 1], ntree = 5, keeptrees = TRUE, nskip = 10, ndpost = 10)
#'
#'  # Tree Data
#'  trees_data <- extractTreeData(model = dbartModel, data = df)
#'  trees_list <- treeList(trees_data)
#' }
#'
#' @export
#'
#'



treeList <- function(trees, iter = NULL, treeNo = NULL) {


  # Error check for 'iter' input
  if (!is.null(iter)) {
    if (iter < 1 || iter > trees$nMCMC) {
      stop("Error: 'iter' value is out of allowed range. Max iteration is ", trees$nMCMC, ".")
    }
  }

  # Error check for 'treeNo' input
  if (!is.null(treeNo)) {
    if (treeNo < 1 || treeNo > trees$nTree) {
      stop("Error: 'treeNo' value is out of allowed range. Max tree number is ", trees$nTree, ".")
    }
  }

  # filter if selected
  if (is.null(iter) & is.null(treeNo)) {
    df <- trees$structure
    message("Displaying All Trees.")
  } else if (is.null(iter) & !is.null(treeNo)) {
    df <- trees$structure |>
      filter(treeNum == treeNo)
    message(paste0("Tree Number ", treeNo, " Selected."))
  } else if (!is.null(iter) & is.null(treeNo)) {
    df <- trees$structure |>
      filter(iteration == iter)
    message(paste0("Iteration ", iter, " Selected."))
  } else {
    df <- trees$structure |>
      filter(iteration == iter, treeNum == treeNo)
    message(paste0("Iteration ", iter, " and Tree Number ", treeNo, " Selected."))
  }


  # get mean response value per node:
  respNode <- df$obsNode
  respNode <- lapply(respNode, mean)
  df$respNode <- unlist(respNode)

  # Which columns to display
  keeps <- c("var",
             "node",
             "parent",
             "iteration",
             "treeNum",
             "label",
             "value",
             "depthMax",
             'noObs',
             'respNode',
             'obsNode',
             'isStump')

  res <- dplyr::select(
    df,
    dplyr::one_of(keeps)
  )


  res <- res |>
    group_by(iteration, treeNum)

  nodeList <- dplyr::group_split(dplyr::select(res, -parent), .keep = TRUE)

  edgeList <- purrr::map(
    dplyr::group_split(dplyr::select(
      res,
      iteration,
      treeNum,
      parent,
      node
    ), .keep = FALSE),
    ~ dplyr::filter(., !is.na(parent))
  )

  # Turn into data structure for tidy graph manipulation
  tblgList <- purrr::map2(
    .x = nodeList,
    .y = edgeList,
    .f = ~ tidygraph::tbl_graph(
      nodes = .x,
      edges = .y,
      directed = TRUE
    )
  )

  return(tblgList)
}
AlanInglis/BartVis documentation built on July 27, 2024, 12:02 a.m.