R/xgb.model.dt.tree.R

Defines functions xgb.model.dt.tree

Documented in xgb.model.dt.tree

#' Parse a boosted tree model text dump
#'
#' Parse a boosted tree model text dump into a \code{data.table} structure.
#'
#' @param feature_names character vector of feature names. If the model already
#'          contains feature names, those would be used when \code{feature_names=NULL} (default value).
#'          Non-null \code{feature_names} could be provided to override those in the model.
#' @param model object of class \code{xgb.Booster}
#' @param text \code{character} vector previously generated by the \code{xgb.dump}
#'          function  (where parameter \code{with_stats = TRUE} should have been set).
#'          \code{text} takes precedence over \code{model}.
#' @param trees an integer vector of tree indices that should be parsed.
#'          If set to \code{NULL}, all trees of the model are parsed.
#'          It could be useful, e.g., in multiclass classification to get only
#'          the trees of one certain class. IMPORTANT: the tree index in xgboost models
#'          is zero-based (e.g., use \code{trees = 0:4} for first 5 trees).
#' @param use_int_id a logical flag indicating whether nodes in columns "Yes", "No", "Missing" should be
#'          represented as integers (when FALSE) or as "Tree-Node" character strings (when FALSE).
#' @param ... currently not used.
#'
#' @return
#' A \code{data.table} with detailed information about model trees' nodes.
#'
#' The columns of the \code{data.table} are:
#'
#' \itemize{
#'  \item \code{Tree}: integer ID of a tree in a model (zero-based index)
#'  \item \code{Node}: integer ID of a node in a tree (zero-based index)
#'  \item \code{ID}: character identifier of a node in a model (only when \code{use_int_id=FALSE})
#'  \item \code{Feature}: for a branch node, it's a feature id or name (when available);
#'              for a leaf note, it simply labels it as \code{'Leaf'}
#'  \item \code{Split}: location of the split for a branch node (split condition is always "less than")
#'  \item \code{Yes}: ID of the next node when the split condition is met
#'  \item \code{No}: ID of the next node when the split condition is not met
#'  \item \code{Missing}: ID of the next node when branch value is missing
#'  \item \code{Quality}: either the split gain (change in loss) or the leaf value
#'  \item \code{Cover}: metric related to the number of observation either seen by a split
#'                      or collected by a leaf during training.
#' }
#'
#' When \code{use_int_id=FALSE}, columns "Yes", "No", and "Missing" point to model-wide node identifiers
#' in the "ID" column. When \code{use_int_id=TRUE}, those columns point to node identifiers from
#' the corresponding trees in the "Node" column.
#'
#' @examples
#' # Basic use:
#'
#' data(agaricus.train, package='xgboost')
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#'                eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#'
#' (dt <- xgb.model.dt.tree(colnames(agaricus.train$data), bst))
#'
#' # This bst model already has feature_names stored with it, so those would be used when
#' # feature_names is not set:
#' (dt <- xgb.model.dt.tree(model = bst))
#'
#' # How to match feature names of splits that are following a current 'Yes' branch:
#'
#' merge(dt, dt[, .(ID, Y.Feature=Feature)], by.x='Yes', by.y='ID', all.x=TRUE)[order(Tree,Node)]
#'
#' @export
xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
                              trees = NULL, use_int_id = FALSE, ...){
  check.deprecation(...)

  if (!inherits(model, "xgb.Booster") && !is.character(text)) {
    stop("Either 'model' must be an object of class xgb.Booster\n",
         "  or 'text' must be a character vector with the result of xgb.dump\n",
         "  (or NULL if 'model' was provided).")
  }

  if (is.null(feature_names) && !is.null(model) && !is.null(model$feature_names))
    feature_names <- model$feature_names

  if (!(is.null(feature_names) || is.character(feature_names))) {
    stop("feature_names: must be a character vector")
  }

  if (!(is.null(trees) || is.numeric(trees))) {
    stop("trees: must be a vector of integers.")
  }

  if (is.null(text)){
    text <- xgb.dump(model = model, with_stats = TRUE)
  }

  if (length(text) < 2 ||
      sum(grepl('leaf=(\\d+)', text)) < 1) {
    stop("Non-tree model detected! This function can only be used with tree models.")
  }

  position <- which(grepl("booster", text, fixed = TRUE))

  add.tree.id <- function(node, tree) if (use_int_id) node else paste(tree, node, sep = "-")

  anynumber_regex <- "[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?"

  td <- data.table(t = text)
  td[position, Tree := 1L]
  td[, Tree := cumsum(ifelse(is.na(Tree), 0L, Tree)) - 1L]

  if (is.null(trees)) {
    trees <- 0:max(td$Tree)
  } else {
    trees <- trees[trees >= 0 & trees <= max(td$Tree)]
  }
  td <- td[Tree %in% trees & !grepl('^booster', t)]

  td[, Node := as.integer(sub("^([0-9]+):.*", "\\1", t))]
  if (!use_int_id) td[, ID := add.tree.id(Node, Tree)]
  td[, isLeaf := grepl("leaf", t, fixed = TRUE)]

  # parse branch lines
  branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
                      "gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
  branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover")
  td[
    isLeaf == FALSE,
    (branch_cols) := {
      matches <- regmatches(t, regexec(branch_rx, t))
      # skip some indices with spurious capture groups from anynumber_regex
      xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE]
      xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
      if (length(xtr) == 0) {
        as.data.table(
          list(Feature = "NA", Split = "NA", Yes = "NA", No = "NA", Missing = "NA", Quality = "NA", Cover = "NA")
        )
      } else {
        as.data.table(xtr)
      }
    }
  ]

  # assign feature_names when available
  is_stump <- function() {
    return(length(td$Feature) == 1 && is.na(td$Feature))
  }
  if (!is.null(feature_names) && !is_stump()) {
    if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE))
      stop("feature_names has less elements than there are features used in the model")
    td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]]
  }

  # parse leaf lines
  leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
  leaf_cols <- c("Feature", "Quality", "Cover")
  td[
    isLeaf == TRUE,
    (leaf_cols) := {
      matches <- regmatches(t, regexec(leaf_rx, t))
      xtr <- do.call(rbind, matches)[, c(2, 4)]
      if (length(xtr) == 2) {
        c("Leaf", as.data.table(xtr[1]), as.data.table(xtr[2]))
      } else {
        c("Leaf", as.data.table(xtr))
      }
    }
  ]

  # convert some columns to numeric
  numeric_cols <- c("Split", "Quality", "Cover")
  td[, (numeric_cols) := lapply(.SD, as.numeric), .SDcols = numeric_cols]
  if (use_int_id) {
    int_cols <- c("Yes", "No", "Missing")
    td[, (int_cols) := lapply(.SD, as.integer), .SDcols = int_cols]
  }

  td[, t := NULL]
  td[, isLeaf := NULL]

  td[order(Tree, Node)]
}

# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(c("Tree", "Node", "ID", "Feature", "t", "isLeaf", ".SD", ".SDcols"))

Try the xgboost package in your browser

Any scripts or data that you put into this service are public.

xgboost documentation built on March 31, 2023, 10:05 p.m.