R/light_ice.R

Defines functions plot.light_ice light_ice.multiflashlight light_ice.flashlight light_ice.default light_ice

Documented in light_ice light_ice.default light_ice.flashlight light_ice.multiflashlight plot.light_ice

#' Individual Conditional Expectation (ICE)
#'
#' Generates Individual Conditional Expectation (ICE) profiles.
#' An ICE profile shows how the prediction of an observation changes if
#' one or multiple variables are systematically changed across its ranges,
#' holding all other values fixed (see the reference below for details).
#' The curves can be centered in order to increase visibility of interaction effects.
#'
#' @details
#' There are two ways to specify the variable(s) to be profiled.
#' 1. Pass the variable name via `v` and an optional vector with evaluation points
#'   `evaluate_at` (or `breaks`). This works for dependence on a single variable.
#' 2. More general: Specify any `grid` as a `data.frame` with one or
#'   more columns. For instance, it can be generated by a call to [expand.grid()].
#'
#' The minimum required elements in the (multi-)flashlight are "predict_function",
#' "model", "linkinv" and "data", where the latest can be passed on the fly.
#'
#' Which rows in `data` are profiled? This is specified by `indices`.
#' If not given and `n_max` is smaller than the number of rows in `data`,
#' then row indices will be sampled randomly from `data`.
#' If the same rows should be used for all flashlights in a multiflashlight,
#' there are two options: Either pass a `seed` or a vector of indices used to select rows.
#' In both cases, `data` should be the same for all flashlights considered.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param evaluate_at Vector with values of `v` used to evaluate the profile.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic
#'   binning via `n_bins` and `cut_type`. Ignored if `v` is not numeric or if `grid`
#'   or `evaluate_at` are specified.
#' @param grid A `data.frame` with evaluation grid. For instance, can be generated by
#'   [expand.grid()].
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#'   Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#'   Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param indices A vector of row numbers to consider.
#' @param n_max If `indices` is not given, maximum number of rows to consider.
#'   Will be randomly picked from `data` if necessary.
#' @param seed An integer random seed.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param center How should curves be centered?
#'   - Default is "no".
#'   - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#'   - Choose "mean" to center all profiles at the within-group means.
#'   - Choose "0" to mean-center curves at 0.
#' @param ... Further arguments passed to or from other methods.
#' @returns
#'   An object of class "light_ice" with the following elements:
#'   - `data` A tibble containing the results.
#'   - `by` Same as input `by`.
#'   - `v` The variable(s) evaluated.
#'   - `center` How centering was done.
#' @export
#' @references
#'   Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical
#'     learning with plots of individual conditional expectation.
#'     Journal of Computational and Graphical Statistics, 24:1
#'     <doi.org/10.1080/10618600.2014.907095>.
#' @examples
#' fit_add <- lm(Sepal.Length ~ ., data = iris)
#' fl_add <- flashlight(model = fit_add, label = "additive", data = iris)
#'
#' plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200), alpha = 0.2)
#' plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200, center = "first"))
#'
#' # Second model with interactions
#' fit_nonadd <- lm(Sepal.Length ~ . + Sepal.Width:Species, data = iris)
#' fl_nonadd <- flashlight(model = fit_nonadd, label = "nonadditive", data = iris)
#' fls <- multiflashlight(list(fl_add, fl_nonadd))
#'
#' plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200), alpha = 0.2)
#' plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200, center = "mid"))
#' @seealso [light_profile()], [plot.light_ice()]
light_ice <- function(x, ...) {
  UseMethod("light_ice")
}

#' @describeIn light_ice Default method not implemented yet.
#' @export
light_ice.default <- function(x, ...) {
  stop("light_ice method is only available for objects of class flashlight or multiflashlight.")
}

#' @describeIn light_ice ICE profiles for a flashlight object.
#' @export
light_ice.flashlight <- function(x, v = NULL, data = x$data, by = x$by,
                                 evaluate_at = NULL, breaks = NULL,
                                 grid = NULL, n_bins = 27L,
                                 cut_type = c("equal", "quantile"),
                                 indices = NULL, n_max = 20L,
                                 seed = NULL, use_linkinv = TRUE,
                                 center = c("no", "first", "middle",
                                            "last", "mean", "0"), ...) {
  cut_type <- match.arg(cut_type)
  center <- match.arg(center)

  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'by' not in 'data'!" = by %in% colnames(data),
    "'v' not in 'data'." = v %in% colnames(data),
    "'v' or 'grid' misses." = !is.null(grid) || !is.null(v),
    !any(c("value_", "label_", "id_") %in% c(by, v, names(grid)))
  )

  n <- nrow(data)

  # Complete/evaluate grid
  if (is.null(grid)) {
    if (is.null(evaluate_at)) {
      if (!is.null(breaks) && is.numeric(breaks)) {
        evaluate_at <- midpoints(breaks)
      } else {
        evaluate_at <- auto_cut(
          data[[v]], n_bins = n_bins, cut_type = cut_type, ...
        )$bin_means
      }
    }
    grid <- stats::setNames(data.frame(evaluate_at), v)
  } else {
    stopifnot("grid must be a data.frame" = is.data.frame(grid))
    v <- names(grid)
  }

  # Pick ids
  if (is.null(indices)) {
    if (n > n_max) {
      if (!is.null(seed)) {
        set.seed(seed)
      }
      indices <- sample(n, n_max)
    } else {
      indices <- seq_len(n)
    }
  }
  data <- data[indices, , drop = FALSE]

  # Full outer join of data and grid
  cols <- colnames(data)
  data[, v] <- NULL
  data$id_ <- indices
  data <- tidyr::expand_grid(data, grid)

  # Update flashlight
  x <- flashlight(
    x,
    data = data[, cols, drop = FALSE],
    by = by,
    linkinv = if (use_linkinv) x$linkinv else function(z) z
  )

  # Add predictions and organize output
  data <- data[, c("id_", by, v, x$w), drop = FALSE]
  data$value_ <- stats::predict(x)

  # c-ICE curves
  if (center == "0") {
      data$value_ <- grouped_center(data, x = "value_", by = "id_", na.rm = TRUE)
  } else if (center == "mean") {
    centered_values <- grouped_center(data, x = "value_", by = "id_", na.rm = TRUE)
    if (is.null(by)) {
      data$value_ <- centered_values +
        MetricsWeighted::weighted_mean(
          data$value_, w = if (!is.null(x$w)) data[[x$w]], na.rm = TRUE
        )
    } else {
      group_means <- grouped_stats(
        data,
        x = "value_",
        w = x$w,
        by = by,
        counts = FALSE,
        value_name = "global_mean_",
        na.rm = TRUE
      )
      stopifnot(!("global_mean_" %in% colnames(data)))
      data$value_ <- centered_values +
        dplyr::left_join(data, group_means, by = by)$global_mean_
    }
  } else if (center != "no") {
    pos <- switch(
      center,
      first = 1,
      middle = floor((nrow(grid) + 1L) / 2),
      last = nrow(grid)
    )
    data <- transform(
      data, value_ = stats::ave(value_, id_, FUN = function(z) z - z[pos])
    )
  }

  # Finalize output
  data$label_ <- x$label
  out <- list(data = data, by = by, v = names(grid), center = center)
  add_classes(out, c("light_ice", "light"))
}

#' @describeIn light_ice ICE profiles for a multiflashlight object.
#' @export
light_ice.multiflashlight <- function(x, ...) {
  light_combine(lapply(x, light_ice, ...), new_class = "light_ice_multi")
}

#' Visualize ICE profiles
#'
#' Minimal visualization of an object of class "light_ice" as [ggplot2::geom_line()].
#' The object returned is of class "ggplot" and can be further customized.
#'
#' Each observation is visualized by a line. The first "by" variable is represented
#' by the color, a second "by" variable or a multiflashlight by facets.
#'
#' @importFrom rlang .data
#'
#' @inheritParams plot.light_performance
#' @param x An object of class "light_ice".
#' @param ... Further arguments passed to [ggplot2::geom_line()].
#' @returns An object of class "ggplot".
#' @export
#' @seealso [light_ice()]
plot.light_ice <- function(x, facet_scales = "fixed", rotate_x = FALSE, ...) {
  nby <- length(x$by)
  multi <- is.light_ice_multi(x)

  if (nby + multi > 2L) {
    stop("Plot method not defined for more than two by variables or
         multiflashlight with more than one by variable.")
  }
  if (length(x$v) >= 2L) {
    stop("No plot method defined for two or higher dimensional grids.")
  }

  data <- x$data

  # Distinguish cases
  if (nby == 0L) {
    p <- ggplot2::ggplot(
      data, ggplot2::aes(y = value_, x = .data[[x$v]], group = id_)
    ) +
      ggplot2::geom_line(...)
  } else {
    stopifnot(!("temp_" %in% colnames(data)))
    data <- transform(data, temp_ = interaction(id_, x$by[1L]))
    p <- ggplot2::ggplot(
      data, ggplot2::aes(y = value_, x = .data[[x$v]], group = temp_)
    ) +
      ggplot2::geom_line(ggplot2::aes(color = .data[[x$by[1L]]]), ...) +
      override_alpha()
  }
  if (nby > 1L || multi) {
    p <- p + ggplot2::facet_wrap(
      if (multi) "label_" else x$by[2L], scales = facet_scales
    )
  }
  if (rotate_x) {
    p <- p + rotate_x()
  }
  p
}
mayer79/flashlight documentation built on Feb. 13, 2024, 1:09 p.m.