# Part of the rstanarm package for estimating model parameters
# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#' Information criteria and cross-validation
#'
#' @description For models fit using MCMC, compute approximate leave-one-out
#' cross-validation (LOO, LOOIC) or, less preferably, the Widely Applicable
#' Information Criterion (WAIC) using the \pkg{\link[=loo-package]{loo}}
#' package. Functions for \eqn{K}-fold cross-validation, model comparison,
#' and model weighting/averaging are also provided. \strong{Note}:
#' these functions are not guaranteed to work properly unless the \code{data}
#' argument was specified when the model was fit. Also, as of \pkg{loo}
#' version \code{2.0.0} the default number of cores is now only 1, but we
#' recommend using as many (or close to as many) cores as possible by setting
#' the \code{cores} argument or using \code{options(mc.cores = VALUE)} to set
#' it for an entire session.
#'
#' @aliases loo waic
#'
#' @export
#'
#' @param x For \code{loo}, \code{waic}, and \code{kfold} methods, a fitted
#' model object returned by one of the rstanarm modeling functions. See
#' \link{stanreg-objects}.
#'
#' For \code{loo_model_weights}, \code{x} should be a "stanreg_list"
#' object, which is a list of fitted model objects created by
#' \code{\link{stanreg_list}}.
#'
#' @param ... For \code{compare_models}, \code{...} should contain two or more
#' objects returned by the \code{loo}, \code{kfold}, or \code{waic} method
#' (see the \strong{Examples} section, below).
#'
#' For \code{loo_model_weights}, \code{...} should contain arguments
#' (e.g. \code{method}) to pass to the default
#' \code{\link[loo]{loo_model_weights}} method from the \pkg{loo} package.
#'
#' @param cores,save_psis Passed to \code{\link[loo]{loo}}.
#' @param k_threshold Threshold for flagging estimates of the Pareto shape
#' parameters \eqn{k} estimated by \code{loo}. See the \emph{How to proceed
#' when \code{loo} gives warnings} section, below, for details.
#'
#' @return The structure of the objects returned by \code{loo} and \code{waic}
#' methods are documented in detail in the \strong{Value} section in
#' \code{\link[loo]{loo}} and \code{\link[loo]{waic}} (from the \pkg{loo}
#' package).
#'
#' @section Approximate LOO CV: The \code{loo} method for stanreg objects
#' provides an interface to the \pkg{\link[=loo-package]{loo}} package for
#' approximate leave-one-out cross-validation (LOO). The LOO Information
#' Criterion (LOOIC) has the same purpose as the Akaike Information Criterion
#' (AIC) that is used by frequentists. Both are intended to estimate the
#' expected log predictive density (ELPD) for a new dataset. However, the AIC
#' ignores priors and assumes that the posterior distribution is multivariate
#' normal, whereas the functions from the \pkg{loo} package do not make this
#' distributional assumption and integrate over uncertainty in the parameters.
#' This only assumes that any one observation can be omitted without having a
#' major effect on the posterior distribution, which can be judged using the
#' diagnostic plot provided by the \code{\link[loo]{plot.loo}} method and the
#' warnings provided by the \code{\link[loo]{print.loo}} method (see the
#' \emph{How to Use the rstanarm Package} vignette for an example of this
#' process).
#'
#' \subsection{How to proceed when \code{loo} gives warnings (k_threshold)}{
#' The \code{k_threshold} argument to the \code{loo} method for \pkg{rstanarm}
#' models is provided as a possible remedy when the diagnostics reveal
#' problems stemming from the posterior's sensitivity to particular
#' observations. Warnings about Pareto \eqn{k} estimates indicate observations
#' for which the approximation to LOO is problematic (this is described in
#' detail in Vehtari, Gelman, and Gabry (2017) and the
#' \pkg{\link[=loo-package]{loo}} package documentation). The
#' \code{k_threshold} argument can be used to set the \eqn{k} value above
#' which an observation is flagged. If \code{k_threshold} is not \code{NULL}
#' and there are \eqn{J} observations with \eqn{k} estimates above
#' \code{k_threshold} then when \code{loo} is called it will refit the
#' original model \eqn{J} times, each time leaving out one of the \eqn{J}
#' problematic observations. The pointwise contributions of these observations
#' to the total ELPD are then computed directly and substituted for the
#' previous estimates from these \eqn{J} observations that are stored in the
#' object created by \code{loo}.
#'
#' \strong{Note}: in the warning messages issued by \code{loo} about large
#' Pareto \eqn{k} estimates we recommend setting \code{k_threshold} to at
#' least \eqn{0.7}. There is a theoretical reason, explained in Vehtari,
#' Gelman, and Gabry (2017), for setting the threshold to the stricter value
#' of \eqn{0.5}, but in practice they find that errors in the LOO
#' approximation start to increase non-negligibly when \eqn{k > 0.7}.
#' }
#'
#' @seealso
#' \itemize{
#' \item The new \href{http://mc-stan.org/loo/articles/}{\pkg{loo} package vignettes}
#' and various \href{http://mc-stan.org/rstanarm/articles/}{\pkg{rstanarm} vignettes}
#' for more examples using \code{loo} and related functions with \pkg{rstanarm} models.
#' \item \code{\link[loo]{pareto-k-diagnostic}} in the \pkg{loo} package for
#' more on Pareto \eqn{k} diagnostics.
#' \item \code{\link{log_lik.stanreg}} to directly access the pointwise
#' log-likelihood matrix.
#' }
#'
#' @examples
#' \donttest{
#' fit1 <- stan_glm(mpg ~ wt, data = mtcars)
#' fit2 <- stan_glm(mpg ~ wt + cyl, data = mtcars)
#'
#' # compare on LOOIC
#' # (for bigger models use as many cores as possible)
#' loo1 <- loo(fit1, cores = 2)
#' print(loo1)
#' loo2 <- loo(fit2, cores = 2)
#' print(loo2)
#'
#' # when comparing exactly two models, the reported 'elpd_diff'
#' # will be positive if the expected predictive accuracy for the
#' # second model is higher. the approximate standard error of the
#' # difference is also reported.
#' compare_models(loo1, loo2)
#' compare_models(loos = list(loo1, loo2)) # can also provide list
#'
#' # when comparing three or more models they are ordered by
#' # expected predictive accuracy. elpd_diff and se_diff are relative
#' # to the model with best elpd_loo (first row)
#' fit3 <- stan_glm(mpg ~ disp * as.factor(cyl), data = mtcars)
#' loo3 <- loo(fit3, cores = 2, k_threshold = 0.7)
#' compare_models(loo1, loo2, loo3)
#'
#' # setting detail=TRUE will also print model formulas
#' compare_models(loo1, loo2, loo3, detail=TRUE)
#'
#' # Computing model weights
#' model_list <- stanreg_list(fit1, fit2, fit3)
#' loo_model_weights(model_list, cores = 2) # can specify k_threshold=0.7 if necessary
#'
#' # if you have already computed loo then it's more efficient to pass a list
#' # of precomputed loo objects than a "stanreg_list", avoiding the need
#' # for loo_models weights to call loo() internally
#' loo_list <- list(fit1 = loo1, fit2 = loo2, fit3 = loo3) # names optional (affects printing)
#' loo_model_weights(loo_list)
#'
#' # 10-fold cross-validation
#' (kfold1 <- kfold(fit1, K = 10))
#' kfold2 <- kfold(fit2, K = 10)
#' compare_models(kfold1, kfold2, detail=TRUE)
#'
#' # Cross-validation stratifying by a grouping variable
#' # (note: might get some divergences warnings with this model but
#' # this is just intended as a quick example of how to code this)
#' library(loo)
#' fit4 <- stan_lmer(mpg ~ disp + (1|cyl), data = mtcars)
#' table(mtcars$cyl)
#' folds_cyl <- kfold_split_stratified(K = 3, x = mtcars$cyl)
#' table(cyl = mtcars$cyl, fold = folds_cyl)
#' kfold4 <- kfold(fit4, K = 3, folds = folds_cyl)
#' }
#'
#' @importFrom loo loo loo.function loo.matrix
#'
loo.stanreg <-
function(x,
...,
cores = getOption("mc.cores", 1),
save_psis = FALSE,
k_threshold = NULL) {
if (!used.sampling(x))
STOP_sampling_only("loo")
if (model_has_weights(x))
recommend_exact_loo(reason = "model has weights")
user_threshold <- !is.null(k_threshold)
if (user_threshold) {
validate_k_threshold(k_threshold)
} else {
k_threshold <- 0.7
}
# chain_id to pass to loo::relative_eff
chain_id <- chain_id_for_loo(x)
if (is.stanjm(x)) {
ll <- log_lik(x)
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
loo_x <-
suppressWarnings(loo.matrix(
ll,
r_eff = r_eff,
cores = cores,
save_psis = save_psis
))
} else if (is.stanmvreg(x)) {
M <- get_M(x)
ll <- do.call("cbind", lapply(1:M, function(m) log_lik(x, m = m)))
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
loo_x <-
suppressWarnings(loo.matrix(
ll,
r_eff = r_eff,
cores = cores,
save_psis = save_psis
))
} else if (is_clogit(x)) {
ll <- log_lik.stanreg(x)
cons <- apply(ll, MARGIN = 2, FUN = function(y) sd(y) < 1e-15)
if (any(cons)) {
message(
"The following strata were dropped from the ",
"loo calculation because log-lik is constant: ",
paste(which(cons), collapse = ", ")
)
ll <- ll[,!cons, drop = FALSE]
}
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
loo_x <-
suppressWarnings(loo.matrix(
ll,
r_eff = r_eff,
cores = cores,
save_psis = save_psis
))
} else if (is.stansurv(x) && x$has_quadrature) {
ll <- log_lik.stanreg(x)
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
loo_x <-
suppressWarnings(loo.matrix(
ll,
r_eff = r_eff,
cores = cores,
save_psis = save_psis
))
} else {
args <- ll_args(x)
llfun <- ll_fun(x)
likfun <- function(data_i, draws) {
exp(llfun(data_i, draws))
}
r_eff <- loo::relative_eff(
# using function method
x = likfun,
chain_id = chain_id,
data = args$data,
draws = args$draws,
cores = cores,
...
)
loo_x <- suppressWarnings(
loo.function(
llfun,
data = args$data,
draws = args$draws,
r_eff = r_eff,
...,
cores = cores,
save_psis = save_psis
)
)
}
bad_obs <- loo::pareto_k_ids(loo_x, k_threshold)
n_bad <- length(bad_obs)
out <- structure(
loo_x,
name = deparse(substitute(x)),
discrete = is_discrete(x),
yhash = hash_y(x),
formula = loo_model_formula(x)
)
if (!length(bad_obs)) {
if (user_threshold) {
message(
"All pareto_k estimates below user-specified threshold of ",
k_threshold,
". \nReturning loo object."
)
}
return(out)
}
if (!user_threshold) {
if (n_bad > 10) {
recommend_kfold(n_bad)
} else {
recommend_reloo(n_bad)
}
return(out)
}
reloo_out <- reloo(x, loo_x, obs = bad_obs)
structure(
reloo_out,
name = attr(out, "name"),
discrete = attr(out, "discrete"),
yhash = attr(out, "yhash"),
formula = loo_model_formula(x)
)
}
# WAIC
#
#' @rdname loo.stanreg
#' @export
#' @importFrom loo waic waic.function waic.matrix
#'
waic.stanreg <- function(x, ...) {
if (!used.sampling(x))
STOP_sampling_only("waic")
if (is.stanjm(x)) {
out <- waic.matrix(log_lik(x))
} else if (is.stanmvreg(x)) {
M <- get_M(x)
ll <- do.call("cbind", lapply(1:M, function(m) log_lik(x, m = m)))
out <- waic.matrix(ll)
} else if (is_clogit(x)) {
out <- waic.matrix(log_lik(x))
} else if (is.stansurv(x) && x$has_quadrature) {
out <- waic.matrix(log_lik(x))
} else {
args <- ll_args(x)
out <- waic.function(ll_fun(x), data = args$data, draws = args$draws)
}
structure(out,
class = c("waic", "loo"),
name = deparse(substitute(x)),
discrete = is_discrete(x),
yhash = hash_y(x),
formula = loo_model_formula(x))
}
# K-fold CV
#
#' @rdname loo.stanreg
#' @export
#' @param K For \code{kfold}, the number of subsets (folds)
#' into which the data will be partitioned for performing
#' \eqn{K}-fold cross-validation. The model is refit \code{K} times, each time
#' leaving out one of the \code{K} folds. If \code{K} is equal to the total
#' number of observations in the data then \eqn{K}-fold cross-validation is
#' equivalent to exact leave-one-out cross-validation.
#' @param save_fits For \code{kfold}, if \code{TRUE}, a component \code{'fits'}
#' is added to the returned object to store the cross-validated
#' \link[=stanreg-objects]{stanreg} objects and the indices of the omitted
#' observations for each fold. Defaults to \code{FALSE}.
#' @param folds For \code{kfold}, an optional integer vector with one element
#' per observation in the data used to fit the model. Each element of the
#' vector is an integer in \code{1:K} indicating to which of the \code{K}
#' folds the corresponding observation belongs. There are some convenience
#' functions available in the \pkg{loo} package that create integer vectors to
#' use for this purpose (see the \strong{Examples} section below and also the
#' \link[loo]{kfold-helpers} page).
#'
#' If \code{folds} is not specified then the default is to call
#' \code{loo::\link[loo]{kfold_split_random}} to randomly partition the data
#' into \code{K} subsets of equal (or as close to equal as possible) size.
#'
#' @return \code{kfold} returns an object with classes 'kfold' and 'loo' that
#' has a similar structure as the objects returned by the \code{loo} and
#' \code{waic} methods.
#'
#' @section K-fold CV: The \code{kfold} function performs exact \eqn{K}-fold
#' cross-validation. First the data are randomly partitioned into \eqn{K}
#' subsets of equal (or as close to equal as possible) size (unless the folds
#' are specified manually). Then the model is refit \eqn{K} times, each time
#' leaving out one of the \eqn{K} subsets. If \eqn{K} is equal to the total
#' number of observations in the data then \eqn{K}-fold cross-validation is
#' equivalent to exact leave-one-out cross-validation (to which \code{loo} is
#' an efficient approximation). The \code{compare_models} function is also
#' compatible with the objects returned by \code{kfold}.
#'
kfold <- function(x, K = 10, save_fits = FALSE, folds = NULL) {
validate_stanreg_object(x)
stopifnot(K > 1, K <= nobs(x))
if (!used.sampling(x)) {
STOP_sampling_only("kfold")
}
if (is.stanmvreg(x)) {
STOP_if_stanmvreg("kfold")
}
if (model_has_weights(x)) {
stop("kfold is not currently available for models fit using weights.")
}
d <- kfold_and_reloo_data(x)
N <- nrow(d)
K <- as.integer(K)
if (is.null(folds)) {
folds <- loo::kfold_split_random(K = K, N = N)
} else {
stopifnot(
length(folds) == N,
all(folds == as.integer(folds)),
all(folds %in% 1L:K),
all(1:K %in% folds)
)
folds <- as.integer(folds)
}
lppds <- list()
fits <- array(list(), c(K, 2), list(NULL, c("fit", "omitted")))
for (k in 1:K) {
message("Fitting model ", k, " out of ", K)
omitted <- which(folds == k)
fit_k_call <- update.stanreg(
object = x,
data = d[-omitted,, drop=FALSE],
subset = rep(TRUE, nrow(d) - length(omitted)),
weights = NULL,
refresh = 0,
open_progress = FALSE,
evaluate = FALSE
)
if (!is.null(getCall(x)$offset)) {
fit_k_call$offset <- x$offset[-omitted]
}
fit_k_call$subset <- if (!is.stansurv(x)) eval(fit_k_call$subset) else NULL
fit_k_call$data <- eval(fit_k_call$data)
capture.output(
fit_k <- eval(fit_k_call)
)
lppds[[k]] <-
log_lik.stanreg(
fit_k,
newdata = d[omitted, , drop = FALSE],
offset = x$offset[omitted],
newx = get_x(x)[omitted, , drop = FALSE],
newz = x$z[omitted, , drop = FALSE], # NULL other than for some stan_betareg models
stanmat = as.matrix.stanreg(fit_k)
)
if (save_fits) {
fits[k, ] <- list(fit = fit_k, omitted = omitted)
}
}
elpds_unord <- unlist(lapply(lppds, function(x) {
apply(x, 2, log_mean_exp)
}))
# make sure elpds are put back in the right order
obs_order <- unlist(lapply(1:K, function(k) which(folds == k)))
elpds <- rep(NA, length(elpds_unord))
elpds[obs_order] <- elpds_unord
out <- list(
elpd_kfold = sum(elpds),
se_elpd_kfold = sqrt(N * var(elpds)),
pointwise = cbind(elpd_kfold = elpds)
)
# for compatibility with new structure of loo package objects
out$estimates <- cbind(Estimate = out$elpd_kfold, SE = out$se_elpd_kfold)
rownames(out$estimates) <- c("elpd_kfold")
if (save_fits) {
out$fits <- fits
}
structure(out,
class = c("kfold", "loo"),
K = K,
name = deparse(substitute(x)),
discrete = is_discrete(x),
yhash = hash_y(x),
formula = loo_model_formula(x))
}
#' Various print methods
#'
#' @keywords internal
#' @export
#' @method print kfold
#' @param x,digits,... See \code{\link{print}}.
print.kfold <- function(x, digits = 1, ...) {
cat("\n", paste0(attr(x, "K"), "-fold"), "cross-validation\n\n")
out <- data.frame(Estimate = x$elpd_kfold, SE = x$se_elpd_kfold,
row.names = "elpd_kfold")
.printfr(out, digits)
invisible(x)
}
#' @rdname loo.stanreg
#' @export
#'
#' @param loos For \code{compare_models}, a list of two or more objects returned
#' by the \code{loo}, \code{kfold}, or \code{waic} method. This argument can
#' be used as an alternative to passing these objects via \code{...}.
#' @param detail For \code{compare_models}, if \code{TRUE} then extra
#' information about each model (currently just the model formulas) will be
#' printed with the output.
#'
#' @return \code{compare_models} returns a vector or matrix with class
#' 'compare.loo'. See the \strong{Comparing models} section below for more
#' details.
#'
#' @section Comparing models: \code{compare_models} is a method for the
#' \code{\link[loo]{compare}} function in the \pkg{loo} package that
#' performs some extra checks to make sure the \pkg{rstanarm} models are
#' suitable for comparison. These extra checks include verifying that all
#' models to be compared were fit using the same outcome variable and
#' likelihood family.
#'
#' If exactly two models are being compared then \code{compare_models} returns
#' a vector containing the difference in expected log predictive density
#' (ELPD) between the models and the standard error of that difference (the
#' documentation for \code{\link[loo]{compare}} in the \pkg{loo}
#' package has additional details about the calculation of the standard error
#' of the difference). The difference in ELPD will be negative if the expected
#' out-of-sample predictive accuracy of the first model is higher. If the
#' difference is be positive then the second model is preferred.
#'
#' If more than two models are being compared then \code{compare_models}
#' returns a matrix with one row per model. This matrix summarizes the objects
#' and arranges them in descending order according to expected out-of-sample
#' predictive accuracy. That is, the first row of the matrix will be
#' for the model with the largest ELPD (smallest LOOIC).
#' The columns containing the ELPD difference and the standard error of the
#' difference contain values relative to the model with the best ELPD.
#' See the \strong{Details} section at the \code{\link[loo]{compare}}
#' page in the \pkg{loo} package for more information.
#'
compare_models <- function(..., loos = list(), detail = FALSE) {
dots <- list(...)
if (length(dots) && length(loos)) {
stop("'...' and 'loos' can't both be specified.", call. = FALSE)
} else if (length(dots)) {
loos <- dots
} else {
stopifnot(is.list(loos))
}
loos <- validate_loos(loos)
comp <- loo::compare(x = loos)
structure(
comp,
class = c("compare_rstanarm_loos", class(comp)),
model_names = names(loos),
formulas = if (!detail) NULL else lapply(loos, attr, "formula")
)
}
#' @rdname print.kfold
#' @keywords internal
#' @export
#' @method print compare_rstanarm_loos
print.compare_rstanarm_loos <- function(x, ...) {
formulas <- attr(x, "formulas")
nms <- attr(x, "model_names")
if (!is.null(formulas)) {
cat("Model formulas: ")
for (j in seq_len(NROW(x))) {
cat("\n ", paste0(nms[j], ": "),
formula_string(formulas[[j]]))
}
cat("\n")
}
xcopy <- x
class(xcopy) <- "compare.loo"
if (NROW(x) == 2) {
cat("\nModel comparison: ")
cat("\n(negative 'elpd_diff' favors 1st model, positive favors 2nd) \n\n")
} else {
cat("\nModel comparison: ")
cat("\n(ordered by highest ELPD)\n\n")
}
print(xcopy, ...)
return(invisible(x))
}
#' @rdname loo.stanreg
#' @aliases loo_model_weights
#'
#' @importFrom loo loo_model_weights
#' @export loo_model_weights
#'
#' @export
#'
#'
#' @section Model weights: The \code{loo_model_weights} method can be used to
#' compute model weights for a \code{"stanreg_list"} object, which is a list
#' of fitted model objects made with \code{\link{stanreg_list}}. The end of
#' the \strong{Examples} section has a demonstration. For details see the
#' \code{\link[loo]{loo_model_weights}} documentation in the \pkg{loo}
#' package.
#'
loo_model_weights.stanreg_list <-
function(x,
...,
cores = getOption("mc.cores", 1),
k_threshold = NULL) {
loo_list <- vector(mode = "list", length = length(x))
for (j in seq_along(x)) {
loo_list[[j]] <-
loo.stanreg(x[[j]], cores = cores, k_threshold = k_threshold)
}
wts <- loo::loo_model_weights.default(x = loo_list, ...)
setNames(wts, names(x))
}
# internal ----------------------------------------------------------------
validate_k_threshold <- function(k) {
if (!is.numeric(k) || length(k) != 1) {
stop("'k_threshold' must be a single numeric value.",
call. = FALSE)
} else if (k < 0) {
stop("'k_threshold' < 0 not allowed.",
call. = FALSE)
} else if (k > 1) {
warning(
"Setting 'k_threshold' > 1 is not recommended.",
"\nFor details see the PSIS-LOO section in help('loo-package', 'loo').",
call. = FALSE
)
}
}
recommend_kfold <- function(n) {
warning(
"Found ", n, " observations with a pareto_k > 0.7. ",
"With this many problematic observations we recommend calling ",
"'kfold' with argument 'K=10' to perform 10-fold cross-validation ",
"rather than LOO.\n",
call. = FALSE
)
}
recommend_reloo <- function(n) {
warning(
"Found ", n, " observation(s) with a pareto_k > 0.7. ",
"We recommend calling 'loo' again with argument 'k_threshold = 0.7' ",
"in order to calculate the ELPD without the assumption that ",
"these observations are negligible. ", "This will refit the model ",
n, " times to compute the ELPDs for the problematic observations directly.\n",
call. = FALSE
)
}
recommend_exact_loo <- function(reason) {
stop(
"'loo' is not supported if ", reason, ". ",
"If refitting the model 'nobs(x)' times is feasible, ",
"we recommend calling 'kfold' with K equal to the ",
"total number of observations in the data to perform exact LOO-CV.\n",
call. = FALSE
)
}
# Refit model leaving out specific observations
#
# @param x stanreg object
# @param loo_x the result of loo(x)
# @param obs vector of observation indexes. the model will be refit length(obs)
# times, each time leaving out one of the observations specified in 'obs'.
# @param ... unused currently
# @param refit logical, to toggle whether refitting actually happens (only used
# to avoid refitting in tests)
#
# @return A modified version of 'loo_x'.
# @importFrom utils capture.output
reloo <- function(x, loo_x, obs, ..., refit = TRUE) {
if (is.stanmvreg(x))
STOP_if_stanmvreg("reloo")
stopifnot(!is.null(x$data), is.loo(loo_x))
J <- length(obs)
d <- kfold_and_reloo_data(x)
lls <- vector("list", J)
message(
J, " problematic observation(s) found.",
"\nModel will be refit ", J, " times."
)
if (!refit)
return(NULL)
for (j in 1:J) {
message(
"\nFitting model ", j, " out of ", J,
" (leaving out observation ", obs[j], ")"
)
omitted <- obs[j]
if (is_clogit(x)) {
strata_id <- model.weights(model.frame(x))
omitted <- which(strata_id == strata_id[obs[j]])
}
fit_j_call <-
update(
x,
data = d[-omitted, , drop = FALSE],
subset = rep(TRUE, nrow(d) - length(omitted)),
evaluate = FALSE,
refresh = 0,
open_progress = FALSE
)
fit_j_call$subset <- if (!is.stansurv(x)) eval(fit_j_call$subset) else NULL
fit_j_call$data <- eval(fit_j_call$data)
if (!is.null(getCall(x)$offset)) {
fit_j_call$offset <- x$offset[-omitted]
}
capture.output(
fit_j <- suppressWarnings(eval(fit_j_call))
)
lls[[j]] <-
log_lik.stanreg(
fit_j,
newdata = d[omitted, , drop = FALSE],
offset = x$offset[omitted],
newx = get_x(x)[omitted, , drop = FALSE],
newz = x$z[omitted, , drop = FALSE], # NULL other than for some stan_betareg models
stanmat = as.matrix.stanreg(fit_j)
)
}
# compute elpd_{loo,j} for each of the held out observations
elpd_loo <- unlist(lapply(lls, log_mean_exp))
# compute \hat{lpd}_j for each of the held out observations (using log-lik
# matrix from full posterior, not the leave-one-out posteriors)
ll_x <- log_lik(
object = x,
newdata = d[obs,, drop=FALSE],
offset = x$offset[obs]
)
hat_lpd <- apply(ll_x, 2, log_mean_exp)
# compute effective number of parameters
p_loo <- hat_lpd - elpd_loo
# replace parts of the loo object with these computed quantities
sel <- c("elpd_loo", "p_loo", "looic")
loo_x$pointwise[obs, sel] <- cbind(elpd_loo, p_loo, -2 * elpd_loo)
loo_x$estimates[sel, "Estimate"] <- with(loo_x, colSums(pointwise[, sel]))
loo_x$estimates[sel, "SE"] <- with(loo_x, {
N <- nrow(pointwise)
sqrt(N * apply(pointwise[, sel], 2, var))
})
loo_x$diagnostics$pareto_k[obs] <- NA
return(loo_x)
}
log_sum_exp2 <- function(a,b) {
m <- max(a,b)
m + log(sum(exp(c(a,b) - m)))
}
# @param x numeric vector
log_sum_exp <- function(x) {
max_x <- max(x)
max_x + log(sum(exp(x - max_x)))
}
# log_mean_exp (just log_sum_exp(x) - log(length(x)))
log_mean_exp <- function(x) {
log_sum_exp(x) - log(length(x))
}
# Get correct data to use for kfold and reloo
#
# @param x stanreg object
# @return data frame
kfold_and_reloo_data <- function(x) {
# either data frame or environment
d <- x[["data"]]
sub <- getCall(x)[["subset"]]
if (!is.null(sub)) {
keep <- eval(substitute(sub), envir = d)
}
if (is.environment(d)) {
# make data frame
d <- get_all_vars(formula(x), data = d)
} else {
# already a data frame
all_vars <- all.vars(formula(x))
if ("." %in% all_vars) {
all_vars <- seq_len(ncol(d))
}
d <- d[, all_vars, drop=FALSE]
}
if (!is.null(sub)) {
d <- d[keep,, drop=FALSE]
}
d <- na.omit(d)
if (is_clogit(x)) {
strata_var <- as.character(getCall(x)$strata)
d[[strata_var]] <- model.weights(model.frame(x))
}
return(d)
}
# Calculate a SHA1 hash of y
# @param x stanreg object
# @param ... Passed to digest::sha1
#
hash_y <- function(x, ...) {
if (!requireNamespace("digest", quietly = TRUE))
stop("Please install the 'digest' package.")
validate_stanreg_object(x)
y <- get_y(x)
attributes(y) <- NULL
digest::sha1(x = y, ...)
}
# check if discrete or continuous
# @param object stanreg object
is_discrete <- function(object) {
if (inherits(object, "polr"))
return(TRUE)
if (inherits(object, "stansurv"))
return(FALSE)
if (inherits(object, "stanmvreg")) {
fams <- fetch(family(object), "family")
res <- sapply(fams, function(x)
is.binomial(x) || is.poisson(x) || is.nb(x))
return(res)
}
fam <- family(object)$family
is.binomial(fam) || is.poisson(fam) || is.nb(fam)
}
is.loo <- function(x) inherits(x, "loo")
is.kfold <- function(x) is.loo(x) && inherits(x, "kfold")
is.waic <- function(x) is.loo(x) && inherits(x, "waic")
# validate objects for model comparison
validate_loos <- function(loos = list()) {
if (length(loos) <= 1)
stop("At least two objects are required for model comparison.",
call. = FALSE)
is_loo <- sapply(loos, is.loo)
is_waic <- sapply(loos, is.waic)
is_kfold <- sapply(loos, is.kfold)
if (!all(is_loo))
stop("All objects must have class 'loo'", call. = FALSE)
if ((any(is_waic) && !all(is_waic) ||
(any(is_kfold) && !all(is_kfold))))
stop("Can't mix objects computed using 'loo', 'waic', and 'kfold'.",
call. = FALSE)
yhash <- lapply(loos, attr, which = "yhash")
yhash_check <- sapply(yhash, function(x) {
isTRUE(all.equal(x, yhash[[1]]))
})
if (!all(yhash_check))
stop("Not all models have the same y variable.", call. = FALSE)
discrete <- sapply(loos, attr, which = "discrete")
if (!all(discrete == discrete[1]))
stop("Discrete and continuous observation models can't be compared.",
call. = FALSE)
setNames(loos, nm = lapply(loos, attr, which = "name"))
}
# chain_id to pass to loo::relative_eff
chain_id_for_loo <- function(object) {
dims <- dim(object$stanfit)[1:2]
n_iter <- dims[1]
n_chain <- dims[2]
rep(1:n_chain, each = n_iter)
}
# model formula to store in loo object
# @param x stanreg object
loo_model_formula <- function(x) {
form <- try(formula(x), silent = TRUE)
if (inherits(form, "try-error") || is.null(form)) {
form <- "formula not found"
}
return(form)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.