R/orsf_pd.R

Defines functions orsf_dependence orsf_ice_new orsf_ice_inb orsf_ice_oob orsf_pd_new orsf_pd_inb orsf_pd_oob pred_spec_auto

Documented in orsf_ice_inb orsf_ice_new orsf_ice_oob orsf_pd_inb orsf_pd_new orsf_pd_oob pred_spec_auto

#' Automatic variable values for dependence
#'
#' For partial dependence and individual conditional expectations,
#'   this function allows a variable to be considered without having
#'   to specify what values to set the variable at. The values used
#'   are based on quantiles for continuous variables (10th, 25th, 50th,
#'   75th, and 90th) and unique categories for categorical variables.
#'
#' @param ... names of the variables to use. These can be in quotes
#'   or not in quotes (see examples).
#'
#' @return a character vector with the names
#'
#' @details This function should only be used in the context of
#'   `orsf_pd` or `orsf_ice` functions.
#'
#'
#' @export
#'
#' @examples
#'
#' fit <- orsf(penguins_orsf, species ~., n_tree = 5)
#'
#' orsf_pd_oob(fit, pred_spec_auto(flipper_length_mm))
#'
pred_spec_auto <- function(...){

 input_string <- gsub(".*\\((.*)\\).*", "\\1", match.call())[-1]
 result <- trimws(unlist(strsplit(input_string, ",")))
 class(result) <- c("pspec_auto", class(result))

 result

}


#' Partial dependence
#'
#' Compute partial dependence for an oblique random forest.
#' `r roxy_pd_explain()`
#' `r roxy_pd_oob_explain('partial dependence')`
#'
#' @inheritParams predict.ObliqueForest
#'
#'
#' @param pred_spec (*named list*, *pspec_auto*, or *data.frame*).
#'
#'  -  If `pred_spec` is a named list,
#'   Each item in the list should be a vector of values that will be used as
#'   points in the partial dependence function. The name of each item in the
#'   list should indicate which variable will be modified to take the
#'   corresponding values.
#'
#'  - If `pred_spec` is created using `pred_spec_auto()`, all that is needed
#'    is the names of variables to use (see [pred_spec_auto]).
#'
#'  - If `pred_spec` is a `data.frame`, columns will
#'   indicate variable names, values will indicate variable values, and
#'   partial dependence will be computed using the inputs on each row.
#'
#' @param pred_type (_character_) the type of predictions to compute. Valid
#'   Valid options for survival are:
#'
#'   - 'risk' : probability of having an event at or before `pred_horizon`.
#'   - 'surv' : 1 - risk.
#'   - 'chf': cumulative hazard function
#'   - 'mort': mortality prediction
#'   - 'time': survival time prediction
#'
#'  For classification:
#'
#'  - 'prob': probability for each class
#'
#'  For regression:
#'
#'  - 'mean': predicted mean, i.e., the expected value
#'
#' @param na_action `r roxy_na_action_header("new_data")`
#'
#'   - `r roxy_na_action_fail("new_data")`
#'   - `r roxy_na_action_omit("new_data")`
#'
#' @param expand_grid (_logical_) if `TRUE`, partial dependence will be
#'   computed at all possible combinations of inputs in `pred_spec`. If
#'   `FALSE`, partial dependence will be computed for each variable
#'   in `pred_spec`, separately.
#'
#' @param prob_values (_numeric_) a vector of values between 0 and 1,
#'   indicating what quantiles will be used to summarize the partial
#'   dependence values at each set of inputs. `prob_values` should
#'   have the same length as `prob_labels`. The quantiles are calculated
#'   based on predictions from `object` at each set of values indicated
#'   by `pred_spec`.
#'
#' @param prob_labels (_character_) a vector of labels with the same length
#'   as `prob_values`, with each label indicating what the corresponding
#'   value in `prob_values` should be labelled as in summarized outputs.
#'   `prob_labels` should have the same length as `prob_values`.
#'
#' @param boundary_checks (_logical_) if `TRUE`, `pred_spec` will be checked
#'  to make sure the requested values are between the 10th and 90th
#'  percentile in the object's training data. If `FALSE`, these checks are
#'  skipped.
#'
#' @param n_thread `r roxy_n_thread_header("computing predictions")`
#'
#' @param verbose_progress (_logical_) if `TRUE`, progress will be
#'  printed to console. If `FALSE` (the default), nothing will be
#'  printed.
#'
#' @param ... `r roxy_dots()`
#'
#' @return a [data.table][data.table::data.table-package] containing
#'   partial dependence values for the specified variable(s)
#'   and, if relevant, at the specified prediction horizon(s).
#'
#' @details
#'
#' `r roxy_pd_limitations()`
#'
#' @references
#'
#' 1. `r cite("hooker_2021")`
#'
#' @includeRmd Rmd/orsf_pd_examples.Rmd
#'
#' @export
#'
orsf_pd_oob <- function(object,
                        pred_spec,
                        pred_horizon = NULL,
                        pred_type = NULL,
                        expand_grid = TRUE,
                        prob_values = c(0.025, 0.50, 0.975),
                        prob_labels = c('lwr', 'medn', 'upr'),
                        boundary_checks = TRUE,
                        n_thread = NULL,
                        verbose_progress = NULL,
                        ...){

 check_dots(list(...), orsf_pd_oob)

 orsf_dependence(object = object,
                 pred_spec = pred_spec,
                 pd_data = NULL,
                 pred_horizon = pred_horizon,
                 pred_type = pred_type,
                 expand_grid = expand_grid,
                 prob_values = prob_values,
                 prob_labels = prob_labels,
                 boundary_checks = boundary_checks,
                 n_thread = n_thread,
                 verbose_progress = verbose_progress,
                 oobag = TRUE,
                 type_output = 'smry')

}

#' @rdname orsf_pd_oob
#' @export
orsf_pd_inb <- function(object,
                        pred_spec,
                        pred_horizon = NULL,
                        pred_type = NULL,
                        expand_grid = TRUE,
                        prob_values = c(0.025, 0.50, 0.975),
                        prob_labels = c('lwr', 'medn', 'upr'),
                        boundary_checks = TRUE,
                        n_thread = NULL,
                        verbose_progress = NULL,
                        ...){

 check_dots(list(...), orsf_pd_inb)

 if(is.null(object$data))
  stop("no data were found in object. ",
       "did you use attach_data = FALSE when ",
       "running orsf()?", call. = FALSE)

 orsf_dependence(object = object,
                 pred_spec = pred_spec,
                 pd_data = object$data,
                 pred_horizon = pred_horizon,
                 pred_type = pred_type,
                 expand_grid = expand_grid,
                 prob_values = prob_values,
                 prob_labels = prob_labels,
                 boundary_checks = boundary_checks,
                 n_thread = n_thread,
                 verbose_progress = verbose_progress,
                 oobag = FALSE,
                 type_output = 'smry')

}

#' @rdname orsf_pd_oob
#' @export
orsf_pd_new <- function(object,
                        pred_spec,
                        new_data,
                        pred_horizon = NULL,
                        pred_type = NULL,
                        na_action = 'fail',
                        expand_grid = TRUE,
                        prob_values = c(0.025, 0.50, 0.975),
                        prob_labels = c('lwr', 'medn', 'upr'),
                        boundary_checks = TRUE,
                        n_thread = NULL,
                        verbose_progress = NULL,
                        ...){

 check_dots(list(...), orsf_pd_new)

 orsf_dependence(object = object,
                 pred_spec = pred_spec,
                 pd_data = new_data,
                 pred_horizon = pred_horizon,
                 pred_type = pred_type,
                 na_action = na_action,
                 expand_grid = expand_grid,
                 prob_values = prob_values,
                 prob_labels = prob_labels,
                 boundary_checks = boundary_checks,
                 n_thread = n_thread,
                 verbose_progress = verbose_progress,
                 oobag = FALSE,
                 type_output = 'smry')

}


#' Individual Conditional Expectations
#'
#' Compute individual conditional expectations for an
#'  oblique random forest. `r roxy_ice_explain()`
#'  `r roxy_pd_oob_explain('individual conditional expectations')`
#'
#' @inheritParams orsf_pd_oob
#'
#' @return a [data.table][data.table::data.table-package] containing
#'   individual conditional expectations for the specified variable(s)
#'   and, if relevant, at the specified prediction horizon(s).
#'
#' @includeRmd Rmd/orsf_ice_examples.Rmd
#'
#' @export
#'
#'
orsf_ice_oob <- function(object,
                         pred_spec,
                         pred_horizon = NULL,
                         pred_type = NULL,
                         expand_grid = TRUE,
                         boundary_checks = TRUE,
                         n_thread = NULL,
                         verbose_progress = NULL,
                         ...){

 check_dots(list(...), orsf_ice_oob)

 orsf_dependence(object = object,
                 pred_spec = pred_spec,
                 pd_data = NULL,
                 pred_horizon = pred_horizon,
                 pred_type = pred_type,
                 expand_grid = expand_grid,
                 boundary_checks = boundary_checks,
                 n_thread = n_thread,
                 verbose_progress = verbose_progress,
                 oobag = TRUE,
                 type_output = 'ice')

}

#' @rdname orsf_ice_oob
#' @export
orsf_ice_inb <- function(object,
                         pred_spec,
                         pred_horizon = NULL,
                         pred_type = NULL,
                         expand_grid = TRUE,
                         boundary_checks = TRUE,
                         n_thread = NULL,
                         verbose_progress = NULL,
                         ...){

 check_dots(list(...), orsf_ice_oob)

 if(is.null(object$data))
  stop("no data were found in object. ",
       "did you use attach_data = FALSE when ",
       "running orsf()?", call. = FALSE)

 orsf_dependence(object = object,
                 pred_spec = pred_spec,
                 pd_data = object$data,
                 pred_horizon = pred_horizon,
                 pred_type = pred_type,
                 expand_grid = expand_grid,
                 boundary_checks = boundary_checks,
                 n_thread = n_thread,
                 verbose_progress = verbose_progress,
                 oobag = FALSE,
                 type_output = 'ice')

}

#' @rdname orsf_ice_oob
#' @export
orsf_ice_new <- function(object,
                         pred_spec,
                         new_data,
                         pred_horizon = NULL,
                         pred_type = NULL,
                         na_action = 'fail',
                         expand_grid = TRUE,
                         boundary_checks = TRUE,
                         n_thread = NULL,
                         verbose_progress = NULL,
                         ...){

 check_dots(list(...), orsf_ice_new)

 orsf_dependence(object = object,
                 pred_spec = pred_spec,
                 pd_data = new_data,
                 pred_horizon = pred_horizon,
                 pred_type = pred_type,
                 na_action = na_action,
                 expand_grid = expand_grid,
                 boundary_checks = boundary_checks,
                 n_thread = n_thread,
                 verbose_progress = verbose_progress,
                 oobag = FALSE,
                 type_output = 'ice')


}


#' delegates the work functions for partial dependence
#'
#' this function takes inputs from the main API functions and
#'   determines which of the lower level functions to call.
#'
#' @inheritParams orsf_pd
#' @param type_output 'ice' or 'smry'.
#'   this in combination with oobag determines which cpp routine to use.
#' @param type_input if 'grid', then all combos of pd_data are considered.
#'   if 'loop', then each entry of pd_data is considered separately.
#'
#' @return output from one of the pd working functions
#'
#' @noRd

orsf_dependence <- function(object,
                            pd_data,
                            pred_spec,
                            pred_horizon,
                            pred_type,
                            na_action = NULL,
                            expand_grid,
                            prob_values = NULL,
                            prob_labels = NULL,
                            boundary_checks,
                            n_thread = NULL,
                            verbose_progress = NULL,
                            oobag,
                            type_output){

 check_arg_is(object, arg_name = 'object', expected_class = 'ObliqueForest')

 if(object$sample_fraction > 0.90 && oobag){
  stop("sample_fraction must be <= 9/10 for out-of-bag dependence calculations",
       call. = FALSE)
 }

 if(oobag && is.null(object$data))
  stop("no data were found in object. ",
       "did you use attach_data = FALSE when ",
       "running orsf()?", call. = FALSE)

 object$compute_dependence(
  pd_data = pd_data,
  pred_spec = pred_spec,
  pred_horizon = pred_horizon,
  pred_type = pred_type,
  na_action = na_action,
  expand_grid = expand_grid,
  prob_values = prob_values,
  prob_labels = prob_labels,
  boundary_checks = boundary_checks,
  n_thread = n_thread %||% object$n_thread,
  verbose_progress = verbose_progress %||% object$verbose_progress,
  oobag = oobag,
  type_output = type_output
 )


}
bcjaeger/aorsf documentation built on April 3, 2025, 4:16 p.m.