R/light_interaction.R

Defines functions .zap_small light_interaction.multiflashlight light_interaction.flashlight light_interaction.default light_interaction

Documented in light_interaction light_interaction.default light_interaction.flashlight light_interaction.multiflashlight

#' Interaction Strength
#'
#' This function provides Friedman's H statistic for overall interaction strength per
#' covariable as well as its version for pairwise interactions, see the reference below.
#'
#' As a fast alternative to assess overall interaction strength, with `type = "ice"`,
#' the function offers a method based on centered ICE curves:
#' The corresponding H* statistic measures how much of the variability of a c-ICE curve
#' is unexplained by the main effect. As for Friedman's H statistic, it can be useful
#' to consider unnormalized or squared values (see Details below).
#'
#' Friedman's H statistic relates the interaction strength of a variable (pair)
#' to the total effect strength of that variable (pair) based on partial dependence
#' curves. Due to this normalization step, even variables with low importance can
#' have high values for H. The function [light_interaction()] offers the option
#' to skip normalization in order to have a more direct comparison of the interaction
#' effects across variable (pairs). The values of such unnormalized H statistics are
#' on the scale of the response variable. Use `take_sqrt = FALSE` to return
#' squared values of H. Note that in general, for each variable (pair), predictions
#' are done on a data set with `grid_size * n_max`, so be cautious with
#' increasing the defaults too much. Still, even with larger `grid_size`
#' and `n_max`, there might be considerable variation across different runs,
#' thus, setting a seed is recommended.
#'
#' The minimum required elements in the (multi-) flashlight are a "predict_function",
#' "model", and "data".
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' @param v Vector of variable names to be assessed.
#' @param pairwise Should overall interaction strength per variable be shown or
#'   pairwise interactions? Defaults to `FALSE`.
#' @param type Are measures based on Friedman's H statistic ("H") or on "ice" curves?
#'   Option "ice" is available only if `pairwise = FALSE`.
#' @param normalize Should the variances explained be normalized?
#'   Default is `TRUE` in order to reproduce Friedman's H statistic.
#' @param take_sqrt In order to reproduce Friedman's H statistic,
#'   resulting values are root transformed. Set to `FALSE` if squared values
#'   should be returned.
#' @param grid_size Grid size used to form the outer product. Will be randomly
#'   picked from data (after limiting to `n_max`).
#' @param n_max Maximum number of data rows to consider. Will be randomly picked
#'   from `data` if necessary.
#' @param seed An integer random seed used for subsampling.
#' @param use_linkinv Should retransformation function be applied? Default is `FALSE`.
#' @param ... Further arguments passed to or from other methods.
#' @returns
#'   An object of class "light_importance" with the following elements:
#'   - `data` A tibble containing the results. Can be used to build fully customized
#'     visualizations. Column names can be controlled by
#'     `options(flashlight.column_name)`.
#'   - `by` Same as input `by`.
#'   - `type` Same as input `type`. For information only.
#' @export
#' @references
#'   Friedman, J. H. and Popescu, B. E. (2008). "Predictive learning via rule
#'     ensembles." The Annals of Applied Statistics. JSTOR, 916–54.
#' @examples
#' # First model with interactions
#' fit_nonadd <- lm(
#'   Sepal.Length ~ . + Sepal.Width:Species + Petal.Width:Species, data = iris
#' )
#' fl_nonadd <- flashlight(
#'   model = fit_nonadd, label = "nonadditive", data = iris, y = "Sepal.Length"
#' )
#'
#' # Friedman's H per feature
#' plot(light_interaction(fl_nonadd), fill = "chartreuse4")
#'
#' # Unnormalized H^2 measures proportion of bivariate effect explained by interaction
#' plot(
#'   light_interaction(fl_nonadd, normalize = TRUE, take_sqrt = TRUE),
#'   fill = "chartreuse4"
#' )
#'
#' # Pairwise H
#' plot(light_interaction(fl_nonadd, pairwise = TRUE), fill = "chartreuse4")
#'
#' # Second model without interactions
#' fit_add <- lm(Sepal.Length ~ ., data = iris)
#' fl_add <- flashlight(
#'   model = fit_add, label = "additive", data = iris, y = "Sepal.Length"
#' )
#' fls <- multiflashlight(list(fl_add, fl_nonadd))
#'
#' plot(light_interaction(fls), fill = "chartreuse4")
#' @seealso [light_ice()]
light_interaction <- function(x, ...) {
  UseMethod("light_interaction")
}

#' @describeIn light_interaction Default method not implemented yet.
#' @export
light_interaction.default <- function(x, ...) {
  stop("light_interaction method is only available for objects of class flashlight or multiflashlight.")
}

#' @describeIn light_interaction Interaction strengths for a flashlight object.
#' @export
light_interaction.flashlight <- function(x, data = x$data, by = x$by,
                                         v = NULL, pairwise = FALSE,
                                         type = c("H", "ice"),
                                         normalize = TRUE, take_sqrt = TRUE,
                                         grid_size = 200L, n_max = 1000L,
                                         seed = NULL,
                                         use_linkinv = FALSE, ...) {
  type <- match.arg(type)
  if (length(by) >= 2L) {
    stop("light_interaction() does not support more than one by variable.")
  }
  temp_vars <- c(
    "value_", "label_", "variable_", "error_", "w_", "id_", "id_curve_",
     "value_2_", "value_i_", "value_j_", "denom_"
  )
  stopifnot(
    "No data!" = is.data.frame(data) && nrow(data) >= 1L,
    "'by' not in 'data'!" = by %in% colnames(data),
    "Not all 'v' in 'data'" = v %in% colnames(data),
    !(c("id_", "id_curve_", "w_") %in% colnames(data)),
    !any(temp_vars %in% c(by, v))
  )

  if (type == "ice" && pairwise) {
    stop("Pairwise interactions are implemented only for type = 'H'.")
  }
  cols <- colnames(data)

  if (!is.null(seed)) {
    set.seed(seed)
  }

  # Determine v
  if (is.null(v)) {
    v <- setdiff(cols, c(x$y, by, x$w))
  }
  stopifnot(length(v) >= 1L + pairwise)
  if (pairwise) {
    v <- utils::combn(v, 2, simplify = FALSE)
  }

  # Sampling weights have to be dealt with since they can appear in both grid and sample
  has_w <- !is.null(x$w)
  w <- if (has_w) "w_"
  if (has_w) {
    data[[w]] <- data[[x$w]]
  }

  # Update flashlight (except for data)
  x <- flashlight(
    x, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z
  )

  # HELPER FUNCTIONS
  # Version of light_profile and light_ice
  call_pd <- function(X, z, vn = "value_2_", gid, only_values = FALSE, agg = TRUE) {
    # Weights of the grid ids
    if (has_w) {
      ww <- X[gid, w, drop = FALSE]
      ww$id_ <- gid
    }
    grid <- X[gid, z, drop = FALSE]
    grid$id_ <- gid
    X[, z] <- NULL
    X$id_curve_ <- seq_len(nrow(X))
    X <- tidyr::expand_grid(X, grid)
    X[[vn]] <- stats::predict(x, data = X[, cols, drop = FALSE])
    if (!agg) {
      X[[vn]] <- grouped_center(X, x = vn, by = "id_curve_", na.rm = TRUE)
      return(X)
    }
    out <- grouped_weighted_mean(X, x = vn, w = w, by = "id_")
    out <- out[order(out$id_), ]
    if (has_w) {
      out[[w]] <- ww[[w]][match(out$id_, ww$id_)]
    }
    out[[vn]] <- grouped_center(out, x = vn, w = w)
    if (only_values) out[, vn, drop = FALSE] else out
  }

  # Get predictions on grid in the same order as through call_pd
  call_f <- function(X, vn = "value_2_", gid) {
    out <- X[gid, ]
    out[[vn]] <- stats::predict(x, data = out[, cols, drop = FALSE])
    out[[vn]] <- grouped_center(out, x = vn, w = w)
    out$id_ <- gid
    out[order(out$id_), c("id_", vn, w)]
  }
  # Functions that calculates the test statistic
  statistic <- function(z, dat, grid_id) {
    if (nrow(dat) <= 2) {
      return(stats::setNames(data.frame(0), "value_"))
    }
    if (type == "H") {
      z_i <- z[1L]
      z_j <- if (pairwise) z[2L] else setdiff(cols, z_i)
      if (pairwise) {
        pd_f <- call_pd(dat, z = z, gid = grid_id)
      } else {
        pd_f <- call_f(dat, gid = grid_id)
      }
      pd_i <- call_pd(dat, z = z_i, vn = "value_i_", gid = grid_id, only_values = TRUE)
      pd_j <- call_pd(dat, z = z_j, vn = "value_j_", gid = grid_id, only_values = TRUE)
      dat <- dplyr::bind_cols(pd_f, pd_i, pd_j)
      dat <- transform(dat, value_ = (value_2_ - value_i_ - value_j_)^2)
    }
    else if (type == "ice") {
      dat <- call_pd(dat, z = z, gid = grid_id, agg = FALSE)
      dat$value_ <- grouped_center(dat, x = "value_2_", w = w, by = "id_")^2
    } else {
      stop("Only type H or ice implemented.")
    }
    # Aggregate & normalize
    num <- MetricsWeighted::weighted_mean(
      dat$value_, w = if (has_w) dat[[w]], na.rm = TRUE
    )
    if (normalize) {
      num <- .zap_small(num) /
        MetricsWeighted::weighted_mean(
          dat$value_2_^2, w = if (has_w) dat[[w]], na.rm = TRUE
        )
    }
    stats::setNames(
      data.frame(.zap_small(if (take_sqrt) sqrt(num) else num)), "value_"
    )
  }
  # Calculate statistic for each variable (pair) and combine results
  core_func <- function(X) {
    # Reduce data size and select grid values
    n <- nrow(X)
    if (n_max < n) {
      X <- X[sample(n, n_max), , drop = FALSE]
      n <- n_max
    }
    if (grid_size < n) {
      grid_id <- sample(n, grid_size)
    } else {
      grid_size <- n
      grid_id <- seq_len(grid_size)
    }

    # Calculate Friedman's H statistic for each variable (pair)
    out <- lapply(v, statistic, dat = X, grid_id = grid_id)
    names(out) <- if (pairwise) lapply(v, paste, collapse = ":") else v
    dplyr::bind_rows(out, .id = "variable_")
  }

  # Call core function for each "by" group (should rework code...)
  if (is.null(by)) {
    agg <- core_func(data)
  } else {
    agg_l <- lapply(split(data, f = data[[by]]), core_func)
    for (nm in names(agg_l)) {
      agg_l[[nm]][, by] <- nm
    }
    agg <- dplyr::bind_rows(agg_l)
  }

  # Prepare output
  agg <- transform(tibble::as_tibble(agg), label_ = x$label, error_ = NA)
  add_classes(
    list(
      data = agg[, c("label_", by, "variable_", "value_", "error_")],
      by = by,
      type = type
    ),
    c("light_importance", "light")
  )
}


#' @describeIn light_interaction for a multiflashlight object.
#' @export
light_interaction.multiflashlight <- function(x, ...) {
  light_combine(lapply(x, light_interaction, ...), new_class = "light_importance_multi")
}

# Helper function used to clip small values.
.zap_small <- function(x, eps = 1e-12, val = 0) {
  .bad <- abs(x) < eps | !is.finite(x)
  if (any(.bad)) {
    x[.bad] <- val
  }
  x
}
mayer79/flashlight documentation built on Feb. 13, 2024, 1:09 p.m.