#' Partial Dependence and other Profiles
#'
#' @description
#' Calculates different types of profiles across covariable values.
#' By default, partial dependence profiles are calculated (see Friedman).
#' Other options are profiles of ALE (accumulated local effects, see Apley),
#' response, predicted values ("M plots" or "marginal plots", see Apley), and residuals.
#' The results are aggregated either by (weighted) means or by (weighted) quartiles.
#'
#' Note that ALE profiles are calibrated by (weighted) average predictions.
#' In contrast to the suggestions in Apley, we calculate ALE profiles of factors
#' in the same order as the factor levels.
#' They are not being reordered based on similiarity of other variables.
#'
#' @details
#' Numeric covariables `v` with more than `n_bins` disjoint values
#' are binned into `n_bins` bins. Alternatively, `breaks` can be provided
#' to specify the binning. For partial dependence profiles
#' (and partly also ALE profiles), this behaviour can be overwritten either
#' by providing a vector of evaluation points (`pd_evaluate_at`) or an
#' evaluation `pd_grid`. By the latter we mean a data frame with column name(s)
#' with a (multi-)variate evaluation grid.
#'
#' For partial dependence, ALE, and prediction profiles, "model", "predict_function",
#' "linkinv" and "data" are required. For response profiles its "y", "linkinv" and
#' "data". "data" can also be passed on the fly.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param type Type of the profile: Either "partial dependence", "ale", "predicted",
#' "response", or "residual".
#' @param stats Deprecated. Will be removed in version 1.1.0.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic binning via
#' `n_bins` and `cut_type`. Ignored if `v` is not numeric.
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#' Ignored if `v` is not numeric or if `breaks` is specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#' Ignored if `v` is not numeric or if `breaks` is specified.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param counts Should observation counts be added?
#' @param counts_weighted If `counts = TRUE`: Should counts be weighted by the
#' case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param v_labels If `FALSE`, return group centers of `v` instead of labels.
#' Only relevant for types "response", "predicted" or "residual" and if `v`
#' is being binned. In that case useful, for instance, if different flashlights
#' use different data sets and bin labels would not match.
#' @param pred Optional vector with predictions (after application of inverse link).
#' Can be used to avoid recalculation of predictions over and over if the functions
#' is to be repeatedly called for different `v` and predictions are computationally
#' expensive to make. Not implemented for multiflashlight.
#' @param pd_evaluate_at Vector with values of `v` used to evaluate the profile.
#' Only relevant for type = "partial dependence" and "ale".
#' @param pd_grid A `data.frame` with grid values, e.g., generated by [expand.grid()].
#' Only used for type = "partial dependence".
#' @param pd_indices A vector of row numbers to consider in calculating
#' partial dependence profiles and "ale".
#' @param pd_n_max Maximum number of ICE profiles to calculate (will be randomly
#' picked from `data`) for partial dependence and ALE.
#' @param pd_seed Integer random seed used to select ICE profiles for partial dependence
#' and ALE.
#' @param pd_center How should ICE curves be centered?
#' - Default is "no".
#' - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#' - Choose "mean" to center all profiles at the within-group means.
#' - Choose "0" to mean-center curves at 0. Only relevant for partial dependence.
#' @param ale_two_sided If `TRUE`, `v` is continuous and `breaks`
#' are passed or being calculated, then two-sided derivatives are calculated
#' for ALE instead of left derivatives. More specifically: Usually, local effects
#' at value x are calculated using points in \eqn{[x-e, x]}.
#' Set `ale_two_sided = TRUE` to use points in \eqn{[x-e/2, x+e/2]}.
#' @param ... Further arguments passed to [formatC()] in forming the
#' cut breaks of the `v` variable.
#' @returns
#' An object of class "light_profile" with the following elements:
#' - `data` A tibble containing results.
#' - `by` Names of group by variable.
#' - `v` The variable(s) evaluated.
#' - `type` Same as input `type`. For information only.
#' @export
#' @references
#' - Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine.
#' The Annals of Statistics, 29:1189–1232.
#' - Apley D. W. (2016). Visualizing the effects of predictor variables in black box
#' supervised learning models.
#' @examples
#' fit_lin <- lm(Sepal.Length ~ ., data = iris)
#' fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length")
#'
#' # PDP by Species
#' plot(light_profile(fl_lin, v = "Petal.Length", by = "Species"))
#'
#' # Average predicted
#' plot(light_profile(fl_lin, v = "Petal.Length", type = "pred"))
#'
#' # Second model with non-linear Petal.Length effect
#' fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris)
#' fl_nonlin <- flashlight(
#' model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length"
#' )
#' fls <- multiflashlight(list(fl_lin, fl_nonlin))
#'
#' # PDP by Species
#' plot(light_profile(fls, v = "Petal.Length", by = "Species"))
#' plot(light_profile(fls, v = "Petal.Length", by = "Species"), swap_dim = TRUE)
#'
#' # Average residuals (calibration)
#' plot(light_profile(fls, v = "Petal.Length", type = "residual"))
#' @seealso [light_effects()], [plot.light_profile()]
light_profile <- function(x, ...) {
UseMethod("light_profile")
}
#' @describeIn light_profile Default method not implemented yet.
#' @export
light_profile.default <- function(x, ...) {
stop("No default method available yet.")
}
#' @describeIn light_profile Profiles for flashlight.
#' @export
light_profile.flashlight <- function(x, v = NULL, data = NULL, by = x$by,
type = c("partial dependence", "ale",
"predicted", "response",
"residual", "shap"),
stats = "mean",
breaks = NULL, n_bins = 11L,
cut_type = c("equal", "quantile"),
use_linkinv = TRUE, counts = TRUE,
counts_weighted = FALSE,
v_labels = TRUE, pred = NULL,
pd_evaluate_at = NULL, pd_grid = NULL,
pd_indices = NULL, pd_n_max = 1000L,
pd_seed = NULL,
pd_center = c("no", "first", "middle",
"last", "mean", "0"),
ale_two_sided = FALSE, ...) {
type <- match.arg(type)
cut_type <- match.arg(cut_type)
pd_center <- match.arg(pd_center)
if (stats == "quartiles") {
stop("stats = 'quartiles' is deprecated. The argument 'stats' will be removed in version 1.1.0.")
}
if (type == "shap") {
stop("type = 'shap' is deprecated.")
}
if (is.null(data)) {
data <- x$data
}
# Checks (more will be done below or in the called functions)
temp_vars <- c("value_", "label_", "type_", "counts_")
stopifnot(
"No data!" = is.data.frame(data) && nrow(data) >= 1L,
"'by' not in 'data'!" = by %in% colnames(data),
"'v' not in 'data'." = v %in% colnames(data),
"'v' or 'pd_grid' misses." = !is.null(pd_grid) || !is.null(v),
!any(temp_vars %in% c(by, v, names(pd_grid)))
)
if (!is.null(pred) && type == "predicted" && length(pred) != nrow(data)) {
stop("Wrong number of predicted values passed.")
}
# Update flashlight
x <- flashlight(
x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z
)
# Calculate profiles
arg_list <- list(
x = x,
v = v,
evaluate_at = pd_evaluate_at,
breaks = breaks,
n_bins = n_bins,
cut_type = cut_type,
indices = pd_indices,
n_max = pd_n_max,
seed = pd_seed
)
if (type == "partial dependence") {
arg_list <- c(arg_list, list(grid = pd_grid, center = pd_center))
cp_profiles <- do.call(light_ice, arg_list)
v <- cp_profiles$v
data <- cp_profiles$data
} else if (type == "ale") {
arg_list <- c(
arg_list,
list(
counts = counts,
counts_weighted = counts_weighted,
pred = pred,
two_sided = ale_two_sided
)
)
agg <- do.call(ale_profile, arg_list)
} else {
stopifnot(
"'v' misses." = !is.null(v),
"'v' not in data." = v %in% colnames(data)
)
if (type %in% c("response", "residual") && is.null(x$y)) {
stop("You need to specify 'y' in flashlight.")
}
# Add predictions/response to data
data$value_ <- switch(
type,
response = response(x),
predicted = if (is.null(pred)) stats::predict(x) else pred,
residual = stats::residuals(x)
)
# Replace v values by binned ones
cuts <- auto_cut(
data[[v]], breaks = breaks, n_bins = n_bins, cut_type = cut_type, ...
)
data[[v]] <- cuts$data[[if (v_labels) "level" else "value"]]
}
# Aggregate predicted values
if (type != "ale") { # ale is already aggregated
agg <- grouped_stats(
data = data,
x = "value_",
w = x$w,
by = c(by, v),
counts = counts,
counts_weighted = counts_weighted,
counts_name = "counts_",
na.rm = TRUE
)
}
# Finalize results
type_lev <- c("response", "predicted", "partial dependence", "ale", "residual")
agg <- transform(agg, label_ = x$label, type_ = factor(type, type_lev))
out <- list(data = agg, by = by, v = v, type = type)
add_classes(out, c("light_profile", "light"))
}
#' @describeIn light_profile Profiles for multiflashlight.
#' @export
light_profile.multiflashlight <- function(x, v = NULL, data = NULL,
type = c("partial dependence", "ale",
"predicted", "response",
"residual", "shap"),
breaks = NULL, n_bins = 11L,
cut_type = c("equal", "quantile"),
pd_evaluate_at = NULL,
pd_grid = NULL, ...) {
type <- match.arg(type)
cut_type <- match.arg(cut_type)
is_pd <- type == "partial dependence"
is_ale <- type == "ale"
if ("pred" %in% names(list(...))) {
stop("'pred' not implemented for multiflashlight")
}
# Align breaks for numeric v
if (is.null(pd_grid) || !is_pd) {
stopifnot("Need exactly one 'v'." = length(v) == 1L)
if (is.null(breaks) && (is.null(pd_evaluate_at) || (!is_pd && !is_ale))) {
breaks <- common_breaks(
x = x, v = v, data = data, n_bins = n_bins, cut_type = cut_type
)
}
}
all_profiles <- lapply(
x,
light_profile,
v = v,
data = data,
type = type,
breaks = breaks,
n_bins = n_bins,
cut_type = cut_type,
pd_evaluate_at = pd_evaluate_at,
pd_grid = pd_grid,
...
)
light_combine(all_profiles, new_class = "light_profile_multi")
}
#' Visualize Profiles, e.g. Partial Dependence
#'
#' Minimal visualization of an object of class "light_profile".
#' The object returned is of class "ggplot" and can be further customized.
#'
#' Either lines and points are plotted (if stats = "mean") or quartile boxes.
#' If there is a "by" variable or a multiflashlight, this first dimension
#' is represented by color (or if `swap_dim = TRUE` by facets).
#' If there are two "by" variables or a multiflashlight with one "by" variable,
#' the first "by" variable is visualized as color, while the second one
#' or the multiflashlight is shown via facet (change with `swap_dim`).
#'
#' @importFrom rlang .data
#'
#' @inheritParams plot.light_performance
#' @param x An object of class "light_profile".
#' @param swap_dim If multiflashlight and one "by" variable or
#' single flashlight with two "by" variables, swap the role of dodge/fill variable
#' and facet variable. If multiflashlight or one "by" variable,
#' use facets instead of colors.
#' @param show_points Should points be added to the line (default is `TRUE`).
#' @param ... Further arguments passed to [ggplot2::geom_point()] or
#' [ggplot2::geom_line()].
#' @returns An object of class "ggplot".
#' @export
#' @seealso [light_profile()], [plot.light_effects()]
plot.light_profile <- function(x, swap_dim = FALSE, facet_scales = "free_x",
rotate_x = x$type != "partial dependence",
show_points = TRUE, ...) {
data <- x$data
nby <- length(x$by)
multi <- is.light_profile_multi(x)
ndim <- nby + multi
if (ndim > 2L) {
stop("Plot method not defined for more than two by variables or
multiflashlight with more than one by variable.")
}
if (length(x$v) >= 2L) {
stop("No plot method defined for two or higher dimensional grids.")
}
# Distinguish some cases
p <- ggplot2::ggplot(x$data, ggplot2::aes(y = value_, x = .data[[x$v]]))
if (ndim == 0L) {
p <- p + ggplot2::geom_line(ggplot2::aes(group = 1), ...)
if (show_points) {
p <- p + ggplot2::geom_point(...)
}
} else if (ndim == 1L) {
first_dim <- if (multi) "label_" else x$by[1L]
if (!swap_dim) {
p <- p + ggplot2::geom_line(
ggplot2::aes(color = .data[[first_dim]], group = .data[[first_dim]]), ...
)
if (show_points) {
p <- p + ggplot2::geom_point(ggplot2::aes(color = .data[[first_dim]]), ...)
}
} else {
p <- p +
ggplot2::facet_wrap(first_dim, scales = facet_scales) +
ggplot2::geom_line(ggplot2::aes(group = 1), ...)
if (show_points) {
p <- p + ggplot2::geom_point(...)
}
}
} else if (ndim == 2L) {
second_dim <- if (multi) "label_" else x$by[2L]
wrap_var <- if (swap_dim) x$by[1L] else second_dim
col_var <- if (swap_dim) second_dim else x$by[1L]
p <- p + ggplot2::geom_line(
ggplot2::aes(color = .data[[col_var]], group = .data[[col_var]]), ...
)
if (show_points) {
p <- p + ggplot2::geom_point(ggplot2::aes(color = .data[[col_var]]), ...)
}
p <- p + ggplot2::facet_wrap(wrap_var, scales = facet_scales)
}
if (rotate_x) {
p <- p + rotate_x()
}
p + ggplot2::ylab(x$type)
}
#' ALE profile
#'
#' Internal function used by [light_profile()] to calculate ALE profiles.
#'
#' @noRd
#' @param x An object of class "flashlight".
#' @param v The variable to be profiled.
#' @param breaks Cut breaks for a numeric `v`. Only used if no `evaluate_at` is specified.
#' @param n_bins Maxmium number of unique values to evaluate for numeric `v`.
#' Only used if no `evaluate_at` is specified.
#' @param cut_type For the default "equal", bins of equal width are created for `v`
#' by [pretty()]. Choose "quantile" to create quantile bins.
#' @param counts Should counts be added?
#' @param counts_weighted If `counts = TRUE`: Should counts be weighted by the
#' case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param pred Optional vector with predictions.
#' @param evaluate_at Vector with values of `v` used to evaluate the profile.
#' Only relevant for type = "partial dependence".
#' @param indices A vector of row numbers to consider.
#' @param n_max Maximum number of ICE profiles to calculate within interval (not within data).
#' @param seed Integer random seed passed to [light_ice()].
#' @param two_sided Standard ALE profiles are calculated via left derivatives.
#' Set to `TRUE` if two-sided derivatives should be calculated.
#' Only works for continuous `v`. More specifically: Usually, local effects at
#' value x are calculated using points in \eqn{[x-e, x]}. Set `ale_two_sided = TRUE`
#' to use points in \eqn{[x-e/2, x+e/2]} instead.
#' @param calibrate Should values be calibrated based on average preditions?
#' Default is `TRUE`.
#' @param ... Other arguments passed to this function (currently unused).
#' @returns A tibble containing results.
ale_profile <- function(x, v, breaks = NULL, n_bins = 11L,
cut_type = c("equal", "quantile"),
counts = TRUE, counts_weighted = FALSE,
pred = NULL, evaluate_at = NULL,
indices = NULL, n_max = 1000L, seed = NULL,
two_sided = FALSE, calibrate = TRUE, ...) {
cut_type <- match.arg(cut_type)
data <- x$data
stopifnot(
"No data!" = is.data.frame(data) && nrow(data) >= 1L,
"'v' not specified." = !is.null(v),
"'v' not in 'data'." = v %in% colnames(data),
!any(c("value_", "counts_", "id_") %in% c(x$by, v))
)
if (!is.null(seed)) {
set.seed(seed)
}
if (!is.null(indices)) {
data <- data[indices, , drop = FALSE]
if (!is.null(pred)) {
pred <- pred[indices]
}
}
is_num <- is.numeric(data[[v]])
# Evaluation points (including shift for two-sided derivatives)
if (is.null(evaluate_at)) {
if (!is.null(breaks)) {
evaluate_at <- midpoints(breaks)
} else {
cuts <- auto_cut(data[[v]], breaks = breaks, n_bins = n_bins, cut_type = cut_type)
breaks <- cuts$breaks
evaluate_at <- cuts$bin_means
}
}
if (two_sided) {
if (!is.null(breaks)) {
evaluate_at_orig <- evaluate_at
evaluate_at <- breaks[-1L]
} else {
two_sided <- FALSE
}
}
# Helper function used to calculate differences for any pair of x values
ale_core <- function(from_to) {
from <- from_to[1L]
to <- from_to[2L]
if (is_num) {
.s <- data[[v]] >= from & data[[v]] <= to
} else {
.s <- data[[v]] %in% from_to
}
dat_i <- data[.s, , drop = FALSE]
if (nrow(dat_i) == 0L) {
return(NULL)
}
ice <- light_ice(x, v = v, data = dat_i, evaluate_at = from_to, n_max = n_max)$data
if (is_num) {
ice <- transform(ice, value_ = if (identical(to, from)) 0 else value_/(to - from))
}
# Safe reshaping
dat_to <- ice[ice[[v]] %in% to, ]
dat_from <- ice[ice[[v]] %in% from, ]
dat_to <- transform(
dat_to, value_ = value_ - dat_from$value_[match(id_, dat_from$id_)]
)
# Aggregation and output
out <- grouped_stats(
dat_to,
x = "value_",
w = x$w,
by = x$by,
counts_weighted = counts_weighted,
na.rm = TRUE
)
out[[v]] <- to
out
}
# Call ale_core once per interval and combine results
eval_pair <- data.frame(
from = evaluate_at[c(1L, 1:(length(evaluate_at) - 1L))], to = evaluate_at
)
ale <- dplyr::bind_rows(apply(eval_pair, 1L, ale_core))
# Remove missing values before accumulation
bad <- is.na(ale$value_)
if (any(bad)) {
ale <- ale[!bad, , drop = FALSE]
}
# Accumulate effects. Integrate out gaps
wcumsum <- function(X) {
transform(
X,
value_ = as.numeric(cumsum(value_ * (if (is_num) c(0, diff(X[[v]])) else 1)))
)
}
ale <- Reframe(ale, FUN = wcumsum, .by = x$by)
if (is.factor(data[[v]])) {
ale[[v]] <- factor(ale[[v]], levels = levels(data[[v]]))
}
# Calibrate effects
if (calibrate) {
preds <- if (is.null(pred)) stats::predict(x) else pred
if (is.null(x$by)) {
pred_mean <- MetricsWeighted::weighted_mean(
preds, if (!is.null(x$w)) data[[x$w]], na.rm = TRUE
)
ale_mean <- MetricsWeighted::weighted_mean(
ale$value_, w = ale$counts_, na.rm = TRUE
)
ale <- tibble::as_tibble(transform(ale, value_ = value_ - ale_mean + pred_mean))
} else {
stopifnot(!(c("cal_xx", "shift_xx") %in% colnames(data)))
dat_pred <- grouped_stats(
cbind(data, cal_xx = preds),
x = "cal_xx",
w = x$w,
by = x$by,
counts = FALSE,
na.rm = TRUE
)
dat_ale <- grouped_stats(
ale, x = "value_", w = "counts_", by = x$by, counts = FALSE, na.rm = TRUE
)
dat_shift <- dplyr::left_join(dat_ale, dat_pred, by = x$by)
dat_shift <- transform(dat_shift, shift_xx = cal_xx - value_)
ale <- dplyr::left_join(
ale, dat_shift[, c(x$by, "shift_xx"), drop = FALSE], by = x$by
)
ale <- transform(ale, value_ = value_ + shift_xx, shift_xx = NULL)
}
}
# Revert shift for two-sided derivatives
if (two_sided) {
ale[[v]] <- evaluate_at_orig[match(ale[[v]], breaks[-1L])]
}
ale[, c(x$by, v, if (counts && "counts_" %in% colnames(ale)) "counts_", "value_")]
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.