Nothing
#' Partial Dependence and other Profiles
#'
#' @description
#' Calculates different types of profiles across covariable values.
#' By default, partial dependence profiles are calculated (see Friedman).
#' Other options are profiles of ALE (accumulated local effects, see Apley),
#' response, predicted values ("M plots" or "marginal plots", see Apley),
#' residuals, and shap. The results are aggregated either by (weighted) means or by
#' (weighted) quartiles.
#'
#' Note that ALE profiles are calibrated by (weighted) average predictions.
#' In contrast to the suggestions in Apley, we calculate ALE profiles of factors
#' in the same order as the factor levels.
#' They are not being reordered based on similiarity of other variables.
#'
#' @details
#' Numeric covariables `v` with more than `n_bins` disjoint values
#' are binned into `n_bins` bins. Alternatively, `breaks` can be provided
#' to specify the binning. For partial dependence profiles
#' (and partly also ALE profiles), this behaviour can be overwritten either
#' by providing a vector of evaluation points (`pd_evaluate_at`) or an
#' evaluation `pd_grid`. By the latter we mean a data frame with column name(s)
#' with a (multi-)variate evaluation grid.
#'
#' For partial dependence, ALE, and prediction profiles, "model", "predict_function",
#' "linkinv" and "data" are required. For response profiles its "y", "linkinv" and
#' "data", and for shap profiles it is just "shap". "data" can be passed on the fly.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`. Not used for `type = "shap"`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param type Type of the profile: Either "partial dependence", "ale", "predicted",
#' "response", "residual", or "shap".
#' @param stats Statistic to calculate: "mean" or "quartiles". For ALE profiles,
#' only "mean" makes sense.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic binning via
#' `n_bins` and `cut_type`. Ignored if `v` is not numeric.
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#' Ignored if `v` is not numeric or if `breaks` is specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#' Ignored if `v` is not numeric or if `breaks` is specified.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' Not used for type "shap".
#' @param counts Should observation counts be added?
#' @param counts_weighted If `counts = TRUE`: Should counts be weighted by the
#' case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param v_labels If `FALSE`, return group centers of `v` instead of labels.
#' Only relevant for types "response", "predicted" or "residual" and if `v`
#' is being binned. In that case useful, for instance, if different flashlights
#' use different data sets and bin labels would not match.
#' @param pred Optional vector with predictions (after application of inverse link).
#' Can be used to avoid recalculation of predictions over and over if the functions
#' is to be repeatedly called for different `v` and predictions are computationally
#' expensive to make. Not implemented for multiflashlight.
#' @param pd_evaluate_at Vector with values of `v` used to evaluate the profile.
#' Only relevant for type = "partial dependence" and "ale".
#' @param pd_grid A `data.frame` with grid values, e.g., generated by [expand.grid()].
#' Only used for type = "partial dependence".
#' @param pd_indices A vector of row numbers to consider in calculating
#' partial dependence profiles and "ale".
#' @param pd_n_max Maximum number of ICE profiles to calculate (will be randomly
#' picked from `data`) for partial dependence and ALE.
#' @param pd_seed Integer random seed used to select ICE profiles for partial dependence
#' and ALE.
#' @param pd_center How should ICE curves be centered?
#' - Default is "no".
#' - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#' - Choose "mean" to center all profiles at the within-group means.
#' - Choose "0" to mean-center curves at 0. Only relevant for partial dependence.
#' @param ale_two_sided If `TRUE`, `v` is continuous and `breaks`
#' are passed or being calculated, then two-sided derivatives are calculated
#' for ALE instead of left derivatives. More specifically: Usually, local effects
#' at value x are calculated using points in \eqn{[x-e, x]}.
#' Set `ale_two_sided = TRUE` to use points in \eqn{[x-e/2, x+e/2]}.
#' @param ... Further arguments passed to [cut3()] in forming the
#' cut breaks of the `v` variable.
#' @returns
#' An object of class "light_profile" with the following elements:
#' - `data` A tibble containing results. Can be used to build fully customized
#' visualizations. Column names can be controlled by `options(flashlight.column_name)`.
#' - `by` Names of group by variable.
#' - `v` The variable(s) evaluated.
#' - `type` Same as input `type`. For information only.
#' - `stats` Same as input `stats`.
#' @export
#' @references
#' - Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. The Annals of Statistics, 29:1189–1232.
#' - Apley D. W. (2016). Visualizing the effects of predictor variables in black box supervised learning models.
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' fl <- flashlight(model = fit, label = "iris", data = iris, y = "Sepal.Length")
#' light_profile(fl, v = "Species")
#' light_profile(fl, v = "Petal.Width", type = "residual")
#' @seealso [light_effects()], [plot.light_profile()]
light_profile <- function(x, ...) {
UseMethod("light_profile")
}
#' @describeIn light_profile Default method not implemented yet.
#' @export
light_profile.default <- function(x, ...) {
stop("No default method available yet.")
}
#' @describeIn light_profile Profiles for flashlight.
#' @export
light_profile.flashlight <- function(x, v = NULL, data = NULL, by = x$by,
type = c("partial dependence", "ale",
"predicted", "response",
"residual", "shap"),
stats = c("mean", "quartiles"),
breaks = NULL, n_bins = 11L,
cut_type = c("equal", "quantile"),
use_linkinv = TRUE, counts = TRUE,
counts_weighted = FALSE,
v_labels = TRUE, pred = NULL,
pd_evaluate_at = NULL, pd_grid = NULL,
pd_indices = NULL, pd_n_max = 1000L,
pd_seed = NULL,
pd_center = c("no", "first", "middle",
"last", "mean", "0"),
ale_two_sided = FALSE, ...) {
type <- match.arg(type)
stats <- match.arg(stats)
cut_type <- match.arg(cut_type)
pd_center <- match.arg(pd_center)
if (stats == "quartiles") {
message("stats = 'quartiles' is deprecated and will be removed in flashlight 1.0.0.")
}
if (type == "shap") {
message("type = 'shap' is deprecated and will be removed in flashlight 1.0.0.")
}
warning_on_names(
c("value_name", "label_name", "q1_name", "q3_name", "type_name", "counts_name"),
...
)
value_name <- getOption("flashlight.value_name")
label_name <- getOption("flashlight.label_name")
q1_name <- getOption("flashlight.q1_name")
q3_name <- getOption("flashlight.q3_name")
type_name <- getOption("flashlight.type_name")
counts_name <- getOption("flashlight.counts_name")
# If SHAP, extract data
if (type == "shap") {
if (!is.shap(x$shap)) {
stop("No shap values calculated. Run 'add_shap' for the flashlight first.")
}
stopifnot(v %in% colnames(x$shap$data))
variable_name <- getOption("flashlight.variable_name")
data <- x$shap$data[x$shap$data[[variable_name]] == v, ]
} else if (is.null(data)) {
data <- x$data
}
# Checks (more will be done below or in the called functions)
stopifnot(
"No data!" = is.data.frame(data) && nrow(data) >= 1L,
"'by' not in 'data'!" = by %in% colnames(data),
"'v' not in 'data'." = v %in% colnames(data),
"'v' or 'pd_grid' misses." = !is.null(pd_grid) || !is.null(v)
)
check_unique(
c(by, union(v, names(pd_grid))),
opt_names = c(if (counts) counts_name,
if (stats == "quartiles") c(q1_name, q3_name),
value_name, label_name, type_name)
)
if (!is.null(pred) && type == "predicted" && length(pred) != nrow(data)) {
stop("Wrong number of predicted values passed.")
}
if (type == "ale" && stats == "quartiles") {
stop("The cumsum step of ALE does not make sense for quartiles, so this option is not available.")
}
# Update flashlight
if (type != "shap") {
x <- flashlight(
x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z
)
}
# Calculate profiles
arg_list <- list(
x = x,
v = v,
evaluate_at = pd_evaluate_at,
breaks = breaks,
n_bins = n_bins,
cut_type = cut_type,
indices = pd_indices,
n_max = pd_n_max,
seed = pd_seed
)
if (type == "partial dependence") {
arg_list <- c(arg_list, list(grid = pd_grid, center = pd_center))
withr::with_options(
list(flashlight.id_name = "id_xxx"), # safer than default
cp_profiles <- do.call(light_ice, arg_list)
)
v <- cp_profiles$v
data <- cp_profiles$data
} else if (type == "ale") {
arg_list <- c(
arg_list,
list(
counts = counts,
counts_weighted = counts_weighted,
pred = pred,
two_sided = ale_two_sided
)
)
agg <- do.call(ale_profile, arg_list)
} else {
stopifnot(
"'v' misses." = !is.null(v),
"'v' not in data." = v %in% colnames(data)
)
if (type %in% c("response", "residual") && is.null(x$y)) {
stop("You need to specify 'y' in flashlight.")
}
# Add predictions/response to data
data[[value_name]] <- switch(type,
response = response(x),
predicted = if (is.null(pred)) stats::predict(x) else pred,
residual = stats::residuals(x),
shap = data[["shap_"]]
)
# Replace v values by binned ones
cuts <- auto_cut(
data[[v]], breaks = breaks, n_bins = n_bins, cut_type = cut_type, ...
)
data[[v]] <- cuts$data[[if (v_labels) "level" else "value"]]
}
# Aggregate predicted values
if (type != "ale") { # ale is already aggregated
agg <- grouped_stats(
data = data,
x = value_name,
w = x$w,
by = c(by, v),
stats = stats,
counts = counts,
counts_weighted = counts_weighted,
counts_name = counts_name,
q1_name = q1_name,
q3_name = q3_name,
na.rm = TRUE
)
}
# Finalize results
agg[[label_name]] <- x$label
# Code type as factor (relevant for light_effects)
agg[[type_name]] <- factor(
type, c("response", "predicted", "partial dependence", "ale", "residual", "shap")
)
# Collect results
out <- list(data = agg, by = by, v = v, type = type, stats = stats)
add_classes(out, c("light_profile", "light"))
}
#' @describeIn light_profile Profiles for multiflashlight.
#' @export
light_profile.multiflashlight <- function(x, v = NULL, data = NULL,
type = c("partial dependence", "ale",
"predicted", "response",
"residual", "shap"),
breaks = NULL, n_bins = 11L,
cut_type = c("equal", "quantile"),
pd_evaluate_at = NULL,
pd_grid = NULL, ...) {
type <- match.arg(type)
cut_type <- match.arg(cut_type)
is_pd <- type == "partial dependence"
is_ale <- type == "ale"
if ("pred" %in% names(list(...))) {
stop("'pred' not implemented for multiflashlight")
}
# Align breaks for numeric v
if (is.null(pd_grid) || !is_pd) {
stopifnot("Need exactly one 'v'." = length(v) == 1L)
if (is.null(breaks) && (is.null(pd_evaluate_at) || (!is_pd && !is_ale))) {
breaks <- common_breaks(
x = x, v = v, data = data, n_bins = n_bins, cut_type = cut_type
)
}
}
all_profiles <- lapply(
x,
light_profile,
v = v,
data = data,
type = type,
breaks = breaks,
n_bins = n_bins,
cut_type = cut_type,
pd_evaluate_at = pd_evaluate_at,
pd_grid = pd_grid,
...
)
light_combine(all_profiles, new_class = "light_profile_multi")
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.