R/light_profile.R

Defines functions light_profile.multiflashlight light_profile.flashlight light_profile.default light_profile

Documented in light_profile light_profile.default light_profile.flashlight light_profile.multiflashlight

#' Partial Dependence and other Profiles
#'
#' @description
#' Calculates different types of profiles across covariable values.
#' By default, partial dependence profiles are calculated (see Friedman).
#' Other options are profiles of ALE (accumulated local effects, see Apley),
#' response, predicted values ("M plots" or "marginal plots", see Apley),
#' residuals, and shap. The results are aggregated either by (weighted) means or by
#' (weighted) quartiles.
#'
#' Note that ALE profiles are calibrated by (weighted) average predictions.
#' In contrast to the suggestions in Apley, we calculate ALE profiles of factors
#' in the same order as the factor levels.
#' They are not being reordered based on similiarity of other variables.
#'
#' @details
#' Numeric covariables `v` with more than `n_bins` disjoint values
#' are binned into `n_bins` bins. Alternatively, `breaks` can be provided
#' to specify the binning. For partial dependence profiles
#' (and partly also ALE profiles), this behaviour can be overwritten either
#' by providing a vector of evaluation points (`pd_evaluate_at`) or an
#' evaluation `pd_grid`. By the latter we mean a data frame with column name(s)
#' with a (multi-)variate evaluation grid.
#'
#' For partial dependence, ALE, and prediction profiles, "model", "predict_function",
#' "linkinv" and "data" are required. For response profiles its "y", "linkinv" and
#' "data", and for shap profiles it is just "shap". "data" can be passed on the fly.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`. Not used for `type = "shap"`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param type Type of the profile: Either "partial dependence", "ale", "predicted",
#'   "response", "residual", or "shap".
#' @param stats Statistic to calculate: "mean" or "quartiles". For ALE profiles,
#'   only "mean" makes sense.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic binning via
#'   `n_bins` and `cut_type`. Ignored if `v` is not numeric.
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#'   Ignored if `v` is not numeric or if `breaks` is specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#'   Ignored if `v` is not numeric or if `breaks` is specified.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#'   Not used for type "shap".
#' @param counts Should observation counts be added?
#' @param counts_weighted If `counts = TRUE`: Should counts be weighted by the
#'   case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param v_labels If `FALSE`, return group centers of `v` instead of labels.
#'   Only relevant for types "response", "predicted" or "residual" and if `v`
#'   is being binned. In that case useful, for instance, if different flashlights
#'   use different data sets and bin labels would not match.
#' @param pred Optional vector with predictions (after application of inverse link).
#'   Can be used to avoid recalculation of predictions over and over if the functions
#'   is to be repeatedly called for different `v` and predictions are computationally
#'   expensive to make. Not implemented for multiflashlight.
#' @param pd_evaluate_at Vector with values of `v` used to evaluate the profile.
#'   Only relevant for type = "partial dependence" and "ale".
#' @param pd_grid A `data.frame` with grid values, e.g., generated by [expand.grid()].
#'   Only used for type = "partial dependence".
#' @param pd_indices A vector of row numbers to consider in calculating
#'   partial dependence profiles and "ale".
#' @param pd_n_max Maximum number of ICE profiles to calculate (will be randomly
#'   picked from `data`) for partial dependence and ALE.
#' @param pd_seed Integer random seed used to select ICE profiles for partial dependence
#'   and ALE.
#' @param pd_center How should ICE curves be centered?
#'   - Default is "no".
#'   - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#'   - Choose "mean" to center all profiles at the within-group means.
#'   - Choose "0" to mean-center curves at 0. Only relevant for partial dependence.
#' @param ale_two_sided If `TRUE`, `v` is continuous and `breaks`
#'   are passed or being calculated, then two-sided derivatives are calculated
#'   for ALE instead of left derivatives. More specifically: Usually, local effects
#'   at value x are calculated using points in \eqn{[x-e, x]}.
#'   Set `ale_two_sided = TRUE` to use points in \eqn{[x-e/2, x+e/2]}.
#' @param ... Further arguments passed to [cut3()] in forming the
#'   cut breaks of the `v` variable.
#' @returns
#'   An object of class "light_profile" with the following elements:
#'   - `data` A tibble containing results. Can be used to build fully customized
#'     visualizations. Column names can be controlled by `options(flashlight.column_name)`.
#'   - `by` Names of group by variable.
#'   - `v` The variable(s) evaluated.
#'   - `type` Same as input `type`. For information only.
#'   - `stats` Same as input `stats`.
#' @export
#' @references
#' - Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. The Annals of Statistics, 29:1189–1232.
#' - Apley D. W. (2016). Visualizing the effects of predictor variables in black box supervised learning models.
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' fl <- flashlight(model = fit, label = "iris", data = iris, y = "Sepal.Length")
#' light_profile(fl, v = "Species")
#' light_profile(fl, v = "Petal.Width", type = "residual")
#' @seealso [light_effects()], [plot.light_profile()]
light_profile <- function(x, ...) {
  UseMethod("light_profile")
}

#' @describeIn light_profile Default method not implemented yet.
#' @export
light_profile.default <- function(x, ...) {
  stop("No default method available yet.")
}

#' @describeIn light_profile Profiles for flashlight.
#' @export
light_profile.flashlight <- function(x, v = NULL, data = NULL, by = x$by,
                                     type = c("partial dependence", "ale",
                                              "predicted", "response",
                                              "residual", "shap"),
                                     stats = c("mean", "quartiles"),
                                     breaks = NULL, n_bins = 11L,
                                     cut_type = c("equal", "quantile"),
                                     use_linkinv = TRUE, counts = TRUE,
                                     counts_weighted = FALSE,
                                     v_labels = TRUE, pred = NULL,
                                     pd_evaluate_at = NULL, pd_grid = NULL,
                                     pd_indices = NULL, pd_n_max = 1000L,
                                     pd_seed = NULL,
                                     pd_center = c("no", "first", "middle",
                                                   "last", "mean", "0"),
                                     ale_two_sided = FALSE, ...) {
  type <- match.arg(type)
  stats <- match.arg(stats)
  cut_type <- match.arg(cut_type)
  pd_center <- match.arg(pd_center)

  if (stats == "quartiles") {
    message("stats = 'quartiles' is deprecated and will be removed in flashlight 1.0.0.")
  }
  if (type == "shap") {
    message("type = 'shap' is deprecated and will be removed in flashlight 1.0.0.")
  }

  warning_on_names(
    c("value_name", "label_name", "q1_name", "q3_name", "type_name", "counts_name"),
    ...
  )

  value_name <- getOption("flashlight.value_name")
  label_name <- getOption("flashlight.label_name")
  q1_name <- getOption("flashlight.q1_name")
  q3_name <- getOption("flashlight.q3_name")
  type_name <- getOption("flashlight.type_name")
  counts_name <- getOption("flashlight.counts_name")

  # If SHAP, extract data
  if (type == "shap") {
    if (!is.shap(x$shap)) {
      stop("No shap values calculated. Run 'add_shap' for the flashlight first.")
    }
    stopifnot(v %in% colnames(x$shap$data))
    variable_name <- getOption("flashlight.variable_name")
    data <- x$shap$data[x$shap$data[[variable_name]] == v, ]
  } else if (is.null(data)) {
    data <- x$data
  }

  # Checks (more will be done below or in the called functions)
  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'by' not in 'data'!" = by %in% colnames(data),
    "'v' not in 'data'." = v %in% colnames(data),
    "'v' or 'pd_grid' misses." = !is.null(pd_grid) || !is.null(v)
  )
  check_unique(
    c(by, union(v, names(pd_grid))),
    opt_names = c(if (counts) counts_name,
                  if (stats == "quartiles") c(q1_name, q3_name),
                  value_name, label_name, type_name)
  )
  if (!is.null(pred) && type == "predicted" && length(pred) != nrow(data)) {
    stop("Wrong number of predicted values passed.")
  }
  if (type == "ale" && stats == "quartiles") {
    stop("The cumsum step of ALE does not make sense for quartiles, so this option is not available.")
  }

  # Update flashlight
  if (type != "shap") {
    x <- flashlight(
      x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z
    )
  }

  # Calculate profiles
  arg_list <- list(
    x = x,
    v = v,
    evaluate_at = pd_evaluate_at,
    breaks = breaks,
    n_bins = n_bins,
    cut_type = cut_type,
    indices = pd_indices,
    n_max = pd_n_max,
    seed = pd_seed
  )
  if (type == "partial dependence") {
    arg_list <- c(arg_list, list(grid = pd_grid, center = pd_center))
    withr::with_options(
      list(flashlight.id_name = "id_xxx"),  # safer than default
      cp_profiles <- do.call(light_ice, arg_list)
    )
    v <- cp_profiles$v
    data <- cp_profiles$data
  } else if (type == "ale") {
    arg_list <- c(
      arg_list,
      list(
        counts = counts,
        counts_weighted = counts_weighted,
        pred = pred,
        two_sided = ale_two_sided
      )
    )
    agg <- do.call(ale_profile, arg_list)
  } else {
    stopifnot(
      "'v' misses." = !is.null(v),
      "'v' not in data." = v %in% colnames(data)
    )
    if (type %in% c("response", "residual") && is.null(x$y)) {
      stop("You need to specify 'y' in flashlight.")
    }

    # Add predictions/response to data
    data[[value_name]] <- switch(type,
      response = response(x),
      predicted = if (is.null(pred)) stats::predict(x) else pred,
      residual = stats::residuals(x),
      shap = data[["shap_"]]
    )

    # Replace v values by binned ones
    cuts <- auto_cut(
      data[[v]], breaks = breaks, n_bins = n_bins, cut_type = cut_type, ...
    )
    data[[v]] <- cuts$data[[if (v_labels) "level" else "value"]]
  }

  # Aggregate predicted values
  if (type != "ale") { # ale is already aggregated
    agg <- grouped_stats(
      data = data,
      x = value_name,
      w = x$w,
      by = c(by, v),
      stats = stats,
      counts = counts,
      counts_weighted = counts_weighted,
      counts_name = counts_name,
      q1_name = q1_name,
      q3_name = q3_name,
      na.rm = TRUE
    )
  }

  # Finalize results
  agg[[label_name]] <- x$label

  # Code type as factor (relevant for light_effects)
  agg[[type_name]] <- factor(
    type, c("response", "predicted", "partial dependence", "ale", "residual", "shap")
  )

  # Collect results
  out <- list(data = agg, by = by, v = v, type = type, stats = stats)
  add_classes(out, c("light_profile", "light"))
}

#' @describeIn light_profile Profiles for multiflashlight.
#' @export
light_profile.multiflashlight <- function(x, v = NULL, data = NULL,
                                          type = c("partial dependence", "ale",
                                                   "predicted", "response",
                                                   "residual", "shap"),
                                          breaks = NULL, n_bins = 11L,
                                          cut_type = c("equal", "quantile"),
                                          pd_evaluate_at = NULL,
                                          pd_grid = NULL, ...) {
  type <- match.arg(type)
  cut_type <- match.arg(cut_type)

  is_pd <- type == "partial dependence"
  is_ale <- type == "ale"

  if ("pred" %in% names(list(...))) {
    stop("'pred' not implemented for multiflashlight")
  }

  # Align breaks for numeric v
  if (is.null(pd_grid) || !is_pd) {
    stopifnot("Need exactly one 'v'." = length(v) == 1L)
    if (is.null(breaks) && (is.null(pd_evaluate_at) || (!is_pd && !is_ale))) {
      breaks <- common_breaks(
        x = x, v = v, data = data, n_bins = n_bins, cut_type = cut_type
      )
    }
  }
  all_profiles <- lapply(
    x,
    light_profile,
    v = v,
    data = data,
    type = type,
    breaks = breaks,
    n_bins = n_bins,
    cut_type = cut_type,
    pd_evaluate_at = pd_evaluate_at,
    pd_grid = pd_grid,
    ...
  )
  light_combine(all_profiles, new_class = "light_profile_multi")
}

Try the flashlight package in your browser

Any scripts or data that you put into this service are public.

flashlight documentation built on May 31, 2023, 6:19 p.m.