R/lightgbm_ops.R

Defines functions fmt_thresh_right fmt_thresh decision_right decision_left rtlgb2rules lgb2rules lgbtree2rules preorderlgb get_lgb_tree

Documented in preorderlgb

# lightgbm_ops
# ::rtemis::
# 2023 EDG lambdamd.org

# get_lgb_tree ----

#' Get LightGBM Booster Trees
#'
#' @returns A list of trees
#' @keywords internal
#' @noRd
get_lgb_tree <- function(x,
                         n_iter = -1) {
  out <- lapply(
    jsonlite::fromJSON(
      lightgbm::lgb.dump(
        booster = x,
        num_iteration = n_iter
      ),
      simplifyVector = FALSE
    )$tree_info,
    \(y) y$tree_structure
  )
  names(out) <- paste0("Tree_", seq_along(out))
  out
} # rtemis::get_lgb_tree


# preorderlgb ----

#' Preorder Traversal of LightGBM Tree
#'
#' Called by `lgbtree2rules` and operates on `tree` environment in place.
#'
#' @param tree Environment that will hold the extracted rules
#' @param node LightGBM tree
#' @param rule Character: current rule
#' @param left Character: left child label
#' @param right Character: right child label
#' @param split_feature Character: split feature label
#' @param threshold Character: threshold label
#' @param right_cat_type Character: "in" or "notin": operator for right categorical
#'
#' @keywords internal
preorderlgb <- function(tree,
                        node,
                        rule = "TRUE",
                        left = "left_child",
                        right = "right_child",
                        split_feature = "split_feature",
                        threshold = "threshold",
                        right_cat_type = "in",
                        xnames,
                        factor_levels,
                        trace = 0) {
  if (is.null(node[[split_feature]])) {
    names(rule) <- "leaf"
    if (trace > 0) {
      cat("Reached a leaf; rule is ", rule, "\n")
    }
    tree$leafs <- c(tree$leafs, rule)
    return(rule)
  }
  rule_left <- paste0(
    rule,
    " & ",
    xnames[node[[split_feature]] + 1],
    decision_left(node[["decision_type"]]),
    fmt_thresh(
      catsplit = node[["decision_type"]] == "==",
      feature = xnames[node[[split_feature]] + 1],
      threshold = node[["threshold"]],
      factor_levels = factor_levels
    )
  )
  rule_right <- paste0(
    rule,
    " & ",
    xnames[node[[split_feature]] + 1],
    decision_right(node[["decision_type"]], right_cat_type),
    fmt_thresh_right(
      catsplit = node[["decision_type"]] == "==",
      feature = xnames[node[[split_feature]] + 1],
      threshold = node[["threshold"]],
      factor_levels = factor_levels,
      cat_type = right_cat_type
    )
  )
  rule_left <- preorderlgb(
    tree,
    node[[left]], rule_left, left, right, split_feature,
    threshold,
    right_cat_type = right_cat_type,
    xnames = xnames,
    factor_levels = factor_levels,
    trace = trace
  )
  rule <- c(rule, rule_left)
  rule_right <- preorderlgb(
    tree,
    node[[right]], rule_right, left, right, split_feature,
    threshold,
    right_cat_type = right_cat_type,
    xnames = xnames,
    factor_levels = factor_levels,
    trace = trace
  )
  rule <- c(rule, rule_right)
} # rtemis::preorderlgb


# lgbtree2rules ----
lgbtree2rules <- function(x,
                          xnames,
                          factor_levels,
                          right_cat_type = "in") {
  tree <- new.env()
  tree$leafs <- character()
  preorderlgb(
    tree,
    x,
    xnames = xnames,
    right_cat_type = right_cat_type,
    factor_levels = factor_levels
  )
  # remove root node "TRUE & "
  substr(tree$leafs, 8, 99999)
} # rtemis::lgbtree2rules


# lgb2rules ----
#' Convert LightGBM Booster to set of rules
#'
#' @param x LightGBM Booster object
#' @param n_iter Integer: Number of trees to convert to rules
#' @param xnames Character vector: Names of features
#'
#' @return Character vector of rules
#' @keywords internal
#' @noRd
lgb2rules <- function(Booster,
                      n_iter = NULL,
                      xnames,
                      factor_levels,
                      right_cat_type = "in",
                      return_unique = TRUE) {
  if (is.null(n_iter)) n_iter <- length(Booster)
  trees <- get_lgb_tree(Booster, n_iter)
  rules <- lapply(
    trees,
    \(x) lgbtree2rules(
      x,
      xnames,
      factor_levels = factor_levels,
      right_cat_type = right_cat_type
    )
  ) |> unlist()
  if (return_unique) unique(rules) else rules
} # rtemis::lgb2rules


#' Extract set of rules from LightGBM rtMod
#'
#' @param rtmod LightGBM rtMod object
#' @param dat Data frame used to train rtmod
#' @param n_iter Integer: Number of trees to convert to rules
#'
#' @return Character vector of rules
#' @keywords internal
#' @noRd
rtlgb2rules <- function(rtmod,
                        dat,
                        n_iter = NULL,
                        right_cat_type = "in",
                        return_unique = TRUE) {
  lgb2rules(
    Booster = rtmod$mod,
    n_iter = n_iter,
    xnames = rtmod$xnames,
    factor_levels = dt_get_factor_levels(dat),
    right_cat_type = right_cat_type,
    return_unique = return_unique
  )
} # rtemis::rtlgb2rules


# decision_left ----
decision_left <- function(decision_type) {
  switch(decision_type,
    "<=" = " <= ",
    "==" = " %in% "
  )
} # rtemis::decision_left

decision_right <- function(decision_type, cat_type) {
  switch(decision_type,
    "<=" = " > ",
    "==" = if (cat_type == "in") " %in% " else " %notin% "
  )
} # rtemis::decision_right


#' Format rule thresholds
#'
#' @param catsplit Logical: If TRUE, feature is categorical
#' @param feature Character: feature name
#' @param threshold Character: threshold as reported by lightgbm
#' @param factor_levels Named list of factor levels. Names should correspond to training
#' set column names.
#'
#' @keywords internal
#' @noRd
fmt_thresh <- function(catsplit,
                       feature,
                       threshold,
                       factor_levels) {
  if (catsplit) {
    flevels <- as.integer(strsplit(threshold, "\\|\\|")[[1]]) + 1 # 0- to 1-based factor level index
    flevels <- factor_levels[[feature]][flevels]
    paste0(
      "c(",
      paste0("'", flevels, "'", collapse = ","),
      ")"
    )
  } else {
    threshold
  }
} # rtemis::fmt_thresh

#' @rdname fmt_thresh
#' @keywords internal
#' @noRd
fmt_thresh_right <- function(catsplit,
                             feature,
                             threshold,
                             factor_levels,
                             cat_type) {
  if (catsplit) {
    flevels <- as.integer(strsplit(threshold, "\\|\\|")[[1]]) + 1 # 0- to 1-based factor level index
    flevels <- factor_levels[[feature]][flevels]
    if (cat_type == "in") {
      flevels <- setdiff(factor_levels[[feature]], flevels)
    }
    paste0(
      "c(",
      paste0("'", flevels, "'", collapse = ","),
      ")"
    )
  } else {
    threshold
  }
} # rtemis::fmt_thresh_right
egenn/rtemis documentation built on May 4, 2024, 7:40 p.m.