R/lgb.model.dt.tree.R

Defines functions single.tree.parse lgb.model.dt.tree

Documented in lgb.model.dt.tree

#' @name lgb.model.dt.tree
#' @title Parse a LightGBM model json dump
#' @description Parse a LightGBM model json dump into a \code{data.table} structure.
#' @param model object of class \code{lgb.Booster}
#' @param num_iteration number of iterations you want to predict with. NULL or
#'                      <= 0 means use best iteration
#' @return
#' A \code{data.table} with detailed information about model trees' nodes and leafs.
#'
#' The columns of the \code{data.table} are:
#'
#' \itemize{
#'  \item{\code{tree_index}: ID of a tree in a model (integer)}
#'  \item{\code{split_index}: ID of a node in a tree (integer)}
#'  \item{\code{split_feature}: for a node, it's a feature name (character);
#'                              for a leaf, it simply labels it as \code{"NA"}}
#'  \item{\code{node_parent}: ID of the parent node for current node (integer)}
#'  \item{\code{leaf_index}: ID of a leaf in a tree (integer)}
#'  \item{\code{leaf_parent}: ID of the parent node for current leaf (integer)}
#'  \item{\code{split_gain}: Split gain of a node}
#'  \item{\code{threshold}: Splitting threshold value of a node}
#'  \item{\code{decision_type}: Decision type of a node}
#'  \item{\code{default_left}: Determine how to handle NA value, TRUE -> Left, FALSE -> Right}
#'  \item{\code{internal_value}: Node value}
#'  \item{\code{internal_count}: The number of observation collected by a node}
#'  \item{\code{leaf_value}: Leaf value}
#'  \item{\code{leaf_count}: The number of observation collected by a leaf}
#' }
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#'
#' params <- list(
#'   objective = "binary"
#'   , learning_rate = 0.01
#'   , num_leaves = 63L
#'   , max_depth = -1L
#'   , min_data_in_leaf = 1L
#'   , min_sum_hessian_in_leaf = 1.0
#' )
#' model <- lgb.train(params, dtrain, 10L)
#'
#' tree_dt <- lgb.model.dt.tree(model)
#' }
#' @importFrom data.table := rbindlist
#' @importFrom jsonlite fromJSON
#' @export
lgb.model.dt.tree <- function(model, num_iteration = NULL) {

  json_model <- lgb.dump(booster = model, num_iteration = num_iteration)

  parsed_json_model <- jsonlite::fromJSON(
    txt = json_model
    , simplifyVector = TRUE
    , simplifyDataFrame = FALSE
    , simplifyMatrix = FALSE
    , flatten = FALSE
  )

  # Parse tree model
  tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse)

  # Combine into single data.table
  tree_dt <- data.table::rbindlist(l = tree_list, use.names = TRUE)

  # Substitute feature index with the actual feature name

  # Since the index comes from C++ (which is 0-indexed), be sure
  # to add 1 (e.g. index 28 means the 29th feature in feature_names)
  split_feature_indx <- tree_dt[, split_feature] + 1L

  # Get corresponding feature names. Positions in split_feature_indx
  # which are NA will result in an NA feature name
  feature_names <- parsed_json_model$feature_names[split_feature_indx]
  tree_dt[, split_feature := feature_names]

  return(tree_dt)

}


#' @importFrom data.table := data.table rbindlist
single.tree.parse <- function(lgb_tree) {

  # Traverse tree function
  pre_order_traversal <- function(env = NULL, tree_node_leaf, current_depth = 0L, parent_index = NA_integer_) {

    if (is.null(env)) {
      # Setup initial default data.table with default types
      env <- new.env(parent = emptyenv())
      env$single_tree_dt <- data.table::data.table(
        tree_index = integer(0L)
        , depth = integer(0L)
        , split_index = integer(0L)
        , split_feature = integer(0L)
        , node_parent = integer(0L)
        , leaf_index = integer(0L)
        , leaf_parent = integer(0L)
        , split_gain = numeric(0L)
        , threshold = numeric(0L)
        , decision_type = character(0L)
        , default_left = character(0L)
        , internal_value = integer(0L)
        , internal_count = integer(0L)
        , leaf_value = integer(0L)
        , leaf_count = integer(0L)
      )
      # start tree traversal
      pre_order_traversal(
        env = env
        , tree_node_leaf = tree_node_leaf
        , current_depth = current_depth
        , parent_index = parent_index
      )
    } else {

      # Check if split index is not null in leaf
      if (!is.null(tree_node_leaf$split_index)) {

        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("split_index",
                                                                                "split_feature",
                                                                                "split_gain",
                                                                                "threshold",
                                                                                "decision_type",
                                                                                "default_left",
                                                                                "internal_value",
                                                                                "internal_count")],
                                                               "depth" = current_depth,
                                                               "node_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)

        # Traverse tree again both left and right
        pre_order_traversal(
          env = env
          , tree_node_leaf = tree_node_leaf$left_child
          , current_depth = current_depth + 1L
          , parent_index = tree_node_leaf$split_index
        )
        pre_order_traversal(
          env = env
          , tree_node_leaf = tree_node_leaf$right_child
          , current_depth = current_depth + 1L
          , parent_index = tree_node_leaf$split_index
        )

      } else if (!is.null(tree_node_leaf$leaf_index)) {

        # update data.table
        env$single_tree_dt <- data.table::rbindlist(l = list(env$single_tree_dt,
                                                             c(tree_node_leaf[c("leaf_index",
                                                                                "leaf_value",
                                                                                "leaf_count")],
                                                               "depth" = current_depth,
                                                               "leaf_parent" = parent_index)),
                                                    use.names = TRUE,
                                                    fill = TRUE)

      }

    }
    return(env$single_tree_dt)
  }

  # Traverse structure
  single_tree_dt <- pre_order_traversal(tree_node_leaf = lgb_tree$tree_structure)

  # Store index
  single_tree_dt[, tree_index := lgb_tree$tree_index]

  return(single_tree_dt)

}

Try the lightgbm package in your browser

Any scripts or data that you put into this service are public.

lightgbm documentation built on Jan. 17, 2023, 1:13 a.m.