R/xgb.plot.deepness.R

Defines functions xgb.plot.deepness

Documented in xgb.plot.deepness

#' Plot model trees deepness
#'
#' Visualizes distributions related to depth of tree leafs.
#' \code{xgb.plot.deepness} uses base R graphics, while \code{xgb.ggplot.deepness} uses the ggplot backend.
#'
#' @param model either an \code{xgb.Booster} model generated by the \code{xgb.train} function
#'        or a data.table result of the \code{xgb.model.dt.tree} function.
#' @param plot (base R barplot) whether a barplot should be produced.
#'        If FALSE, only a data.table is returned.
#' @param which which distribution to plot (see details).
#' @param ... other parameters passed to \code{barplot} or \code{plot}.
#'
#' @details
#'
#' When \code{which="2x1"}, two distributions with respect to the leaf depth
#' are plotted on top of each other:
#' \itemize{
#'  \item the distribution of the number of leafs in a tree model at a certain depth;
#'  \item the distribution of average weighted number of observations ("cover")
#'        ending up in leafs at certain depth.
#' }
#' Those could be helpful in determining sensible ranges of the \code{max_depth}
#' and \code{min_child_weight} parameters.
#'
#' When \code{which="max.depth"} or \code{which="med.depth"}, plots of either maximum or median depth
#' per tree with respect to tree number are created. And \code{which="med.weight"} allows to see how
#' a tree's median absolute leaf weight changes through the iterations.
#'
#' This function was inspired by the blog post
#' \url{https://github.com/aysent/random-forest-leaf-visualization}.
#'
#' @return
#'
#' Other than producing plots (when \code{plot=TRUE}), the \code{xgb.plot.deepness} function
#' silently returns a processed data.table where each row corresponds to a terminal leaf in a tree model,
#' and contains information about leaf's depth, cover, and weight (which is used in calculating predictions).
#'
#' The \code{xgb.ggplot.deepness} silently returns either a list of two ggplot graphs when \code{which="2x1"}
#' or a single ggplot graph for the other \code{which} options.
#'
#' @seealso
#'
#' \code{\link{xgb.train}}, \code{\link{xgb.model.dt.tree}}.
#'
#' @examples
#'
#' data(agaricus.train, package='xgboost')
#'
#' # Change max_depth to a higher number to get a more significant result
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 6,
#'                eta = 0.1, nthread = 2, nrounds = 50, objective = "binary:logistic",
#'                subsample = 0.5, min_child_weight = 2)
#'
#' xgb.plot.deepness(bst)
#' xgb.ggplot.deepness(bst)
#'
#' xgb.plot.deepness(bst, which='max.depth', pch=16, col=rgb(0,0,1,0.3), cex=2)
#'
#' xgb.plot.deepness(bst, which='med.weight', pch=16, col=rgb(0,0,1,0.3), cex=2)
#'
#' @rdname xgb.plot.deepness
#' @export
xgb.plot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight"),
                              plot = TRUE, ...) {

  if (!(inherits(model, "xgb.Booster") || is.data.table(model)))
    stop("model: Has to be either an xgb.Booster model generaged by the xgb.train function\n",
         "or a data.table result of the xgb.importance function")

  if (!requireNamespace("igraph", quietly = TRUE))
    stop("igraph package is required for plotting the graph deepness.", call. = FALSE)

  which <- match.arg(which)

  dt_tree <- model
  if (inherits(model, "xgb.Booster"))
    dt_tree <- xgb.model.dt.tree(model = model)

  if (!all(c("Feature", "Tree", "ID", "Yes", "No", "Cover") %in% colnames(dt_tree)))
    stop("Model tree columns are not as expected!\n",
         "  Note that this function works only for tree models.")

  dt_depths <- merge(get.leaf.depth(dt_tree), dt_tree[, .(ID, Cover, Weight = Quality)], by = "ID")
  setkeyv(dt_depths, c("Tree", "ID"))
  # count by depth levels, and also calculate average cover at a depth
  dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
  setkey(dt_summaries, "Depth")

  if (plot) {
    if (which == "2x1") {
      op <- par(no.readonly = TRUE)
      par(mfrow = c(2, 1),
          oma = c(3, 1, 3, 1) + 0.1,
          mar = c(1, 4, 1, 0) + 0.1)

      dt_summaries[, barplot(N, border = NA, ylab = 'Number of leafs', ...)]

      dt_summaries[, barplot(Cover, border = NA, ylab = "Weighted cover", names.arg = Depth, ...)]

      title("Model complexity", xlab = "Leaf depth", outer = TRUE, line = 1)
      par(op)
    } else if (which == "max.depth") {
      dt_depths[, max(Depth), Tree][
                , plot(jitter(V1, amount = 0.1) ~ Tree, ylab = 'Max tree leaf depth', xlab = "tree #", ...)]
    } else if (which == "med.depth") {
      dt_depths[, median(as.numeric(Depth)), Tree][
                , plot(jitter(V1, amount = 0.1) ~ Tree, ylab = 'Median tree leaf depth', xlab = "tree #", ...)]
    } else if (which == "med.weight") {
      dt_depths[, median(abs(Weight)), Tree][
                , plot(V1 ~ Tree, ylab = 'Median absolute leaf weight', xlab = "tree #", ...)]
    }
  }
  invisible(dt_depths)
}

# Extract path depths from root to leaf
# from data.table containing the nodes and edges of the trees.
# internal utility function
get.leaf.depth <- function(dt_tree) {
  # extract tree graph's edges
  dt_edges <- rbindlist(list(
      dt_tree[Feature != "Leaf", .(ID, To = Yes, Tree)],
      dt_tree[Feature != "Leaf", .(ID, To = No, Tree)]
    ))
  # whether "To" is a leaf:
  dt_edges <-
    merge(dt_edges,
          dt_tree[Feature == "Leaf", .(ID, Leaf = TRUE)],
          all.x = TRUE, by.x = "To", by.y = "ID")
  dt_edges[is.na(Leaf), Leaf := FALSE]

  dt_edges[, {
    graph <- igraph::graph_from_data_frame(.SD[, .(ID, To)])
    # min(ID) in a tree is a root node
    paths_tmp <- igraph::shortest_paths(graph, from = min(ID), to = To[Leaf == TRUE])
    # list of paths to each leaf in a tree
    paths <- lapply(paths_tmp$vpath, names)
    # combine into a resulting path lengths table for a tree
    data.table(Depth = sapply(paths, length), ID = To[Leaf == TRUE])
  }, by = Tree]
}

# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(
  c(
    ".N", "N", "Depth", "Quality", "Cover", "Tree", "ID", "Yes", "No", "Feature", "Leaf", "Weight"
  )
)

Try the xgboost package in your browser

Any scripts or data that you put into this service are public.

xgboost documentation built on March 31, 2023, 10:05 p.m.