R/utils_lightgbm.R

Defines functions fmt_thresh_right fmt_thresh decision_right decision_left lgb2rules lgbtree2rules preorderlgb get_lgb_tree

# lightgbm_ops
# ::rtemis::
# 2023 EDG rtemis.org

# get_lgb_tree ----

#' Get LightGBM Booster Trees
#'
#' @return A list of trees
#' @keywords internal
#' @noRd
get_lgb_tree <- function(x, n_iter = -1) {
  out <- lapply(
    jsonlite::fromJSON(
      lightgbm::lgb.dump(
        booster = x,
        num_iteration = n_iter
      ),
      simplifyVector = FALSE
    )[["tree_info"]],
    \(y) y[["tree_structure"]]
  )
  names(out) <- paste0("Tree_", seq_along(out))
  out
} # rtemis::get_lgb_tree


# preorderlgb ----

#' Preorder Traversal of LightGBM Tree
#'
#' Called by `lgbtree2rules` and operates on `tree` environment in place.
#'
#' @param tree Environment that will hold the extracted rules.
#' @param node LightGBM tree.
#' @param rule Character: current rule.
#' @param left Character: left child label.
#' @param right Character: right child label.
#' @param split_feature Character: split feature label.
#' @param threshold Character: threshold label.
#' @param right_cat_type Character: "in" or "notin": operator for right categorical.
#' @param xnames Character vector: feature names.
#' @param factor_levels Named list of factor levels.
#' @param verbosity Integer: Verbosity level.
#'
#' @keywords internal
#' @noRd
preorderlgb <- function(
  tree,
  node,
  rule = "TRUE",
  left = "left_child",
  right = "right_child",
  split_feature = "split_feature",
  threshold = "threshold",
  right_cat_type = "in",
  xnames,
  factor_levels,
  verbosity = 0L
) {
  if (is.null(node[[split_feature]])) {
    names(rule) <- "leaf"
    if (verbosity > 0L) {
      message("Reached a leaf; rule is ", rule, ".")
    }
    tree[["leafs"]] <- c(tree[["leafs"]], rule)
    return(rule)
  }
  rule_left <- paste0(
    rule,
    " & ",
    xnames[node[[split_feature]] + 1],
    decision_left(node[["decision_type"]]),
    fmt_thresh(
      catsplit = node[["decision_type"]] == "==",
      feature = xnames[node[[split_feature]] + 1],
      threshold = node[["threshold"]],
      factor_levels = factor_levels
    )
  )
  rule_right <- paste0(
    rule,
    " & ",
    xnames[node[[split_feature]] + 1],
    decision_right(node[["decision_type"]], right_cat_type),
    fmt_thresh_right(
      catsplit = node[["decision_type"]] == "==",
      feature = xnames[node[[split_feature]] + 1],
      threshold = node[["threshold"]],
      factor_levels = factor_levels,
      cat_type = right_cat_type
    )
  )
  rule_left <- preorderlgb(
    tree,
    node[[left]],
    rule_left,
    left,
    right,
    split_feature,
    threshold,
    right_cat_type = right_cat_type,
    xnames = xnames,
    factor_levels = factor_levels,
    verbosity = verbosity
  )
  rule <- c(rule, rule_left)
  rule_right <- preorderlgb(
    tree,
    node[[right]],
    rule_right,
    left,
    right,
    split_feature,
    threshold,
    right_cat_type = right_cat_type,
    xnames = xnames,
    factor_levels = factor_levels,
    verbosity = verbosity
  )
  rule <- c(rule, rule_right)
} # rtemis::preorderlgb


# lgbtree2rules ----
lgbtree2rules <- function(x, xnames, factor_levels, right_cat_type = "in") {
  tree <- new.env()
  tree[["leafs"]] <- character()
  preorderlgb(
    tree,
    x,
    xnames = xnames,
    right_cat_type = right_cat_type,
    factor_levels = factor_levels
  )
  # remove root node "TRUE & "
  substr(tree[["leafs"]], 8, 99999)
} # rtemis::lgbtree2rules


# lgb2rules ----
#' Convert LightGBM Booster to set of rules
#'
#' @param x LightGBM Booster object
#' @param n_iter Integer: Number of trees to convert to rules
#' @param xnames Character vector: Names of features
#'
#' @return Character vector of rules
#' @keywords internal
#' @noRd
lgb2rules <- function(
  Booster,
  n_iter = NULL,
  xnames,
  factor_levels,
  right_cat_type = "in",
  return_unique = TRUE,
  verbosity = 1L
) {
  if (verbosity > 0L) {
    msg2start("Extracting LightGBM rules...")
  }
  if (is.null(n_iter)) {
    n_iter <- length(Booster)
  }
  trees <- get_lgb_tree(Booster, n_iter)
  rules <- lapply(trees, function(x) {
    lgbtree2rules(
      x,
      xnames,
      factor_levels = factor_levels,
      right_cat_type = right_cat_type
    )
  }) |>
    unlist()
  if (verbosity > 0L) {
    msg2done()
  }
  if (return_unique) unique(rules) else rules
} # rtemis::lgb2rules

method(extract_rules, class_lgb.Booster) <- function(
  x,
  n_iter = NULL,
  xnames,
  factor_levels,
  right_cat_type = "in",
  return_unique = TRUE,
  verbosity = 1L
) {
  if (verbosity > 0L) {
    msg2start("Extracting LightGBM rules...")
  }
  if (is.null(n_iter)) {
    n_iter <- length(x)
  }
  trees <- get_lgb_tree(x, n_iter)
  rules <- lapply(trees, function(x) {
    lgbtree2rules(
      x,
      xnames,
      factor_levels = factor_levels,
      right_cat_type = right_cat_type
    )
  }) |>
    unlist()

  rules <- if (return_unique) {
    unique(rules)
  } else {
    rules
  }
  if (verbosity > 0L) {
    msg2done()
    msg20(
      "Extracted ",
      hilite(length(rules)),
      ifelse(return_unique, " unique", ""),
      " rules."
    )
  }
  rules
} # /rtemis::extract_rules.lgb.Booster


# decision_left ----
decision_left <- function(decision_type) {
  switch(decision_type, "<=" = " <= ", "==" = " %in% ")
} # rtemis::decision_left

decision_right <- function(decision_type, cat_type) {
  switch(
    decision_type,
    "<=" = " > ",
    "==" = if (cat_type == "in") " %in% " else " %notin% "
  )
} # rtemis::decision_right


#' Format rule thresholds
#'
#' @param catsplit Logical: If TRUE, feature is categorical
#' @param feature Character: feature name
#' @param threshold Character: threshold as reported by lightgbm
#' @param factor_levels Named list of factor levels. Names should correspond to training
#' set column names.
#'
#' @keywords internal
#' @noRd
fmt_thresh <- function(catsplit, feature, threshold, factor_levels) {
  if (catsplit) {
    flevels <- as.integer(strsplit(threshold, "\\|\\|")[[1]]) + 1 # 0- to 1-based factor level index
    flevels <- factor_levels[[feature]][flevels]
    paste0(
      "c(",
      paste0("'", flevels, "'", collapse = ","),
      ")"
    )
  } else {
    threshold
  }
} # rtemis::fmt_thresh

#' @rdname fmt_thresh
#' @keywords internal
#' @noRd
fmt_thresh_right <- function(
  catsplit,
  feature,
  threshold,
  factor_levels,
  cat_type
) {
  if (catsplit) {
    flevels <- as.integer(strsplit(threshold, "\\|\\|")[[1]]) + 1 # 0- to 1-based factor level index
    flevels <- factor_levels[[feature]][flevels]
    if (cat_type == "in") {
      flevels <- setdiff(factor_levels[[feature]], flevels)
    }
    paste0(
      "c(",
      paste0("'", flevels, "'", collapse = ","),
      ")"
    )
  } else {
    threshold
  }
} # rtemis::fmt_thresh_right
egenn/rtemis documentation built on June 14, 2025, 11:54 p.m.