R/light_profile2d.R

Defines functions light_profile2d.multiflashlight light_profile2d.flashlight light_profile2d.default light_profile2d

Documented in light_profile2d light_profile2d.default light_profile2d.flashlight light_profile2d.multiflashlight

#' 2D Partial Dependence and other 2D Profiles
#'
#' Calculates different types of 2D-profiles across two variables.
#' By default, partial dependence profiles are calculated (see Friedman).
#' Other options are response, predicted values, residuals, and shap.
#' The results are aggregated by (weighted) means.
#'
#' Different binning options are available, see arguments below.
#' For high resolution partial dependence plots, it might be necessary to specify
#' `breaks`, `pd_evaluate_at` or `pd_grid` in order to avoid empty parts
#' in the plot. A high value of `n_bins` might not have the desired effect as it
#' internally capped at the number of distinct values of a variable.
#'
#' For partial dependence and prediction profiles, "model", "predict_function",
#' "linkinv" and "data" are required. For response profiles it is "y", "linkinv"
#' and "data" and for shap profiles it is just "shap". "data" can be passed on the fly.
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param v A vector of exactly two variable names to be profiled.
#' @param data An optional `data.frame`. Not used for `type = "shap"`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param type Type of the profile: Either "partial dependence", "predicted",
#'   "response", "residual", or "shap".
#' @param breaks Named list of cut breaks specifying how to bin one or more numeric
#'   variables. Used to overwrite automatic binning via `n_bins` and `cut_type`.
#'   Ignored for non-numeric `v`.
#' @param n_bins Approximate number of unique values to evaluate for numeric `v`.
#'   Can be an unnamed vector of length 2 to distinguish between v.
#' @param cut_type Should numeric `v` be cut into "equal" or "quantile" bins?
#'   Can be an unnamed vector of length 2 to distinguish between v.
#' @param use_linkinv Should retransformation function be applied?
#'   Default is `TRUE`. Not used for type "shap".
#' @param counts Should observation counts be added?
#' @param counts_weighted If `counts` is TRUE: Should counts be weighted by the
#'   case weights? If `TRUE`, the sum of `w` is returned by group.
#' @param pd_evaluate_at An named list of evaluation points for one or more variables.
#'   Only relevant for type = "partial dependence".
#' @param pd_grid An evaluation `data.frame` with exactly two columns,
#'   e.g., generated by [expand.grid()]. Only used for type = "partial dependence".
#'   Offers maximal flexibility.
#' @param pd_indices A vector of row numbers to consider in calculating partial
#'   dependence profiles. Only used for type = "partial dependence".
#' @param pd_n_max Maximum number of ICE profiles to calculate
#'   (will be randomly picked from `data`). Only used for type = "partial dependence".
#' @param pd_seed Integer random seed used to select ICE profiles.
#'   Only used for type = "partial dependence".
#' @param ... Further arguments passed to [cut3()] in forming
#'   the cut breaks of the `v` variables. Not relevant for partial dependence profiles.
#' @returns
#'   An object of class "light_profile2d" with the following elements:
#'   - `data` A tibble containing results. Can be used to build fully customized
#'     visualizations. Column names can be controlled by
#'     `options(flashlight.column_name)`.
#'   - `by` Names of group by variables.
#'   - `v` The two variable names evaluated.
#'   - `type` Same as input `type`. For information only.
#' @export
#' @references
#'   Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine.
#'     The Annals of Statistics, 29:1189–1232.
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' fl <- flashlight(model = fit, label = "iris", data = iris, y = "Sepal.Length")
#' light_profile2d(fl, v = c("Petal.Length", "Species"))
#' @seealso [light_profile()], [plot.light_profile2d()]
light_profile2d <- function(x, ...) {
  UseMethod("light_profile2d")
}

#' @describeIn light_profile2d Default method not implemented yet.
#' @export
light_profile2d.default <- function(x, ...) {
  stop("No default method available yet.")
}

#' @describeIn light_profile2d 2D profiles for flashlight.
#' @export
light_profile2d.flashlight <- function(x, v = NULL,
                                       data = NULL, by = x$by,
                                       type = c("partial dependence",
                                                "predicted", "response",
                                                "residual", "shap"),
                                       breaks = NULL, n_bins = 11L,
                                       cut_type = "equal",
                                       use_linkinv = TRUE, counts = TRUE,
                                       counts_weighted = FALSE,
                                       pd_evaluate_at = NULL, pd_grid = NULL,
                                       pd_indices = NULL, pd_n_max = 1000L,
                                       pd_seed = NULL, ...) {
  type <- match.arg(type)

  if (type == "shap") {
    message("type = 'shap' is deprecated and will be removed in flashlight 1.0.0.")
  }

  value_name <- getOption("flashlight.value_name")
  label_name <- getOption("flashlight.label_name")
  type_name <- getOption("flashlight.type_name")
  counts_name <- getOption("flashlight.counts_name")

  # Check if exactly two variables are specified
  if (type == "partial dependence" && !is.null(pd_grid)) {
    stopifnot(
      "pd_grid must be a data.frame" = is.data.frame(pd_grid),
      "pd_grid must have exactly two columns" = ncol(pd_grid) == 2L
    )
    v <- colnames(pd_grid)
  } else {
    stopifnot("Need exactly two 'v'." = length(v) == 2L)
  }

  # Turn binning arguments into a list of lists
  strategy <- fix_strategy(
    v,
    n_bins = n_bins,
    cut_type = cut_type,
    breaks = breaks,
    pd_evaluate_at = pd_evaluate_at
  )

  # If SHAP, extract data
  if (type == "shap") {
    if (!is.shap(x$shap)) {
      stop("No shap values calculated. Run 'add_shap' for the flashlight first.")
    }
    stopifnot(v %in% colnames(x$shap$data))
    variable_name <- getOption("flashlight.variable_name")
    data <- x$shap$data[x$shap$data[[variable_name]] %in% v, ]
  } else if (is.null(data)) {
    data <- x$data
  }

  # Checks on data and column names
  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'by' not in 'data'!" = by %in% colnames(data),
    "'v' not in 'data'." = v %in% colnames(data)
  )
  check_unique(c(by, v), c(value_name, label_name, type_name))

  # Update flashlight
  if (type != "shap") {
    x <- flashlight(
      x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z)
  }

  # Calculate profiles
  if (type == "partial dependence") {
    # Construct pd_grid from strategy
    if (is.null(pd_grid)) {
      for (vv in v) {
        st <- strategy[[vv]]
        if (is.null(st$pd_evaluate_at)) {
          if (!is.null(st$breaks) && is.numeric(st$breaks)) {
            strategy[[vv]]$pd_evaluate_at <- midpoints(st$breaks)
          } else {
              strategy[[vv]]$pd_evaluate_at <- auto_cut(
                data[[vv]], n_bins = st$n_bins, cut_type = st$cut_type, ...
              )$bin_means
          }
        }
      }
      pd_grid <- expand.grid(lapply(strategy, `[[`, "pd_evaluate_at"))
    }

    # Calculate 2D ICE profiles
    withr::with_options(list(flashlight.id_name = "id_xxx"),
      data <- light_ice(
        x = x,
        grid = pd_grid,
        indices = pd_indices,
        n_max = pd_n_max,
        seed = pd_seed
      )$data
    )
  } else {
    if (type %in% c("response", "residual") && is.null(x$y)) {
      stop("You need to specify 'y' in flashlight.")
    }

    # Add predictions/response to data
    data[[value_name]] <- switch(
      type,
      response = response(x),
      predicted = stats::predict(x),
      residual = stats::residuals(x),
      shap = data[["shap_"]]
    )

    # Replace v values by binned values
    for (vv in v) {
      st <- strategy[[vv]]
      data[[vv]] <- auto_cut(
        data[[vv]],
        n_bins = st$n_bins,
        cut_type = st$cut_type,
        breaks = st$breaks,
        ...
      )$data$level
    }
  }

  # Aggregate predicted values
  agg <- grouped_stats(
    data = data,
    x = value_name,
    w = x$w,
    by = c(by, v),
    na.rm = TRUE,
    counts = counts,
    counts_weighted = counts_weighted,
    counts_name = counts_name
  )

  # Finalize results
  agg[[label_name]] <- x$label
  agg[[type_name]] <- type

  # Collect results
  out <- list(data = agg, by = by, v = v, type = type)
  add_classes(out, c("light_profile2d", "light"))
}

#' @describeIn light_profile2d 2D profiles for multiflashlight.
#' @export
light_profile2d.multiflashlight <- function(x, v = NULL, data = NULL,
                                            type = c("partial dependence",
                                                     "predicted", "response",
                                                     "residual", "shap"),
                                            breaks = NULL, n_bins = 11L,
                                            cut_type = "equal",
                                            pd_evaluate_at = NULL,
                                            pd_grid = NULL, ...) {
  type <- match.arg(type)
  is_pd <- type == "partial dependence"

  if (is.null(pd_grid) || !is_pd) {
    stopifnot("Need exactly two 'v'." = length(v) == 2L)

    # Turn binning arguments into a list of lists
    strategy <- fix_strategy(
      v,
      n_bins = n_bins,
      cut_type = cut_type,
      breaks = breaks,
      pd_evaluate_at = pd_evaluate_at
    )

    # Calculate common breaks for both variables independently
    for (vv in v) {
      st <- strategy[[vv]]
      if (is.null(st$breaks) && (is.null(st$pd_evaluate_at) || !is_pd)) {
        strategy[[vv]]$breaks <- common_breaks(
          x, v = vv, data = data, n_bins = st$n_bins, cut_type = st$cut_type
        )
      }
    }
    breaks <- lapply(strategy, `[[`, "breaks")
  }

  # Call light_profile2d for all flashlights
  all_profiles <- lapply(
    x,
    light_profile2d,
    v = v,
    data = data,
    type = type,
    breaks = breaks,
    n_bins = n_bins,
    cut_type = cut_type,
    pd_evaluate_at = pd_evaluate_at,
    pd_grid = pd_grid,
    ...
  )
  light_combine(all_profiles, new_class = "light_profile2d_multi")
}

Try the flashlight package in your browser

Any scripts or data that you put into this service are public.

flashlight documentation built on May 31, 2023, 6:19 p.m.