R/h2_overall.R

Defines functions h2_overall_raw h2_overall.hstats h2_overall.default h2_overall

Documented in h2_overall h2_overall.default h2_overall.hstats

#' Overall Interaction Strength
#' 
#' Friedman and Popescu's statistic of overall interaction strength per 
#' feature, see Details. Use `plot()` to get a barplot.
#' 
#' The logic of Friedman and Popescu (2008) is as follows: 
#' If there are no interactions involving feature \eqn{x_j}, we can decompose the 
#' (centered) prediction function \eqn{F} into the sum of the (centered) partial 
#' dependence \eqn{F_j} on \eqn{x_j} and the (centered) partial dependence 
#' \eqn{F_{\setminus j}} on all other features \eqn{\mathbf{x}_{\setminus j}}, i.e.,
#' \deqn{
#'   F(\mathbf{x}) = F_j(x_j) + F_{\setminus j}(\mathbf{x}_{\setminus j}).
#' }
#' Correspondingly, Friedman and Popescu's statistic of overall interaction 
#' strength of \eqn{x_j} is given by
#' \deqn{
#'   H_j^2 = \frac{\frac{1}{n} \sum_{i = 1}^n\big[F(\mathbf{x}_i) - 
#'   \hat F_j(x_{ij}) - \hat F_{\setminus j}(\mathbf{x}_{i\setminus j})
#'   \big]^2}{\frac{1}{n} \sum_{i = 1}^n\big[F(\mathbf{x}_i)\big]^2}
#' }
#' (check [partial_dep()] for all definitions).
#' 
#' **Remarks:**
#' 
#' 1. Partial dependence functions (and \eqn{F}) are all centered to 
#'   (possibly weighted) mean 0.
#' 2. Partial dependence functions (and \eqn{F}) are evaluated over the data distribution. 
#'   This is different to partial dependence plots, where one uses a fixed grid.
#' 3. Weighted versions follow by replacing all arithmetic means by corresponding
#'   weighted means.
#' 4. Multivariate predictions can be treated in a component-wise manner.
#' 5. Due to (typically undesired) extrapolation effects of partial dependence functions, 
#'   depending on the model, values above 1 may occur.
#' 6. \eqn{H^2_j = 0} means there are no interactions associated with \eqn{x_j}. 
#'   The higher the value, the more prediction variability comes from interactions 
#'   with \eqn{x_j}.
#' 7. Since the denominator is the same for all features, the values of the test 
#'   statistics can be compared across features.
#' 
#' @param object Object of class "hstats".
#' @param normalize Should statistics be normalized? Default is `TRUE`.
#' @param squared Should *squared* statistics be returned? Default is `TRUE`. 
#' @param sort Should results be sorted? Default is `TRUE`.
#'   (Multi-output is sorted by row means.)
#' @param zero Should rows with all 0 be shown? Default is `TRUE`.
#' @param ... Currently unused.
#' @returns 
#'   An object of class "hstats_matrix" containing these elements:
#'   - `M`: Matrix of statistics (one column per prediction dimension), or `NULL`.
#'   - `SE`: Matrix with standard errors of `M`, or `NULL`. 
#'     Multiply with `sqrt(m_rep)` to get *standard deviations* instead. 
#'     Currently, supported only for [perm_importance()].
#'   - `m_rep`: The number of repetitions behind standard errors `SE`, or `NULL`.
#'     Currently, supported only for [perm_importance()].
#'   - `statistic`: Name of the function that generated the statistic.
#'   - `description`: Description of the statistic.
#' @inherit hstats references
#' @seealso [hstats()], [h2()], [h2_pairwise()], [h2_threeway()]
#' @export
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[, -1])
#' h2_overall(s)
#' plot(h2_overall(s))
#' 
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[, 3:5], verbose = FALSE)
#' plot(h2_overall(s, zero = FALSE))
h2_overall <- function(object, ...) {
  UseMethod("h2_overall")
}

#' @describeIn h2_overall Default method of overall interaction strength.
#' @export
h2_overall.default <- function(object, ...) {
  stop("No default method implemented.")
}

#' @describeIn h2_overall Overall interaction strength from "hstats" object.
#' @export
h2_overall.hstats <- function(object, normalize = TRUE, squared = TRUE, 
                              sort = TRUE, zero = TRUE, ...) {
  get_hstats_matrix(
    statistic = "h2_overall",
    object = object,
    normalize = normalize, 
    squared = squared, 
    sort = sort,
    zero = zero
  )
}

# Helper function

#' Raw H2 Overall
#' 
#' Internal helper function that calculates numerator and denominator of
#' statistic in title.
#' 
#' @noRd
#' @keywords internal
#' @param x A list containing the elements "v", "K", "pred_names", 
#'   "f", "F_not_j", "F_j", "mean_f2", "eps", and "w".
#' @returns A list with the numerator and denominator statistics.
h2_overall_raw <- function(x) {
  num <- init_numerator(x, way = 1L)
  for (z in x[["v"]]) {
    num[z, ] <- with(x, wcolMeans((f - F_j[[z]] - F_not_j[[z]])^2, w = w))
  }
  num <- .zap_small(num, eps = x[["eps"]])  # Numeric precision
  
  list(num = num, denom = x[["mean_f2"]])
}

Try the hstats package in your browser

Any scripts or data that you put into this service are public.

hstats documentation built on May 29, 2024, 6:43 a.m.