R/light_profile.R

Defines functions ale_profile plot.light_profile light_profile.multiflashlight light_profile.flashlight light_profile.default light_profile

Documented in light_profile light_profile.default light_profile.flashlight light_profile.multiflashlight plot.light_profile

#' Partial Dependence and other Profiles
#'
#' @description
#' Calculates different types of profiles across covariable values.
#' By default, partial dependence profiles are calculated (see Friedman).
#' Other options are profiles of ALE (accumulated local effects, see Apley),
#' response, predicted values ("M plots" or "marginal plots", see Apley), and residuals.
#' The results are aggregated either by (weighted) means or by (weighted) quartiles.
#'
#' Note that ALE profiles are calibrated by (weighted) average predictions.
#' In contrast to the suggestions in Apley, we calculate ALE profiles of factors
#' in the same order as the factor levels.
#' They are not being reordered based on similiarity of other variables.
#'
#' @details
#' Numeric covariables `v` with more than `n_bins` disjoint values
#' are binned into `n_bins` bins. Alternatively, `breaks` can be provided
#' to specify the binning. For partial dependence profiles
#' (and partly also ALE profiles), this behaviour can be overwritten either
#' by providing a vector of evaluation points (`pd_evaluate_at`) or an
#' evaluation `pd_grid`. By the latter we mean a data frame with column name(s)
#' with a (multi-)variate evaluation grid.
#'
#' For partial dependence, ALE, and prediction profiles, "model", "predict_function",
#' "linkinv" and "data" are required. For response profiles its "y", "linkinv" and
#' "data". "data" can also be passed on the fly.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v The variable name to be profiled.
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param type Type of the profile: Either "partial dependence", "ale", "predicted",
#'   "response", or "residual".
#' @param stats Deprecated. Will be removed in version 1.1.0.
#' @param breaks Cut breaks for a numeric `v`. Used to overwrite automatic binning via
#'   `n_bins` and `cut_type`. Ignored if `v` is not numeric.
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#'   Ignored if `v` is not numeric or if `breaks` is specified.
#' @param cut_type Should a numeric `v` be cut into "equal" or "quantile" bins?
#'   Ignored if `v` is not numeric or if `breaks` is specified.
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param counts Should observation counts be added?
#' @param counts_weighted If `counts = TRUE`: Should counts be weighted by the
#'   case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param v_labels If `FALSE`, return group centers of `v` instead of labels.
#'   Only relevant for types "response", "predicted" or "residual" and if `v`
#'   is being binned. In that case useful, for instance, if different flashlights
#'   use different data sets and bin labels would not match.
#' @param pred Optional vector with predictions (after application of inverse link).
#'   Can be used to avoid recalculation of predictions over and over if the functions
#'   is to be repeatedly called for different `v` and predictions are computationally
#'   expensive to make. Not implemented for multiflashlight.
#' @param pd_evaluate_at Vector with values of `v` used to evaluate the profile.
#'   Only relevant for type = "partial dependence" and "ale".
#' @param pd_grid A `data.frame` with grid values, e.g., generated by [expand.grid()].
#'   Only used for type = "partial dependence".
#' @param pd_indices A vector of row numbers to consider in calculating
#'   partial dependence profiles and "ale".
#' @param pd_n_max Maximum number of ICE profiles to calculate (will be randomly
#'   picked from `data`) for partial dependence and ALE.
#' @param pd_seed Integer random seed used to select ICE profiles for partial dependence
#'   and ALE.
#' @param pd_center How should ICE curves be centered?
#'   - Default is "no".
#'   - Choose "first", "middle", or "last" to 0-center at specific evaluation points.
#'   - Choose "mean" to center all profiles at the within-group means.
#'   - Choose "0" to mean-center curves at 0. Only relevant for partial dependence.
#' @param ale_two_sided If `TRUE`, `v` is continuous and `breaks`
#'   are passed or being calculated, then two-sided derivatives are calculated
#'   for ALE instead of left derivatives. More specifically: Usually, local effects
#'   at value x are calculated using points in \eqn{[x-e, x]}.
#'   Set `ale_two_sided = TRUE` to use points in \eqn{[x-e/2, x+e/2]}.
#' @param ... Further arguments passed to [formatC()] in forming the
#'   cut breaks of the `v` variable.
#' @returns
#'   An object of class "light_profile" with the following elements:
#'   - `data` A tibble containing results.
#'   - `by` Names of group by variable.
#'   - `v` The variable(s) evaluated.
#'   - `type` Same as input `type`. For information only.
#' @export
#' @references
#' - Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine.
#'   The Annals of Statistics, 29:1189–1232.
#' - Apley D. W. (2016). Visualizing the effects of predictor variables in black box
#'   supervised learning models.
#' @examples
#' fit_lin <- lm(Sepal.Length ~ ., data = iris)
#' fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length")
#'
#' # PDP by Species
#' plot(light_profile(fl_lin, v = "Petal.Length", by = "Species"))
#'
#' # Average predicted
#' plot(light_profile(fl_lin, v = "Petal.Length", type = "pred"))
#'
#' # Second model with non-linear Petal.Length effect
#' fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris)
#' fl_nonlin <- flashlight(
#'   model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length"
#' )
#' fls <- multiflashlight(list(fl_lin, fl_nonlin))
#'
#' # PDP by Species
#' plot(light_profile(fls, v = "Petal.Length", by = "Species"))
#' plot(light_profile(fls, v = "Petal.Length", by = "Species"), swap_dim = TRUE)
#'
#' # Average residuals (calibration)
#' plot(light_profile(fls, v = "Petal.Length", type = "residual"))
#' @seealso [light_effects()], [plot.light_profile()]
light_profile <- function(x, ...) {
  UseMethod("light_profile")
}

#' @describeIn light_profile Default method not implemented yet.
#' @export
light_profile.default <- function(x, ...) {
  stop("No default method available yet.")
}

#' @describeIn light_profile Profiles for flashlight.
#' @export
light_profile.flashlight <- function(x, v = NULL, data = NULL, by = x$by,
                                     type = c("partial dependence", "ale",
                                              "predicted", "response",
                                              "residual", "shap"),
                                     stats = "mean",
                                     breaks = NULL, n_bins = 11L,
                                     cut_type = c("equal", "quantile"),
                                     use_linkinv = TRUE, counts = TRUE,
                                     counts_weighted = FALSE,
                                     v_labels = TRUE, pred = NULL,
                                     pd_evaluate_at = NULL, pd_grid = NULL,
                                     pd_indices = NULL, pd_n_max = 1000L,
                                     pd_seed = NULL,
                                     pd_center = c("no", "first", "middle",
                                                   "last", "mean", "0"),
                                     ale_two_sided = FALSE, ...) {
  type <- match.arg(type)
  cut_type <- match.arg(cut_type)
  pd_center <- match.arg(pd_center)

  if (stats == "quartiles") {
    stop("stats = 'quartiles' is deprecated. The argument 'stats' will be removed in version 1.1.0.")
  }
  if (type == "shap") {
    stop("type = 'shap' is deprecated.")
  }

  if (is.null(data)) {
    data <- x$data
  }

  # Checks (more will be done below or in the called functions)
  temp_vars <- c("value_", "label_", "type_", "counts_")
  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'by' not in 'data'!" = by %in% colnames(data),
    "'v' not in 'data'." = v %in% colnames(data),
    "'v' or 'pd_grid' misses." = !is.null(pd_grid) || !is.null(v),
    !any(temp_vars %in% c(by, v, names(pd_grid)))
  )

  if (!is.null(pred) && type == "predicted" && length(pred) != nrow(data)) {
    stop("Wrong number of predicted values passed.")
  }

  # Update flashlight
  x <- flashlight(
    x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z
  )

  # Calculate profiles
  arg_list <- list(
    x = x,
    v = v,
    evaluate_at = pd_evaluate_at,
    breaks = breaks,
    n_bins = n_bins,
    cut_type = cut_type,
    indices = pd_indices,
    n_max = pd_n_max,
    seed = pd_seed
  )
  if (type == "partial dependence") {
    arg_list <- c(arg_list, list(grid = pd_grid, center = pd_center))
    cp_profiles <- do.call(light_ice, arg_list)
    v <- cp_profiles$v
    data <- cp_profiles$data
  } else if (type == "ale") {
    arg_list <- c(
      arg_list,
      list(
        counts = counts,
        counts_weighted = counts_weighted,
        pred = pred,
        two_sided = ale_two_sided
      )
    )
    agg <- do.call(ale_profile, arg_list)
  } else {
    stopifnot(
      "'v' misses." = !is.null(v),
      "'v' not in data." = v %in% colnames(data)
    )
    if (type %in% c("response", "residual") && is.null(x$y)) {
      stop("You need to specify 'y' in flashlight.")
    }

    # Add predictions/response to data
    data$value_ <- switch(
      type,
      response = response(x),
      predicted = if (is.null(pred)) stats::predict(x) else pred,
      residual = stats::residuals(x)
    )

    # Replace v values by binned ones
    cuts <- auto_cut(
      data[[v]], breaks = breaks, n_bins = n_bins, cut_type = cut_type, ...
    )
    data[[v]] <- cuts$data[[if (v_labels) "level" else "value"]]
  }

  # Aggregate predicted values
  if (type != "ale") { # ale is already aggregated
    agg <- grouped_stats(
      data = data,
      x = "value_",
      w = x$w,
      by = c(by, v),
      counts = counts,
      counts_weighted = counts_weighted,
      counts_name = "counts_",
      na.rm = TRUE
    )
  }

  # Finalize results
  type_lev <- c("response", "predicted", "partial dependence", "ale", "residual")
  agg <- transform(agg, label_ = x$label, type_ = factor(type, type_lev))
  out <- list(data = agg, by = by, v = v, type = type)
  add_classes(out, c("light_profile", "light"))
}

#' @describeIn light_profile Profiles for multiflashlight.
#' @export
light_profile.multiflashlight <- function(x, v = NULL, data = NULL,
                                          type = c("partial dependence", "ale",
                                                   "predicted", "response",
                                                   "residual", "shap"),
                                          breaks = NULL, n_bins = 11L,
                                          cut_type = c("equal", "quantile"),
                                          pd_evaluate_at = NULL,
                                          pd_grid = NULL, ...) {
  type <- match.arg(type)
  cut_type <- match.arg(cut_type)

  is_pd <- type == "partial dependence"
  is_ale <- type == "ale"

  if ("pred" %in% names(list(...))) {
    stop("'pred' not implemented for multiflashlight")
  }

  # Align breaks for numeric v
  if (is.null(pd_grid) || !is_pd) {
    stopifnot("Need exactly one 'v'." = length(v) == 1L)
    if (is.null(breaks) && (is.null(pd_evaluate_at) || (!is_pd && !is_ale))) {
      breaks <- common_breaks(
        x = x, v = v, data = data, n_bins = n_bins, cut_type = cut_type
      )
    }
  }
  all_profiles <- lapply(
    x,
    light_profile,
    v = v,
    data = data,
    type = type,
    breaks = breaks,
    n_bins = n_bins,
    cut_type = cut_type,
    pd_evaluate_at = pd_evaluate_at,
    pd_grid = pd_grid,
    ...
  )
  light_combine(all_profiles, new_class = "light_profile_multi")
}

#' Visualize Profiles, e.g. Partial Dependence
#'
#' Minimal visualization of an object of class "light_profile".
#' The object returned is of class "ggplot" and can be further customized.
#'
#' Either lines and points are plotted (if stats = "mean") or quartile boxes.
#' If there is a "by" variable or a multiflashlight, this first dimension
#' is represented by color (or if `swap_dim = TRUE` by facets).
#' If there are two "by" variables or a multiflashlight with one "by" variable,
#' the first "by" variable is visualized as color, while the second one
#' or the multiflashlight is shown via facet (change with `swap_dim`).
#'
#' @importFrom rlang .data
#'
#' @inheritParams plot.light_performance
#' @param x An object of class "light_profile".
#' @param swap_dim If multiflashlight and one "by" variable or
#'   single flashlight with two "by" variables, swap the role of dodge/fill variable
#'   and facet variable. If multiflashlight or one "by" variable,
#'   use facets instead of colors.
#' @param show_points Should points be added to the line (default is `TRUE`).
#' @param ... Further arguments passed to [ggplot2::geom_point()] or
#'   [ggplot2::geom_line()].
#' @returns An object of class "ggplot".
#' @export
#' @seealso [light_profile()], [plot.light_effects()]
plot.light_profile <- function(x, swap_dim = FALSE, facet_scales = "free_x",
                               rotate_x = x$type != "partial dependence",
                               show_points = TRUE, ...) {
  data <- x$data
  nby <- length(x$by)
  multi <- is.light_profile_multi(x)
  ndim <- nby + multi
  if (ndim > 2L) {
    stop("Plot method not defined for more than two by variables or
         multiflashlight with more than one by variable.")
  }
  if (length(x$v) >= 2L) {
    stop("No plot method defined for two or higher dimensional grids.")
  }
  # Distinguish some cases
  p <- ggplot2::ggplot(x$data, ggplot2::aes(y = value_, x = .data[[x$v]]))
  if (ndim == 0L) {
    p <- p + ggplot2::geom_line(ggplot2::aes(group = 1), ...)
    if (show_points) {
      p <- p + ggplot2::geom_point(...)
    }
  } else if (ndim == 1L) {
    first_dim <- if (multi) "label_" else x$by[1L]
    if (!swap_dim) {
      p <- p + ggplot2::geom_line(
        ggplot2::aes(color = .data[[first_dim]], group = .data[[first_dim]]), ...
      )
      if (show_points) {
        p <- p + ggplot2::geom_point(ggplot2::aes(color = .data[[first_dim]]), ...)
      }
    } else {
      p <- p +
        ggplot2::facet_wrap(first_dim, scales = facet_scales) +
        ggplot2::geom_line(ggplot2::aes(group = 1), ...)
      if (show_points) {
        p <- p + ggplot2::geom_point(...)
      }
    }
  } else if (ndim == 2L) {
    second_dim <- if (multi) "label_" else x$by[2L]
    wrap_var <- if (swap_dim) x$by[1L] else second_dim
    col_var <- if (swap_dim) second_dim else x$by[1L]
    p <- p + ggplot2::geom_line(
      ggplot2::aes(color = .data[[col_var]], group = .data[[col_var]]), ...
    )
    if (show_points) {
      p <- p + ggplot2::geom_point(ggplot2::aes(color = .data[[col_var]]), ...)
    }
    p <- p + ggplot2::facet_wrap(wrap_var, scales = facet_scales)
  }
  if (rotate_x) {
    p <- p + rotate_x()
  }
  p + ggplot2::ylab(x$type)
}

#' ALE profile
#'
#' Internal function used by [light_profile()] to calculate ALE profiles.
#'
#' @noRd
#' @param x An object of class "flashlight".
#' @param v The variable to be profiled.
#' @param breaks Cut breaks for a numeric `v`. Only used if no `evaluate_at` is specified.
#' @param n_bins Maxmium number of unique values to evaluate for numeric `v`.
#'   Only used if no `evaluate_at` is specified.
#' @param cut_type For the default "equal", bins of equal width are created for `v`
#'   by [pretty()]. Choose "quantile" to create quantile bins.
#' @param counts Should counts be added?
#' @param counts_weighted If `counts = TRUE`: Should counts be weighted by the
#'   case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param pred Optional vector with predictions.
#' @param evaluate_at Vector with values of `v` used to evaluate the profile.
#'   Only relevant for type = "partial dependence".
#' @param indices A vector of row numbers to consider.
#' @param n_max Maximum number of ICE profiles to calculate within interval (not within data).
#' @param seed Integer random seed passed to [light_ice()].
#' @param two_sided Standard ALE profiles are calculated via left derivatives.
#'   Set to `TRUE` if two-sided derivatives should be calculated.
#'   Only works for continuous `v`. More specifically: Usually, local effects at
#'   value x are calculated using points in \eqn{[x-e, x]}. Set `ale_two_sided = TRUE`
#'   to use points in \eqn{[x-e/2, x+e/2]} instead.
#' @param calibrate Should values be calibrated based on average preditions?
#'   Default is `TRUE`.
#' @param ... Other arguments passed to this function (currently unused).
#' @returns A tibble containing results.
ale_profile <- function(x, v, breaks = NULL, n_bins = 11L,
                        cut_type = c("equal", "quantile"),
                        counts = TRUE, counts_weighted = FALSE,
                        pred = NULL, evaluate_at = NULL,
                        indices = NULL, n_max = 1000L, seed = NULL,
                        two_sided = FALSE, calibrate = TRUE, ...) {
  cut_type <- match.arg(cut_type)
  data <- x$data
  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'v' not specified." = !is.null(v),
    "'v' not in 'data'." = v %in% colnames(data),
    !any(c("value_", "counts_", "id_") %in% c(x$by, v))
  )
  if (!is.null(seed)) {
    set.seed(seed)
  }
  if (!is.null(indices)) {
    data <- data[indices, , drop = FALSE]
    if (!is.null(pred)) {
      pred <- pred[indices]
    }
  }
  is_num <- is.numeric(data[[v]])

  # Evaluation points (including shift for two-sided derivatives)
  if (is.null(evaluate_at)) {
    if (!is.null(breaks)) {
      evaluate_at <- midpoints(breaks)
    } else {
      cuts <- auto_cut(data[[v]], breaks = breaks, n_bins = n_bins, cut_type = cut_type)
      breaks <- cuts$breaks
      evaluate_at <- cuts$bin_means
    }
  }
  if (two_sided) {
    if (!is.null(breaks)) {
      evaluate_at_orig <- evaluate_at
      evaluate_at <- breaks[-1L]
    } else {
      two_sided <- FALSE
    }
  }

  # Helper function used to calculate differences for any pair of x values
  ale_core <- function(from_to) {
    from <- from_to[1L]
    to <- from_to[2L]
    if (is_num) {
      .s <- data[[v]] >= from & data[[v]] <= to
    } else {
      .s <- data[[v]] %in% from_to
    }
    dat_i <- data[.s, , drop = FALSE]
    if (nrow(dat_i) == 0L) {
      return(NULL)
    }
    ice <- light_ice(x, v = v, data = dat_i, evaluate_at = from_to, n_max = n_max)$data
    if (is_num) {
      ice <- transform(ice, value_ = if (identical(to, from)) 0 else value_/(to - from))
    }
    # Safe reshaping
    dat_to <- ice[ice[[v]] %in% to, ]
    dat_from <- ice[ice[[v]] %in% from, ]
    dat_to <- transform(
      dat_to, value_ = value_ - dat_from$value_[match(id_, dat_from$id_)]
    )

    # Aggregation and output
    out <- grouped_stats(
      dat_to,
      x = "value_",
      w = x$w,
      by = x$by,
      counts_weighted = counts_weighted,
      na.rm = TRUE
    )
    out[[v]] <- to
    out
  }

  # Call ale_core once per interval and combine results
  eval_pair <- data.frame(
    from = evaluate_at[c(1L, 1:(length(evaluate_at) - 1L))], to = evaluate_at
  )
  ale <- dplyr::bind_rows(apply(eval_pair, 1L, ale_core))

  # Remove missing values before accumulation
  bad <- is.na(ale$value_)
  if (any(bad)) {
    ale <- ale[!bad, , drop = FALSE]
  }

  # Accumulate effects. Integrate out gaps
  wcumsum <- function(X) {
    transform(
      X,
      value_ = as.numeric(cumsum(value_ * (if (is_num) c(0, diff(X[[v]])) else 1)))
    )
  }
  ale <- Reframe(ale, FUN = wcumsum, .by = x$by)

  if (is.factor(data[[v]])) {
    ale[[v]] <- factor(ale[[v]], levels = levels(data[[v]]))
  }

  # Calibrate effects
  if (calibrate) {
    preds <- if (is.null(pred)) stats::predict(x) else pred
    if (is.null(x$by)) {
      pred_mean <- MetricsWeighted::weighted_mean(
        preds, if (!is.null(x$w)) data[[x$w]], na.rm = TRUE
      )
      ale_mean <- MetricsWeighted::weighted_mean(
        ale$value_, w = ale$counts_, na.rm = TRUE
      )
      ale <- tibble::as_tibble(transform(ale, value_ = value_ - ale_mean + pred_mean))
    } else {
      stopifnot(!(c("cal_xx", "shift_xx") %in% colnames(data)))
      dat_pred <- grouped_stats(
        cbind(data, cal_xx = preds),
        x = "cal_xx",
        w = x$w,
        by = x$by,
        counts = FALSE,
        na.rm = TRUE
      )
      dat_ale <- grouped_stats(
        ale, x = "value_", w = "counts_", by = x$by, counts = FALSE, na.rm = TRUE
      )
      dat_shift <- dplyr::left_join(dat_ale, dat_pred, by = x$by)
      dat_shift <- transform(dat_shift, shift_xx = cal_xx - value_)
      ale <- dplyr::left_join(
        ale, dat_shift[, c(x$by, "shift_xx"), drop = FALSE], by = x$by
      )
      ale <- transform(ale, value_ = value_ + shift_xx, shift_xx = NULL)
    }
  }
  # Revert shift for two-sided derivatives
  if (two_sided) {
    ale[[v]] <- evaluate_at_orig[match(ale[[v]], breaks[-1L])]
  }

  ale[, c(x$by, v, if (counts && "counts_" %in% colnames(ale)) "counts_", "value_")]
}
mayer79/flashlight documentation built on April 12, 2025, 3:49 p.m.