R/light_ice.R

Defines functions light_ice.multiflashlight light_ice.flashlight light_ice.default light_ice

Documented in light_ice light_ice.default light_ice.flashlight light_ice.multiflashlight

#' Individual Conditional Expectation (ICE)
#'
#' Generates Individual Conditional Expectation (ICE) profiles.
#' An ICE profile shows how the prediction of an observation changes if
#' one or multiple variables are systematically changed across its ranges,
#' holding all other values fixed (see the reference below for details).
#' The curves can be centered in order to increase visibility of interaction effects.
#'
#' @details
#' There are two ways to specify the variable(s) to be profiled.
#' 1. Pass the variable name via `v` and an optional vector with evaluation points
#'   `evaluate_at` (or `breaks`). This works for dependence on a single variable.
#' 2. More general: Specify any `grid` as a `data.frame` with one or
#'   more columns. For instance, it can be generated by a call to [expand.grid()].
#'
#' The minimum required elements in the (multi-)flashlight are "predict_function",
#' "model", "linkinv" and "data", where the latest can be passed on the fly.
#'
#' Which rows in `data` are profiled? This is specified by `indices`.
#' If not given and `n_max` is smaller than the number of rows in `data`,
#' then row indices will be sampled randomly from `data`.
#' If the same rows should be used for all flashlights in a multiflashlight,
#' there are two options: Either pass a `seed` or a vector of indices used to select rows.
#' In both cases, `data` should be the same for all flashlights considered.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param evaluate_at Vector with values of `v` used to evaluate the profile.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic
#'   binning via `n_bins` and `cut_type`. Ignored if `v` is not numeric or if `grid`
#'   or `evaluate_at` are specified.
#' @param grid A `data.frame` with evaluation grid. For instance, can be generated by
#'   [expand.grid()].
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#'   Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#'   Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param indices A vector of row numbers to consider.
#' @param n_max If `indices` is not given, maximum number of rows to consider.
#'   Will be randomly picked from `data` if necessary.
#' @param seed An integer random seed.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param center How should curves be centered?
#'   - Default is "no".
#'   - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#'   - Choose "mean" to center all profiles at the within-group means.
#'   - Choose "0" to mean-center curves at 0.
#' @param ... Further arguments passed to or from other methods.
#' @returns
#'   An object of class "light_ice" with the following elements:
#'   - `data` A tibble containing the results. Can be used to build fully customized
#'     visualizations. Column names can be controlled by `options(flashlight.column_name)`.
#'   - `by` Same as input `by`.
#'   - `v` The variable(s) evaluated.
#'   - `center` How centering was done.
#' @export
#' @references
#'   Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical
#'     learning with plots of individual conditional expectation.
#'     Journal of Computational and Graphical Statistics, 24:1
#'     <doi.org/10.1080/10618600.2014.907095>.
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' fl <- flashlight(model = fit, label = "lm", data = iris)
#' light_ice(fl, v = "Species")
#' @seealso [light_profile()], [plot.light_ice()]
light_ice <- function(x, ...) {
  UseMethod("light_ice")
}

#' @describeIn light_ice Default method not implemented yet.
#' @export
light_ice.default <- function(x, ...) {
  stop("light_ice method is only available for objects of class flashlight or multiflashlight.")
}

#' @describeIn light_ice ICE profiles for a flashlight object.
#' @export
light_ice.flashlight <- function(x, v = NULL, data = x$data, by = x$by,
                                 evaluate_at = NULL, breaks = NULL,
                                 grid = NULL, n_bins = 27L,
                                 cut_type = c("equal", "quantile"),
                                 indices = NULL, n_max = 20L,
                                 seed = NULL, use_linkinv = TRUE,
                                 center = c("no", "first", "middle",
                                            "last", "mean", "0"), ...) {
  cut_type <- match.arg(cut_type)
  center <- match.arg(center)

  warning_on_names(c("value_name", "label_name", "id_name"), ...)

  value_name <- getOption("flashlight.value_name")
  label_name <- getOption("flashlight.label_name")
  id_name <- getOption("flashlight.id_name")

  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'by' not in 'data'!" = by %in% colnames(data),
    "'v' not in 'data'." = v %in% colnames(data),
    "'v' or 'grid' misses." = !is.null(grid) || !is.null(v)
  )
  check_unique(c(by, union(v, names(grid))), c(value_name, label_name, id_name))

  n <- nrow(data)

  # Complete/evaluate grid
  if (is.null(grid)) {
    if (is.null(evaluate_at)) {
      if (!is.null(breaks) && is.numeric(breaks)) {
        evaluate_at <- midpoints(breaks)
      } else {
        evaluate_at <- auto_cut(
          data[[v]], n_bins = n_bins, cut_type = cut_type, ...
        )$bin_means
      }
    }
    grid <- stats::setNames(data.frame(evaluate_at), v)
  } else {
    stopifnot("grid must be a data.frame" = is.data.frame(grid))
    v <- names(grid)
  }

  # Pick ids
  if (is.null(indices)) {
    if (n > n_max) {
      if (!is.null(seed)) {
        set.seed(seed)
      }
      indices <- sample(n, n_max)
    } else {
      indices <- seq_len(n)
    }
  }
  data <- data[indices, , drop = FALSE]

  # Full outer join of data and grid
  cols <- colnames(data)
  data[, v] <- NULL
  data[[id_name]] <- indices
  data <- tidyr::expand_grid(data, grid)

  # Update flashlight
  x <- flashlight(
    x,
    data = data[, cols, drop = FALSE],
    by = by,
    linkinv = if (use_linkinv) x$linkinv else function(z) z
  )

  # Add predictions and organize output
  data <- data[, c(id_name, by, v, x$w), drop = FALSE]
  data[[value_name]] <- stats::predict(x)

  # c-ICE curves
  if (center == "0") {
      data[[value_name]] <- grouped_center(
        data, x = value_name, by = id_name, na.rm = TRUE
      )
  } else if (center == "mean") {
    centered_values <- grouped_center(
      data, x = value_name, by = id_name, na.rm = TRUE
    )
    if (is.null(by)) {
      data[[value_name]] <- centered_values +
        MetricsWeighted::weighted_mean(
          data[[value_name]], w = if (!is.null(x$w)) data[[x$w]], na.rm = TRUE
        )
    } else {
      group_means <- grouped_stats(
        data,
        x = value_name,
        w = x$w,
        by = by,
        counts = FALSE,
        value_name = "global_mean",
        na.rm = TRUE
      )
      stopifnot(!("global_mean" %in% colnames(data)))
      data[[value_name]] <- centered_values +
        dplyr::left_join(data, group_means, by = by)[["global_mean"]]
    }
  } else if (center != "no") {
    pos <- switch(
      center,
      first = 1,
      middle = floor((nrow(grid) + 1L) / 2),
      last = nrow(grid))
    data[[value_name]] <- stats::ave(
      data[[value_name]], data[[id_name]], FUN = function(z) z - z[pos]
    )
  }

  # Finalize output
  data[[label_name]] <- x$label
  out <- list(data = data, by = by, v = names(grid), center = center)
  add_classes(out, c("light_ice", "light"))
}

#' @describeIn light_ice ICE profiles for a multiflashlight object.
#' @export
light_ice.multiflashlight <- function(x, ...) {
  light_combine(lapply(x, light_ice, ...), new_class = "light_ice_multi")
}

Try the flashlight package in your browser

Any scripts or data that you put into this service are public.

flashlight documentation built on May 31, 2023, 6:19 p.m.