#' Model comparison
#'
#' @description Compare fitted models based on [ELPD][loo-glossary].
#'
#' By default the print method shows only the most important information. Use
#' `print(..., simplify=FALSE)` to print a more detailed summary.
#'
#' @export
#' @param x An object of class `"loo"` or a list of such objects. If a list is
#' used then the list names will be used as the model names in the output. See
#' **Examples**.
#' @param ... Additional objects of class `"loo"`, if not passed in as a single
#' list.
#'
#' @return A matrix with class `"compare.loo"` that has its own
#' print method. See the **Details** section.
#'
#' @details
#' When comparing two fitted models, we can estimate the difference in their
#' expected predictive accuracy by the difference in
#' [`elpd_loo`][loo-glossary] or `elpd_waic` (or multiplied by \eqn{-2}, if
#' desired, to be on the deviance scale).
#'
#' When using `loo_compare()`, the returned matrix will have one row per model
#' and several columns of estimates. The values in the
#' [`elpd_diff`][loo-glossary] and [`se_diff`][loo-glossary] columns of the
#' returned matrix are computed by making pairwise comparisons between each
#' model and the model with the largest ELPD (the model in the first row). For
#' this reason the `elpd_diff` column will always have the value `0` in the
#' first row (i.e., the difference between the preferred model and itself) and
#' negative values in subsequent rows for the remaining models.
#'
#' To compute the standard error of the difference in [ELPD][loo-glossary] ---
#' which should not be expected to equal the difference of the standard errors
#' --- we use a paired estimate to take advantage of the fact that the same
#' set of \eqn{N} data points was used to fit both models. These calculations
#' should be most useful when \eqn{N} is large, because then non-normality of
#' the distribution is not such an issue when estimating the uncertainty in
#' these sums. These standard errors, for all their flaws, should give a
#' better sense of uncertainty than what is obtained using the current
#' standard approach of comparing differences of deviances to a Chi-squared
#' distribution, a practice derived for Gaussian linear models or
#' asymptotically, and which only applies to nested models in any case.
#' Sivula et al. (2022) discuss the conditions when the normal
#' approximation used for SE and `se_diff` is good.
#'
#' If more than \eqn{11} models are compared, we internally recompute the model
#' differences using the median model by ELPD as the baseline model. We then
#' estimate whether the differences in predictive performance are potentially
#' due to chance as described by McLatchie and Vehtari (2023). This will flag
#' a warning if it is deemed that there is a risk of over-fitting due to the
#' selection process. In that case users are recommended to avoid model
#' selection based on LOO-CV, and instead to favor model averaging/stacking or
#' projection predictive inference.
#'
#' @seealso
#' * The [FAQ page](https://mc-stan.org/loo/articles/online-only/faq.html) on
#' the __loo__ website for answers to frequently asked questions.
#' @template loo-and-compare-references
#'
#' @examples
#' # very artificial example, just for demonstration!
#' LL <- example_loglik_array()
#' loo1 <- loo(LL) # should be worst model when compared
#' loo2 <- loo(LL + 1) # should be second best model when compared
#' loo3 <- loo(LL + 2) # should be best model when compared
#'
#' comp <- loo_compare(loo1, loo2, loo3)
#' print(comp, digits = 2)
#'
#' # show more details with simplify=FALSE
#' # (will be the same for all models in this artificial example)
#' print(comp, simplify = FALSE, digits = 3)
#'
#' # can use a list of objects with custom names
#' # will use apple, banana, and cherry, as the names in the output
#' loo_compare(list("apple" = loo1, "banana" = loo2, "cherry" = loo3))
#'
#' \dontrun{
#' # works for waic (and kfold) too
#' loo_compare(waic(LL), waic(LL - 10))
#' }
#'
loo_compare <- function(x, ...) {
UseMethod("loo_compare")
}
#' @rdname loo_compare
#' @export
loo_compare.default <- function(x, ...) {
if (is.loo(x)) {
dots <- list(...)
loos <- c(list(x), dots)
} else {
if (!is.list(x) || !length(x)) {
stop("'x' must be a list if not a 'loo' object.")
}
if (length(list(...))) {
stop("If 'x' is a list then '...' should not be specified.")
}
loos <- x
}
# If subsampling is used
if (any(sapply(loos, inherits, "psis_loo_ss"))) {
return(loo_compare.psis_loo_ss_list(loos))
}
loo_compare_checks(loos)
comp <- loo_compare_matrix(loos)
ord <- loo_compare_order(loos)
# compute elpd_diff and se_elpd_diff relative to best model
rnms <- rownames(comp)
diffs <- mapply(FUN = elpd_diffs, loos[ord[1]], loos[ord])
elpd_diff <- apply(diffs, 2, sum)
se_diff <- apply(diffs, 2, se_elpd_diff)
comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
rownames(comp) <- rnms
# run order statistics-based checks on models
loo_order_stat_check(loos, ord)
class(comp) <- c("compare.loo", class(comp))
return(comp)
}
#' @rdname loo_compare
#' @export
#' @param digits For the print method only, the number of digits to use when
#' printing.
#' @param simplify For the print method only, should only the essential columns
#' of the summary matrix be printed? The entire matrix is always returned, but
#' by default only the most important columns are printed.
print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) {
xcopy <- x
if (inherits(xcopy, "old_compare.loo")) {
if (NCOL(xcopy) >= 2 && simplify) {
patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$"
xcopy <- xcopy[, grepl(patts, colnames(xcopy))]
}
} else if (NCOL(xcopy) >= 2 && simplify) {
xcopy <- xcopy[, c("elpd_diff", "se_diff")]
}
print(.fr(xcopy, digits), quote = FALSE)
invisible(x)
}
# internal ----------------------------------------------------------------
#' Compute pointwise elpd differences
#' @noRd
#' @param loo_a,loo_b Two `"loo"` objects.
elpd_diffs <- function(loo_a, loo_b) {
pt_a <- loo_a$pointwise
pt_b <- loo_b$pointwise
elpd <- grep("^elpd", colnames(pt_a))
pt_b[, elpd] - pt_a[, elpd]
}
#' Compute standard error of the elpd difference
#' @noRd
#' @param diffs Vector of pointwise elpd differences
se_elpd_diff <- function(diffs) {
N <- length(diffs)
# As `elpd_diff` is defined as the sum of N independent components,
# we can compute the standard error by using the standard deviation
# of the N components and multiplying by `sqrt(N)`.
sqrt(N) * sd(diffs)
}
#' Perform checks on `"loo"` objects before comparison
#' @noRd
#' @param loos List of `"loo"` objects.
#' @return Nothing, just possibly throws errors/warnings.
loo_compare_checks <- function(loos) {
## errors
if (length(loos) <= 1L) {
stop("'loo_compare' requires at least two models.", call.=FALSE)
}
if (!all(sapply(loos, is.loo))) {
stop("All inputs should have class 'loo'.", call.=FALSE)
}
Ns <- sapply(loos, function(x) nrow(x$pointwise))
if (!all(Ns == Ns[1L])) {
stop("Not all models have the same number of data points.", call.=FALSE)
}
## warnings
yhash <- lapply(loos, attr, which = "yhash")
yhash_ok <- sapply(yhash, function(x) { # ok only if all yhash are same (all NULL is ok)
isTRUE(all.equal(x, yhash[[1]]))
})
if (!all(yhash_ok)) {
warning("Not all models have the same y variable. ('yhash' attributes do not match)",
call. = FALSE)
}
if (all(sapply(loos, is.kfold))) {
Ks <- unlist(lapply(loos, attr, which = "K"))
if (!all(Ks == Ks[1])) {
warning("Not all kfold objects have the same K value. ",
"For a more accurate comparison use the same number of folds. ",
call. = FALSE)
}
} else if (any(sapply(loos, is.kfold)) && any(sapply(loos, is.psis_loo))) {
warning("Comparing LOO-CV to K-fold-CV. ",
"For a more accurate comparison use the same number of folds ",
"or loo for all models compared.",
call. = FALSE)
}
}
#' Find the model names associated with `"loo"` objects
#'
#' @export
#' @keywords internal
#' @param x List of `"loo"` objects.
#' @return Character vector of model names the same length as `x.`
#'
find_model_names <- function(x) {
stopifnot(is.list(x))
out_names <- character(length(x))
names1 <- names(x)
names2 <- lapply(x, "attr", "model_name", exact = TRUE)
names3 <- lapply(x, "[[", "model_name")
names4 <- paste0("model", seq_along(x))
for (j in seq_along(x)) {
if (isTRUE(nzchar(names1[j]))) {
out_names[j] <- names1[j]
} else if (length(names2[[j]])) {
out_names[j] <- names2[[j]]
} else if (length(names3[[j]])) {
out_names[j] <- names3[[j]]
} else {
out_names[j] <- names4[j]
}
}
out_names
}
#' Compute the loo_compare matrix
#' @keywords internal
#' @noRd
#' @param loos List of `"loo"` objects.
loo_compare_matrix <- function(loos){
tmp <- sapply(loos, function(x) {
est <- x$estimates
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
})
colnames(tmp) <- find_model_names(loos)
rnms <- rownames(tmp)
comp <- tmp
ord <- loo_compare_order(loos)
comp <- t(comp)[ord, ]
patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$")
col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))),
use.names = FALSE)
comp <- comp[, col_ord]
comp
}
#' Computes the order of loos for comparison
#' @noRd
#' @keywords internal
#' @param loos List of `"loo"` objects.
loo_compare_order <- function(loos){
tmp <- sapply(loos, function(x) {
est <- x$estimates
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
})
colnames(tmp) <- find_model_names(loos)
rnms <- rownames(tmp)
ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE)
ord
}
#' Perform checks on `"loo"` objects __after__ comparison
#' @noRd
#' @keywords internal
#' @param loos List of `"loo"` objects.
#' @param ord List of `"loo"` object orderings.
#' @return Nothing, just possibly throws errors/warnings.
loo_order_stat_check <- function(loos, ord) {
## breaks
if (length(loos) <= 11L) {
# procedure cannot be diagnosed for fewer than ten candidate models
# (total models = worst model + ten candidates)
# break from function
return(NULL)
}
## warnings
# compute the elpd differences from the median model
baseline_idx <- middle_idx(ord)
diffs <- mapply(FUN = elpd_diffs, loos[ord[baseline_idx]], loos[ord])
elpd_diff <- apply(diffs, 2, sum)
# estimate the standard deviation of the upper-half-normal
diff_median <- stats::median(elpd_diff)
elpd_diff_trunc <- elpd_diff[elpd_diff >= diff_median]
n_models <- sum(!is.na(elpd_diff_trunc))
candidate_sd <- sqrt(1 / n_models * sum(elpd_diff_trunc^2, na.rm = TRUE))
# estimate expected best diff under null hypothesis
K <- length(loos) - 1
order_stat <- order_stat_heuristic(K, candidate_sd)
if (max(elpd_diff) <= order_stat) {
# flag warning if we suspect no model is theoretically better than the baseline
warning("Difference in performance potentially due to chance.",
"See McLatchie and Vehtari (2023) for details.",
call. = FALSE)
}
}
#' Returns the middle index of a vector
#' @noRd
#' @keywords internal
#' @param vec A vector.
#' @return Integer index value.
middle_idx <- function(vec) floor(length(vec) / 2)
#' Computes maximum order statistic from K Gaussians
#' @noRd
#' @keywords internal
#' @param K Number of Gaussians.
#' @param c Scaling of the order statistic.
#' @return Numeric expected maximum from K samples from a Gaussian with mean
#' zero and scale `"c"`
order_stat_heuristic <- function(K, c) {
qnorm(p = 1 - 1 / (K * 2), mean = 0, sd = c)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.