#' Individual Conditional Expectation (ICE)
#'
#' Generates Individual Conditional Expectation (ICE) profiles.
#' An ICE profile shows how the prediction of an observation changes if
#' one or multiple variables are systematically changed across its ranges,
#' holding all other values fixed (see the reference below for details).
#' The curves can be centered in order to increase visibility of interaction effects.
#'
#' @details
#' There are two ways to specify the variable(s) to be profiled.
#' 1. Pass the variable name via `v` and an optional vector with evaluation points
#' `evaluate_at` (or `breaks`). This works for dependence on a single variable.
#' 2. More general: Specify any `grid` as a `data.frame` with one or
#' more columns. For instance, it can be generated by a call to [expand.grid()].
#'
#' The minimum required elements in the (multi-)flashlight are "predict_function",
#' "model", "linkinv" and "data", where the latest can be passed on the fly.
#'
#' Which rows in `data` are profiled? This is specified by `indices`.
#' If not given and `n_max` is smaller than the number of rows in `data`,
#' then row indices will be sampled randomly from `data`.
#' If the same rows should be used for all flashlights in a multiflashlight,
#' there are two options: Either pass a `seed` or a vector of indices used to select rows.
#' In both cases, `data` should be the same for all flashlights considered.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param evaluate_at Vector with values of `v` used to evaluate the profile.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic
#' binning via `n_bins` and `cut_type`. Ignored if `v` is not numeric or if `grid`
#' or `evaluate_at` are specified.
#' @param grid A `data.frame` with evaluation grid. For instance, can be generated by
#' [expand.grid()].
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#' Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#' Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param indices A vector of row numbers to consider.
#' @param n_max If `indices` is not given, maximum number of rows to consider.
#' Will be randomly picked from `data` if necessary.
#' @param seed An integer random seed.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param center How should curves be centered?
#' - Default is "no".
#' - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#' - Choose "mean" to center all profiles at the within-group means.
#' - Choose "0" to mean-center curves at 0.
#' @param ... Further arguments passed to or from other methods.
#' @returns
#' An object of class "light_ice" with the following elements:
#' - `data` A tibble containing the results.
#' - `by` Same as input `by`.
#' - `v` The variable(s) evaluated.
#' - `center` How centering was done.
#' @export
#' @references
#' Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical
#' learning with plots of individual conditional expectation.
#' Journal of Computational and Graphical Statistics, 24:1
#' <doi.org/10.1080/10618600.2014.907095>.
#' @examples
#' fit_add <- lm(Sepal.Length ~ ., data = iris)
#' fl_add <- flashlight(model = fit_add, label = "additive", data = iris)
#'
#' plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200), alpha = 0.2)
#' plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200, center = "first"))
#'
#' # Second model with interactions
#' fit_nonadd <- lm(Sepal.Length ~ . + Sepal.Width:Species, data = iris)
#' fl_nonadd <- flashlight(model = fit_nonadd, label = "nonadditive", data = iris)
#' fls <- multiflashlight(list(fl_add, fl_nonadd))
#'
#' plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200), alpha = 0.2)
#' plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200, center = "mid"))
#' @seealso [light_profile()], [plot.light_ice()]
light_ice <- function(x, ...) {
UseMethod("light_ice")
}
#' @describeIn light_ice Default method not implemented yet.
#' @export
light_ice.default <- function(x, ...) {
stop("light_ice method is only available for objects of class flashlight or multiflashlight.")
}
#' @describeIn light_ice ICE profiles for a flashlight object.
#' @export
light_ice.flashlight <- function(x, v = NULL, data = x$data, by = x$by,
evaluate_at = NULL, breaks = NULL,
grid = NULL, n_bins = 27L,
cut_type = c("equal", "quantile"),
indices = NULL, n_max = 20L,
seed = NULL, use_linkinv = TRUE,
center = c("no", "first", "middle",
"last", "mean", "0"), ...) {
cut_type <- match.arg(cut_type)
center <- match.arg(center)
stopifnot(
"No data!" = is.data.frame(data) && nrow(data) >= 1L,
"'by' not in 'data'!" = by %in% colnames(data),
"'v' not in 'data'." = v %in% colnames(data),
"'v' or 'grid' misses." = !is.null(grid) || !is.null(v),
!any(c("value_", "label_", "id_") %in% c(by, v, names(grid)))
)
n <- nrow(data)
# Complete/evaluate grid
if (is.null(grid)) {
if (is.null(evaluate_at)) {
if (!is.null(breaks) && is.numeric(breaks)) {
evaluate_at <- midpoints(breaks)
} else {
evaluate_at <- auto_cut(
data[[v]], n_bins = n_bins, cut_type = cut_type, ...
)$bin_means
}
}
grid <- stats::setNames(data.frame(evaluate_at), v)
} else {
stopifnot("grid must be a data.frame" = is.data.frame(grid))
v <- names(grid)
}
# Pick ids
if (is.null(indices)) {
if (n > n_max) {
if (!is.null(seed)) {
set.seed(seed)
}
indices <- sample(n, n_max)
} else {
indices <- seq_len(n)
}
}
data <- data[indices, , drop = FALSE]
# Full outer join of data and grid
cols <- colnames(data)
data[, v] <- NULL
data$id_ <- indices
data <- tidyr::expand_grid(data, grid)
# Update flashlight
x <- flashlight(
x,
data = data[, cols, drop = FALSE],
by = by,
linkinv = if (use_linkinv) x$linkinv else function(z) z
)
# Add predictions and organize output
data <- data[, c("id_", by, v, x$w), drop = FALSE]
data$value_ <- stats::predict(x)
# c-ICE curves
if (center == "0") {
data$value_ <- grouped_center(data, x = "value_", by = "id_", na.rm = TRUE)
} else if (center == "mean") {
centered_values <- grouped_center(data, x = "value_", by = "id_", na.rm = TRUE)
if (is.null(by)) {
data$value_ <- centered_values +
MetricsWeighted::weighted_mean(
data$value_, w = if (!is.null(x$w)) data[[x$w]], na.rm = TRUE
)
} else {
group_means <- grouped_stats(
data,
x = "value_",
w = x$w,
by = by,
counts = FALSE,
value_name = "global_mean_",
na.rm = TRUE
)
stopifnot(!("global_mean_" %in% colnames(data)))
data$value_ <- centered_values +
dplyr::left_join(data, group_means, by = by)$global_mean_
}
} else if (center != "no") {
pos <- switch(
center,
first = 1,
middle = floor((nrow(grid) + 1L) / 2),
last = nrow(grid)
)
data <- transform(
data, value_ = stats::ave(value_, id_, FUN = function(z) z - z[pos])
)
}
# Finalize output
data$label_ <- x$label
out <- list(data = data, by = by, v = names(grid), center = center)
add_classes(out, c("light_ice", "light"))
}
#' @describeIn light_ice ICE profiles for a multiflashlight object.
#' @export
light_ice.multiflashlight <- function(x, ...) {
light_combine(lapply(x, light_ice, ...), new_class = "light_ice_multi")
}
#' Visualize ICE profiles
#'
#' Minimal visualization of an object of class "light_ice" as [ggplot2::geom_line()].
#' The object returned is of class "ggplot" and can be further customized.
#'
#' Each observation is visualized by a line. The first "by" variable is represented
#' by the color, a second "by" variable or a multiflashlight by facets.
#'
#' @importFrom rlang .data
#'
#' @inheritParams plot.light_performance
#' @param x An object of class "light_ice".
#' @param ... Further arguments passed to [ggplot2::geom_line()].
#' @returns An object of class "ggplot".
#' @export
#' @seealso [light_ice()]
plot.light_ice <- function(x, facet_scales = "fixed", rotate_x = FALSE, ...) {
nby <- length(x$by)
multi <- is.light_ice_multi(x)
if (nby + multi > 2L) {
stop("Plot method not defined for more than two by variables or
multiflashlight with more than one by variable.")
}
if (length(x$v) >= 2L) {
stop("No plot method defined for two or higher dimensional grids.")
}
data <- x$data
# Distinguish cases
if (nby == 0L) {
p <- ggplot2::ggplot(
data, ggplot2::aes(y = value_, x = .data[[x$v]], group = id_)
) +
ggplot2::geom_line(...)
} else {
stopifnot(!("temp_" %in% colnames(data)))
data <- transform(data, temp_ = interaction(id_, x$by[1L]))
p <- ggplot2::ggplot(
data, ggplot2::aes(y = value_, x = .data[[x$v]], group = temp_)
) +
ggplot2::geom_line(ggplot2::aes(color = .data[[x$by[1L]]]), ...) +
override_alpha()
}
if (nby > 1L || multi) {
p <- p + ggplot2::facet_wrap(
if (multi) "label_" else x$by[2L], scales = facet_scales
)
}
if (rotate_x) {
p <- p + rotate_x()
}
p
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.