Nothing
#' Individual Conditional Expectation (ICE)
#'
#' Generates Individual Conditional Expectation (ICE) profiles.
#' An ICE profile shows how the prediction of an observation changes if
#' one or multiple variables are systematically changed across its ranges,
#' holding all other values fixed (see the reference below for details).
#' The curves can be centered in order to increase visibility of interaction effects.
#'
#' @details
#' There are two ways to specify the variable(s) to be profiled.
#' 1. Pass the variable name via `v` and an optional vector with evaluation points
#' `evaluate_at` (or `breaks`). This works for dependence on a single variable.
#' 2. More general: Specify any `grid` as a `data.frame` with one or
#' more columns. For instance, it can be generated by a call to [expand.grid()].
#'
#' The minimum required elements in the (multi-)flashlight are "predict_function",
#' "model", "linkinv" and "data", where the latest can be passed on the fly.
#'
#' Which rows in `data` are profiled? This is specified by `indices`.
#' If not given and `n_max` is smaller than the number of rows in `data`,
#' then row indices will be sampled randomly from `data`.
#' If the same rows should be used for all flashlights in a multiflashlight,
#' there are two options: Either pass a `seed` or a vector of indices used to select rows.
#' In both cases, `data` should be the same for all flashlights considered.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param evaluate_at Vector with values of `v` used to evaluate the profile.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic
#' binning via `n_bins` and `cut_type`. Ignored if `v` is not numeric or if `grid`
#' or `evaluate_at` are specified.
#' @param grid A `data.frame` with evaluation grid. For instance, can be generated by
#' [expand.grid()].
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#' Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#' Ignored if `v` is not numeric or if `breaks`, `grid` or `evaluate_at` are specified.
#' @param indices A vector of row numbers to consider.
#' @param n_max If `indices` is not given, maximum number of rows to consider.
#' Will be randomly picked from `data` if necessary.
#' @param seed An integer random seed.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param center How should curves be centered?
#' - Default is "no".
#' - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#' - Choose "mean" to center all profiles at the within-group means.
#' - Choose "0" to mean-center curves at 0.
#' @param ... Further arguments passed to or from other methods.
#' @returns
#' An object of class "light_ice" with the following elements:
#' - `data` A tibble containing the results. Can be used to build fully customized
#' visualizations. Column names can be controlled by `options(flashlight.column_name)`.
#' - `by` Same as input `by`.
#' - `v` The variable(s) evaluated.
#' - `center` How centering was done.
#' @export
#' @references
#' Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical
#' learning with plots of individual conditional expectation.
#' Journal of Computational and Graphical Statistics, 24:1
#' <doi.org/10.1080/10618600.2014.907095>.
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' fl <- flashlight(model = fit, label = "lm", data = iris)
#' light_ice(fl, v = "Species")
#' @seealso [light_profile()], [plot.light_ice()]
light_ice <- function(x, ...) {
UseMethod("light_ice")
}
#' @describeIn light_ice Default method not implemented yet.
#' @export
light_ice.default <- function(x, ...) {
stop("light_ice method is only available for objects of class flashlight or multiflashlight.")
}
#' @describeIn light_ice ICE profiles for a flashlight object.
#' @export
light_ice.flashlight <- function(x, v = NULL, data = x$data, by = x$by,
evaluate_at = NULL, breaks = NULL,
grid = NULL, n_bins = 27L,
cut_type = c("equal", "quantile"),
indices = NULL, n_max = 20L,
seed = NULL, use_linkinv = TRUE,
center = c("no", "first", "middle",
"last", "mean", "0"), ...) {
cut_type <- match.arg(cut_type)
center <- match.arg(center)
warning_on_names(c("value_name", "label_name", "id_name"), ...)
value_name <- getOption("flashlight.value_name")
label_name <- getOption("flashlight.label_name")
id_name <- getOption("flashlight.id_name")
stopifnot(
"No data!" = is.data.frame(data) && nrow(data) >= 1L,
"'by' not in 'data'!" = by %in% colnames(data),
"'v' not in 'data'." = v %in% colnames(data),
"'v' or 'grid' misses." = !is.null(grid) || !is.null(v)
)
check_unique(c(by, union(v, names(grid))), c(value_name, label_name, id_name))
n <- nrow(data)
# Complete/evaluate grid
if (is.null(grid)) {
if (is.null(evaluate_at)) {
if (!is.null(breaks) && is.numeric(breaks)) {
evaluate_at <- midpoints(breaks)
} else {
evaluate_at <- auto_cut(
data[[v]], n_bins = n_bins, cut_type = cut_type, ...
)$bin_means
}
}
grid <- stats::setNames(data.frame(evaluate_at), v)
} else {
stopifnot("grid must be a data.frame" = is.data.frame(grid))
v <- names(grid)
}
# Pick ids
if (is.null(indices)) {
if (n > n_max) {
if (!is.null(seed)) {
set.seed(seed)
}
indices <- sample(n, n_max)
} else {
indices <- seq_len(n)
}
}
data <- data[indices, , drop = FALSE]
# Full outer join of data and grid
cols <- colnames(data)
data[, v] <- NULL
data[[id_name]] <- indices
data <- tidyr::expand_grid(data, grid)
# Update flashlight
x <- flashlight(
x,
data = data[, cols, drop = FALSE],
by = by,
linkinv = if (use_linkinv) x$linkinv else function(z) z
)
# Add predictions and organize output
data <- data[, c(id_name, by, v, x$w), drop = FALSE]
data[[value_name]] <- stats::predict(x)
# c-ICE curves
if (center == "0") {
data[[value_name]] <- grouped_center(
data, x = value_name, by = id_name, na.rm = TRUE
)
} else if (center == "mean") {
centered_values <- grouped_center(
data, x = value_name, by = id_name, na.rm = TRUE
)
if (is.null(by)) {
data[[value_name]] <- centered_values +
MetricsWeighted::weighted_mean(
data[[value_name]], w = if (!is.null(x$w)) data[[x$w]], na.rm = TRUE
)
} else {
group_means <- grouped_stats(
data,
x = value_name,
w = x$w,
by = by,
counts = FALSE,
value_name = "global_mean",
na.rm = TRUE
)
stopifnot(!("global_mean" %in% colnames(data)))
data[[value_name]] <- centered_values +
dplyr::left_join(data, group_means, by = by)[["global_mean"]]
}
} else if (center != "no") {
pos <- switch(
center,
first = 1,
middle = floor((nrow(grid) + 1L) / 2),
last = nrow(grid))
data[[value_name]] <- stats::ave(
data[[value_name]], data[[id_name]], FUN = function(z) z - z[pos]
)
}
# Finalize output
data[[label_name]] <- x$label
out <- list(data = data, by = by, v = names(grid), center = center)
add_classes(out, c("light_ice", "light"))
}
#' @describeIn light_ice ICE profiles for a multiflashlight object.
#' @export
light_ice.multiflashlight <- function(x, ...) {
light_combine(lapply(x, light_ice, ...), new_class = "light_ice_multi")
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.