R/get_rules.R

Defines functions get_rules

Documented in get_rules

# get_rules.R
# ::rtemis::
# 2023 EDG lambdamd.org

#' Get RuleFit rules
#'
#' Get rules generated by [s_RuleFit] or [s_LightRuleFit]
#'
#' @param mod Model created by [s_RuleFit] or [s_LightRuleFit]
#' @param formatted Logical: If TRUE, return human-readable rules, otherwise return
#' R-parsable rules
#' @param collapse Logical: If TRUE, collapse all rules to a single character vector
#' @param collapse.keep.names Logical: If TRUE, keep names when collapsing (will
#' be able to tell which run each rule came from). However, has no effect if
#' `collapse.unique = TRUE`, as `unique()` removes names.
#' @param collapse.unique Logical: If TRUE and `collapse = TRUE`, will return only
#' unique rules
#'
#' @author ED Gennatas
#' @export
get_rules <- function(
    mod,
    formatted = FALSE,
    collapse = TRUE,
    collapse.keep.names = FALSE,
    collapse.unique = TRUE) {
  mod.name <- mod$mod.name
  stopifnot(mod.name %in% c("RuleFit", "LightRuleFit"))
  if (inherits(mod, "rtMod")) {
    if (formatted) {
      out <- if (mod.name == "RuleFit") {
        mod$mod$rules.selected.formatted
      } else {
        mod$mod$rules_selected_formatted
      }
    } else {
      out <- if (mod.name == "RuleFit") {
        mod$mod$rules.selected
      } else {
        mod$mod$rules_selected
      }
    }
  } else if (inherits(mod, "rtModCV")) {
    if (formatted) {
      out <- if (mod.name == "RuleFit") {
        lapply(mod$mod, \(rep) lapply(rep, \(mod) mod$mod1$mod$rules.selected.formatted))
      } else {
        lapply(mod$mod, \(rep) lapply(rep, \(mod) mod$mod1$mod$rules_selected_formatted))
      }
    } else {
      out <- if (mod.name == "RuleFit") {
        lapply(mod$mod, \(rep) lapply(rep, \(mod) mod$mod1$mod$rules.selected))
      } else {
        lapply(mod$mod, \(rep) lapply(rep, \(mod) mod$mod1$mod$rules_selected))
      }
    }
    if (length(out) == 1) {
      out <- out[[1]]
    } else {
      names(out) <- paste0("Repeat", seq_along(out))
    }
  }

  if (collapse) {
    out <- unlist(out, recursive = TRUE, use.names = collapse.keep.names)
    if (collapse.unique) {
      out <- unique(out)
    }
  }
  out
} # rtemis::get_rules
egenn/rtemis documentation built on May 4, 2024, 7:40 p.m.