R/xrf.R

Defines functions xgb_params smaller_formula print.xrf summary.xrf coef.xrf synthesize_conjunctions predict.xrf model.matrix.xrf xrf.formula xrf dedupe_train_rules remove_no_variance_rules evaluate_rules_dense_only evaluate_rules correct_xgb_sparse_categoricals has_matching_level build_feature_metadata drop_zero_row_tbl extract_xgb_rules rule_traverse augment_rules get_xgboost_objective xrf_preconditions condition_xgb_control

Documented in coef.xrf model.matrix.xrf predict.xrf print.xrf summary.xrf xrf xrf.formula

#############################################
## functions for preconditions on user input
#############################################

condition_xgb_control <- function(
  family,
  xgb_control,
  data,
  response_var,
  prefit_xgb,
  call = rlang::caller_env()
) {
  # this is a duplicated but necessary check
  if (!(response_var %in% colnames(data))) {
    cli::cli_abort(
      "Response variable {.var {response_var}} not present in supplied data",
      call = call
    )
  }

  data_mutated <- data

  if (
    family == 'multinomial' &&
      is.null(xgb_control$num_class) &&
      is.null(prefit_xgb)
  ) {
    n_classes <- n_distinct(data[[response_var]])
    cli::cli_warn(
      "Detected {n_classes} class{?es} to set num_class xgb_control parameter",
      call = call
    )
    xgb_control$num_class <- n_distinct(data[[response_var]])
  }

  # xgboost expects multinomial labels to be 0:num_class
  if (
    family == 'multinomial' &&
      (is.factor(data[[response_var]]) ||
        is.character(data[[response_var]]))
  ) {
    integer_response <- as.integer(as.factor(data[[response_var]]))
    data_mutated[[response_var]] <- integer_response - min(integer_response)
  } else if (
    family == 'binomial' &&
      is.factor(data[[response_var]]) ||
      is.character(data[[response_var]])
  ) {
    integer_response <- as.integer(as.factor(data[[response_var]]))
    data_mutated[[response_var]] <- integer_response - min(integer_response)
  }

  xgb_control$objective <- get_xgboost_objective(family, call = call)
  list(xgb_control = xgb_control, data = data_mutated)
}

xrf_preconditions <- function(
  family,
  xgb_control,
  glm_control,
  data,
  response_var,
  prefit_xgb,
  call = rlang::caller_env()
) {
  supported_families <- c('gaussian', 'binomial', 'multinomial')
  if (!(family %in% supported_families)) {
    cli::cli_abort(
      "Family {.val {family}} is not currently supported. Supported families
      are: {.val {supported_families}}.",
      call = call
    )
  }

  if (!(response_var %in% colnames(data))) {
    cli::cli_abort(
      "Response variable {.var {response_var}} not present in supplied data",
      call = call
    )
  }

  if (any(is.na(data[[response_var]]))) {
    cli::cli_abort(
      "The response variable contains missing values.",
      call = call
    )
  }

  if (n_distinct(data[[response_var]]) <= 1) {
    # TODO cv.glmnet will still warn/fail when there is a very small number of observations per class for logistic regression
    cli::cli_abort(
      "Response variable shows no variation, model cannot be fit.",
      call = call
    )
  }

  if (
    family == 'multinomial' &&
      ((is.null(xgb_control$num_class) ||
        n_distinct(data[[response_var]]) != xgb_control$num_class) &&
        is.null(prefit_xgb))
  ) {
    cli::cli_abort(
      "Must supply a {.arg num_class} list element in {.arg xgb_control} when using
      multinomial objective",
      call = call
    )
  }

  glm_control_nms <- c('type.measure', 'nfolds', 'foldid')
  if (length(intersect(glm_control_nms, names(glm_control))) < 2) {
    cli::cli_abort(
      "Must supply {.arg type.measure} and ({.arg nfolds} or {.arg foldid})
      as {.arg glm_control} parameters",
      call = call
    )
  }

  allowed_tree_ensemble_classes <- c('xgb.Booster')
  if (
    !is.null(prefit_xgb) &&
      length(intersect(allowed_tree_ensemble_classes, class(prefit_xgb))) == 0
  ) {
    cli::cli_abort(
      "Prefit tree ensemble must be of class {.cls {allowed_tree_ensemble_classes}}.",
      call = call
    )
  }

  features_with_commas <- grepl(',', colnames(data), fixed = TRUE)
  if (any(features_with_commas)) {
    feature_names <- colnames(data)[features_with_commas]
    cli::cli_abort(
      "The column names {.val {features_with_commas}} contain illegal characters.",
      call = call
    )
  }
}

## the choice of ensemble loss is currently hidden from the api to protect implementation details
## this may be exposed to the user in the future
get_xgboost_objective <- function(family, call = rlang::caller_env()) {
  if (family == 'gaussian') {
    return('reg:squarederror')
  } else if (family == 'binomial') {
    return('binary:logistic')
  } else if (family == 'multinomial') {
    return('multi:softmax')
  }

  cli::cli_abort(
    "Unrecognized family {.val {family}} which should have failed fast in preconditions.",
    call = call
  )
}

#############################################
## functions for extracting xgboost rule sets
#############################################

augment_rules <- function(row, rule_ids, less_than) {
  bind_rows(lapply(rule_ids, function(rule_id) {
    list(
      split_id = row$ID,
      rule_id = rule_id,
      feature = row$Feature,
      split = row$Split,
      less_than = less_than
    )
  }))
}

# this is of course slow, but it shouldn't be a bottleneck due to ensembles generally small and tree depth < 6
rule_traverse <- function(row, tree) {
  if (row$Feature == 'Leaf') {
    return(
      data.frame(
        split_id = row$ID,
        rule_id = paste0('r', gsub('-', '_', row$ID)),
        # leaf nodes uniquely identify a rule
        feature = NA,
        split = NA,
        less_than = NA,
        stringsAsFactors = FALSE
      )
    )
  } else {
    # the Yes/No obfuscates the simplicity of the algo - in order tree traversal
    left_child <- tree[tree$ID == row$Yes, ]
    stopifnot(nrow(left_child) == 1) # this can be trusted from xgboost, but fail if that changes
    right_child <- tree[tree$ID == row$No, ]
    stopifnot(nrow(right_child) == 1)

    # recursion will bubble up the conjunctive rule to this split
    left_rules <- rule_traverse(left_child, tree)
    right_rules <- rule_traverse(right_child, tree)

    left_rules_augmented <- augment_rules(
      row,
      unique(left_rules$rule_id),
      less_than = TRUE
    )
    right_rules_augmented <- augment_rules(
      row,
      unique(right_rules$rule_id),
      less_than = FALSE
    )

    return(
      rbind(
        left_rules_augmented,
        right_rules_augmented,
        left_rules,
        right_rules,
        stringsAsFactors = FALSE
      )
    )
  }
}

# note that xgboost produces nrounds * classes number of trees for multi-class classification - we proceed by extracting all rules from all trees, treating them equally
# this function produces a data frame with columns rule_id (which idenitifies a total conjunction), feature, split, and less_than
# each row represents an individual boolean clause evaluated by feature (less_than ? < : >=) split
extract_xgb_rules <- function(m) {
  rules <- xgb.model.dt.tree(model = m) |>
    group_by(Tree) |>
    arrange(Node) |> # put the root at the top of each tree group
    do(
      harvested_rules = rule_traverse(.data[1, ], .data) |>
        filter(!is.na(feature))
    ) |>
    pull(harvested_rules) |>
    lapply(drop_zero_row_tbl) |>
    bind_rows()

  rules
}

drop_zero_row_tbl <- function(tbl) {
  if (nrow(tbl) == 0) {
    return(NULL)
  }

  tbl
}


##################################################
## functions for parsing out model matrix features
##################################################

build_feature_metadata <- function(data) {
  all_features <- data.frame(
    feature_name = colnames(data),
    stringsAsFactors = FALSE
  )

  feature_metadata <- all_features |>
    mutate(
      is_continuous = sapply(feature_name, function(fname) {
        is.numeric(data[[fname]])
      })
    )

  xlev <- data |>
    select_if(function(x) {
      !is.numeric(x)
    }) |>
    lapply(function(x) {
      if (is.factor(x)) {
        levels(x)
      } else {
        as.character(unique(x))
      }
    })

  list(xlev = xlev, feature_metadata = feature_metadata)
}

has_matching_level <- function(feature_name, level_remainder, xlev) {
  for (ix in seq_along(feature_name)) {
    fn <- feature_name[ix]
    lr <- level_remainder[ix]

    if (lr %in% xlev[[fn]]) {
      return(TRUE)
    }
  }

  return(FALSE)
}

correct_xgb_sparse_categoricals <- function(
  rules,
  feature_metadata,
  xlev,
  # .5 matches what xgboost does with dense matrices
  categorical_split_value = .5,
  call = rlang::caller_env()
) {
  if (nrow(rules) == 0) {
    return(rules)
  }

  for (row_ix in 1:nrow(rules)) {
    feature_level <- rules[row_ix, 'feature']
    classified_features <- feature_metadata |>
      mutate(
        level_remainder = sapply(feature_name, function(fn) {
          lstrip(feature_level, fn)
        }),
        may_be_rule_feature = sapply(feature_name, function(fn) {
          !startsWith(feature_level, fn)
        })
      )

    feature_level_matches <- classified_features |>
      filter(!may_be_rule_feature) |>
      filter(
        level_remainder == '' |
          has_matching_level(feature_name, level_remainder, xlev)
      )

    if (nrow(feature_level_matches) > 1) {
      # this means that several feaures and their levels may be concatenated to produce the same column name
      # e.g. feature "ora" with level "nge" and another feature "oran" with level "ge". or even a continuous with name "orange"
      cli::cli_abort(
        "In attempting to parse sparse design matrix columns, several feature/level
        matches found for: {.val {feature_level}}. Conservatively failing to user
        to change feature/level names or use dense matrices.",
        call = call
      )
    } else if (nrow(feature_level_matches) == 0) {
      # the feature couldn't be found. this is usually because a transformation was applied via the formula
      cli::cli_abort(
        c(
          "In attempting to parse sparse design matrix columns, no feature/level
          matches found for: {.val {feature_level}}.",
          "i" = "This is often caused by supplying a transformation in the input formula.",
          "i" = "User may either transform source data and use main effects only
          formula or set argument {.arg sparse}={.val FALSE}.",
          call = call
        )
      )
    }

    if (!feature_level_matches$is_continuous) {
      # xgb always makes the split value negative, so that "Missing" (= 0 one-hot) really maps to "Yes" (the left, less than split)
      # and the right, greater than split (1 one-hot) maps to "No"
      # as such, we don't have to invert the inequality ("less_than")
      # of course, this is reliant on, as far as I can tell, undocumented/unspecified behavior in XGBoost. So the durability isn't great, but:
      # 1. it doesn't seem liable to change (https://github.com/dmlc/xgboost/issues/1112)
      # 2. that lack of specification (dare I call it a bug) is the whole reason we have to do this exercise in xrf
      rules[row_ix, 'split'] <- categorical_split_value
    }
  }

  rules
}

#############################################
## functions for evaluating rulesets
#############################################

evaluate_rules <- function(rules, data) {
  per_rule_evaluation <- rules |>
    group_by(rule_id) |>
    do(
      rule_evaluation = sapply(1:nrow(.data), function(split_ix) {
        split <- .data[split_ix, ]
        feature_ix <- which(split$feature == colnames(data))
        if (length(feature_ix) == 0) {
          cli::cli_abort(
            "Feature {.val {split$feature}} from ruleset is not present in the
            input data to be evaluated",
            call = call
          )
        } else if (length(feature_ix) > 1) {
          cli::cli_abort(
            "Unexpectedly found {.val {split$feature}} twice in input data
            (user must resolve).",
            call = call
          )
        }
        split_comparison <- data[, feature_ix] < split$split
        return(split_comparison == split$less_than)
      }) |>
        apply(1, all) |>
        as.integer() |>
        data.frame()
    )
  rule_features <- bind_cols(
    per_rule_evaluation$rule_evaluation,
    .name_repair = "minimal"
  )
  colnames(rule_features) <- per_rule_evaluation$rule_id

  rule_features
}

evaluate_rules_dense_only <- function(rules, data) {
  data_df <- as.data.frame(data)
  per_rule_evaluation <- rules |>
    group_by(rule_id) |>
    do(
      # yes, this is gross
      # yes, this is fast
      rule_evaluation = eval(
        parse(
          text = paste0(
            paste0('`', .data$feature, '`'),
            ifelse(.data$less_than, ' < ', ' >= '),
            .data$split,
            collapse = ' & '
          )
        ),
        data_df
      ) |>
        as.integer() |>
        data.frame()
    )
  rule_features <- bind_cols(
    per_rule_evaluation$rule_evaluation,
    .name_repair = "minimal"
  )
  colnames(rule_features) <- per_rule_evaluation$rule_id

  rule_features
}

#############################################
## functions for cleaning up evaluated rules
#############################################

# returns the list of rules with non-zero variance
# if, by an unexpected outcome of the tree fitting process, a rule shows no variance, remove it
remove_no_variance_rules <- function(evaluated_rules) {
  keep_columns <- sapply(evaluated_rules, function(feature) {
    length(unique(feature)) > 1
  })

  return(colnames(evaluated_rules)[keep_columns])
}

# removes any exactly equal rules
dedupe_train_rules <- function(evaluated_rules) {
  as.matrix(evaluated_rules) |>
    unique(MARGIN = 2) |>
    colnames()
}

#' Fit an eXtreme RuleFit model
#'
#' S3 method for building an "eXtreme RuleFit" model.
#' See \code{\link{xrf.formula}} for preferred entry point
#'
#' @param object an object describing the model to be fit
#' @param ... additional arguments
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#'
#' @export
xrf <- function(object, ...) {
  UseMethod('xrf', object)
}

#' Fit an eXtreme RuleFit model
#'
#' See Friedman & Popescu (2008) for a description of the general RuleFit algorithm.
#' This method uses XGBoost to fit a tree ensemble, extracts a ruleset as the conjunction of tree
#' traversals, and fits a sparse linear model to the resulting feature set
#' (including the original feature set) using glmnet.
#'
#' @param object a formula prescribing features to use in the model. transformation of the response variable is not supported. when using transformations on the input features (not suggested in general) it is suggested to set sparse=F
#' @param data a data frame with columns corresponding to the formula
#' @param family the family of the fitted model. one of 'gaussian', 'binomial', 'multinomial'
#' @param xgb_control a list of parameters for xgboost. must supply an nrounds argument
#' @param glm_control a list of parameters for the glmnet fit. must supply a type.measure and nfolds arguments (for the lambda cv)
#' @param sparse whether a sparse design matrix should be used
#' @param prefit_xgb an xgboost model (of class xgb.Booster) to be used instead of the model that \code{xrf} would normally fit
#' @param deoverlap if true, the tree derived rules are deoverlapped, in that the deoverlapped rule set contains no overlapped rules
#' @param ... ignored arguments
#'
#' @importFrom xgboost xgboost
#' @importFrom xgboost xgb.model.dt.tree
#' @import dplyr
#' @importFrom Matrix sparse.model.matrix
#' @importFrom rlang .data
#' @importFrom stats as.formula
#' @importFrom stats coef
#' @importFrom stats delete.response
#' @importFrom stats model.matrix
#' @importFrom stats predict
#' @importFrom stats terms
#' @importFrom stats update
#'
#' @details
#'
#' In November 2025, the new version of \pkg{xgboost} (3.1.2.1) introduced
#' significant breaking changes. This version of \pkg{xrf} can reproduce
#' predictions from older versions of \pkg{xgboost}. However, there are likely
#' to be differences in \pkg{xrf} model fits between old and new versions of
#' \pkg{xgboost}.
#'
#' @references
#' Friedman, J. H., & Popescu, B. E. (2008). Predictive learning via rule
#' ensembles. \emph{The Annals of Applied Statistics, 2}(3), 916-954.
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#'
#' @export
xrf.formula <- function(
  object,
  data,
  family,
  xgb_control = list(nrounds = 100, max_depth = 3),
  glm_control = list(type.measure = 'deviance', nfolds = 5),
  sparse = TRUE,
  prefit_xgb = NULL,
  deoverlap = FALSE,
  ...
) {
  expanded_formula <- expand_formula(object, data)
  response_var <- get_response(expanded_formula)

  xgboost_conditioned <- condition_xgb_control(
    family,
    xgb_control,
    data,
    response_var,
    prefit_xgb
  )
  xgb_control <- xgboost_conditioned$xgb_control
  data <- xgboost_conditioned$data
  xrf_preconditions(
    family,
    xgb_control,
    glm_control,
    data,
    response_var,
    prefit_xgb
  )

  model_matrix_method <- if (sparse) {
    sparse.model.matrix
  } else {
    model.matrix
  }
  design_matrix <- model_matrix_method(expanded_formula, data)

  nrounds <- xgb_control$nrounds
  # necessary to remove from params to avoid false positive warnings
  xgb_control <- within(xgb_control, rm(nrounds))

  if (is.null(prefit_xgb)) {
    m_xgb <- xgboost::xgb.train(
      xgboost::xgb.DMatrix(design_matrix, label = data[[response_var]]),
      nrounds = nrounds,
      params = xgb_params(xgb_control),
      verbose = 0
    )
    rules <- extract_xgb_rules(m_xgb)
  } else {
    m_xgb <- prefit_xgb
    rules <- extract_xgb_rules(m_xgb)
    if (length(setdiff(rules$feature, colnames(design_matrix))) > 0) {
      cli::cli_abort(
        "prefit_xgb contains features (or factor-levels) not present in the
        input training data. This is currently not supported."
      )
      # one simple approach would be to simply remove these feature splits from the rules
      # but that potentially dilutes the power of this method. for now, it's on the user to rectify this issue
    }
  }

  if (sparse) {
    feature_metadata <- build_feature_metadata(data)
    rules <- correct_xgb_sparse_categoricals(
      rules,
      feature_metadata$feature_metadata,
      feature_metadata$xlev
    )
  }

  if (deoverlap) {
    rules <- xrf_deoverlap_rules(rules) |>
      select(rule_id, feature, split, less_than)
  }

  rule_features <- if (sparse) {
    evaluate_rules(rules, design_matrix)
  } else {
    evaluate_rules_dense_only(rules, design_matrix)
  }

  varying_rules <- remove_no_variance_rules(rule_features)
  rule_features <- rule_features[, varying_rules]
  rules <- rules |>
    filter(rule_id %in% varying_rules)

  non_duplicate_rules <- dedupe_train_rules(rule_features)
  rule_features <- rule_features[, non_duplicate_rules]
  rules <- rules |>
    filter(rule_id %in% non_duplicate_rules)

  overlapped_feature_names <- intersect(colnames(rule_features), colnames(data))
  if (length(overlapped_feature_names) > 0) {
    cli::cli_warn(
      "Found overlapped raw feature & rule names {overlapped_feature_names}
      (the rule features will be dropped)."
    )
    rule_features <- rule_features[,
      !(colnames(rule_features) %in% overlapped_feature_names)
    ]
  }

  # todo we already have a design matrix, so re-generating it with glmnot is a bit wasteful
  full_data <- cbind(data, rule_features, stringsAsFactors = FALSE)

  # todo glmnet is a bottleneck on data size - it may be interesting to fit the glm to much larger data, e.g. with spark or biglasso
  full_formula <- add_predictors(expanded_formula, colnames(rule_features))

  # glmnet automatically adds an intercept
  full_formula <- update(full_formula, . ~ . - 1)

  m_glm <- glmnot(
    full_formula,
    full_data,
    family = family,
    alpha = 1,
    # this specifies the LASSO
    sparse = sparse,
    glm_control = glm_control
  )

  structure(
    list(
      glm = m_glm,
      xgb = m_xgb,
      base_formula = expanded_formula,
      rule_augmented_formula = full_formula,
      rules = rules
    ),
    class = 'xrf'
  )
}

#' Generate the design matrix from an eXtreme RuleFit model
#'
#' @param object an object of class "xrf"
#' @param data data to generate design matrix from
#' @param sparse a logical indicating whether a sparse design matrix should be used
#' @param ... ignored arguments
#'
#' @importFrom Matrix sparse.model.matrix
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#' design <- model.matrix(m, iris, sparse = FALSE)
#'
#' @export
model.matrix.xrf <- function(object, data, sparse = TRUE, ...) {
  # TODO: handle missing factor levels more elegantly (both for rule evaluation & glmnet)
  # TODO: when rules have a zero coefficient and we just want to predict, we don't need to evaluate them
  stopifnot(is.data.frame(data))

  trms <- terms(object$base_formula)
  trms <- delete.response(trms)

  design_matrix_method <- if (sparse) {
    sparse.model.matrix
  } else {
    model.matrix
  }

  raw_design_matrix <- design_matrix_method(trms, data)
  rules_features <- if (sparse) {
    evaluate_rules(object$rules, raw_design_matrix)
  } else {
    evaluate_rules_dense_only(object$rules, raw_design_matrix)
  }
  full_data <- cbind(data, rules_features, stringsAsFactors = FALSE)

  full_data
}

#' Draw predictions from a RuleFit xrf model
#'
#' @param object an object of class "xrf"
#' @param newdata data to predict on
#' @param sparse a logical indicating whether a sparse design matrix should be used
#' @param lambda the lasso penalty parameter to be applied
#' @param type the type of predicted value produced
#' @param ... ignored arguments
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#' predictions <- predict(m, iris)
#'
#' @export
predict.xrf <- function(
  object,
  newdata,
  sparse = TRUE,
  lambda = 'lambda.min',
  type = 'response',
  ...
) {
  if (!is.data.frame(newdata)) {
    cli::cli_abort("{.arg newdata} should be a data frame.")
  }
  full_data <- model.matrix(object, newdata, sparse)

  predict(
    object$glm,
    newdata = full_data,
    sparse = sparse,
    lambda = lambda,
    type = type
  )
}

synthesize_conjunctions <- function(rules) {
  rules |>
    group_by(rule_id) |>
    arrange(feature, split) |>
    summarize(
      conjunction = paste0(
        feature,
        ifelse(less_than, '<', '>='),
        format(split, scientific = FALSE, digits = 4),
        collapse = ' & '
      )
    )
}

#' Produce rules & coefficients for the RuleFit model
#'
#' @param object an object of class "xrf"
#' @param lambda the lasso penalty parameter to be applied as in 'glmnet'
#' @param ... ignored arguments
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#' linear_model_coefficients <- coef(m, lambda = 'lambda.1se')
#'
#' @export
coef.xrf <- function(object, lambda = 'lambda.min', ...) {
  rule_conjunctions <- synthesize_conjunctions(object$rules)
  glm_coefficients <- coef(object$glm, s = lambda)
  glm_df <- as.data.frame(as.matrix(glm_coefficients))
  colnames(glm_df) <- sapply(lambda, function(lambda_value) {
    paste0('coefficient_', lambda)
  })
  glm_df$term <- rownames(glm_df)
  rownames(glm_df) <- NULL
  glm_df |>
    left_join(rule_conjunctions, by = c('term' = 'rule_id')) |>
    arrange_at(colnames(glm_df[1])) |>
    mutate(rule = conjunction) |>
    select(-conjunction)
}

#' Summarize an eXtreme RuleFit model
#'
#' @param object an object of class "xrf"
#' @param ... ignored arguments
#'
#' @import dplyr
#' @importFrom methods show
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#' summary(m)
#'
#' @export
summary.xrf <- function(object, ...) {
  cat(paste0(
    'An eXtreme RuleFit model of ',
    n_distinct(object$rules$rule_id),
    ' rules.'
  ))
  cat(paste0('\n\nOriginal Formula:\n\n'))
  cat(smaller_formula(object$base_formula))
  cli::cli_rule("Tree model")
  cat("\n")
  print(object$xgb)
  cat("\n")
  cli::cli_rule("GLM")
  print(object$glm$model)
  invisible(object)
}

#' Print an eXtreme RuleFit model
#'
#' @param x an object of class "xrf"
#' @param ... ignored arguments
#'
#' @examples
#' m <- xrf(Petal.Length ~ ., iris,
#'          xgb_control = list(nrounds = 2, max_depth = 2),
#'          family = 'gaussian')
#' print(m)
#'
#' @export
print.xrf <- function(x, ...) {
  cat(paste0(
    'An eXtreme RuleFit model of ',
    n_distinct(x$rules$rule_id),
    ' rules.'
  ))
  cat(paste0('\n\nOriginal Formula:\n\n'))
  cat(smaller_formula(x$base_formula), "\n")
}

smaller_formula <- function(x, ...) {
  chr_form <- deparse(x, width.cutoff = getOption("width") - 12)
  if (length(chr_form) > 1) {
    chr_form <- paste0(chr_form[1], "[truncated]")
  }
  chr_form
}

xgb_params <- function(x, call = rlang::caller_env()) {
  cl <- rlang::call2("xgb.params", .ns = "xgboost", !!!x)
  res <- try(rlang::eval_tidy(cl, data = x), silent = TRUE)
  if (inherits(res, "try-error")) {
    msg <- as.character(res)
    msg <- strsplit(msg, split = ":")[[1]][-(1:3)]
    msg <- gsub("\\n", "", msg)
    msg <- trimws(msg)
    msg <- paste0(msg, collapse = "")
    cli::cli_abort(
      "There was an error when parsing the xgboost arguments: {msg}",
      call = call
    )
  }
  res
}

Try the xrf package in your browser

Any scripts or data that you put into this service are public.

xrf documentation built on Dec. 5, 2025, 1:06 a.m.