loo.stanmstte: Information criteria and cross-validation

Description Usage Arguments Value Approximate LOO CV See Also Examples

Description

For models fit using MCMC, compute approximate leave-one-out cross-validation (LOO, LOOIC) or, less preferably, the Widely Applicable Information Criterion (WAIC) using the loo package. Functions for K-fold cross-validation, model comparison, and model weighting/averaging are also provided. Note: these functions are not guaranteed to work properly unless the data argument was specified when the model was fit. Also, as of loo version 2.0.0 the default number of cores is now only 1, but we recommend using as many (or close to as many) cores as possible by setting the cores argument or using options(mc.cores = VALUE) to set it for an entire session.

Usage

1
2
3
## S3 method for class 'stanmstte'
loo(x, ..., cores = getOption("mc.cores", 1),
  save_psis = FALSE, k_threshold = NULL)

Arguments

x

For loo, waic, and kfold methods, a fitted model object returned by one of the rstanarm modeling functions. See stanreg-objects.

For loo_model_weights, x should be a "stanreg_list" object, which is a list of fitted model objects created by stanreg_list.

...

For compare_models, ... should contain two or more objects returned by the loo, kfold, or waic method (see the Examples section, below).

For loo_model_weights, ... should contain arguments (e.g. method) to pass to the default loo_model_weights method from the loo package.

cores, save_psis

Passed to loo.

k_threshold

Threshold for flagging estimates of the Pareto shape parameters k estimated by loo. See the How to proceed when loo gives warnings section, below, for details.

Value

The structure of the objects returned by loo and waic methods are documented in detail in the Value section in loo and waic (from the loo package).

Approximate LOO CV

The loo method for stanreg objects provides an interface to the loo package for approximate leave-one-out cross-validation (LOO). The LOO Information Criterion (LOOIC) has the same purpose as the Akaike Information Criterion (AIC) that is used by frequentists. Both are intended to estimate the expected log predictive density (ELPD) for a new dataset. However, the AIC ignores priors and assumes that the posterior distribution is multivariate normal, whereas the functions from the loo package do not make this distributional assumption and integrate over uncertainty in the parameters. This only assumes that any one observation can be omitted without having a major effect on the posterior distribution, which can be judged using the diagnostic plot provided by the plot.loo method and the warnings provided by the print.loo method (see the How to Use the rstanarm Package vignette for an example of this process).

How to proceed when loo gives warnings (k_threshold)

The k_threshold argument to the loo method for rstanarm models is provided as a possible remedy when the diagnostics reveal problems stemming from the posterior's sensitivity to particular observations. Warnings about Pareto k estimates indicate observations for which the approximation to LOO is problematic (this is described in detail in Vehtari, Gelman, and Gabry (2017) and the loo package documentation). The k_threshold argument can be used to set the k value above which an observation is flagged. If k_threshold is not NULL and there are J observations with k estimates above k_threshold then when loo is called it will refit the original model J times, each time leaving out one of the J problematic observations. The pointwise contributions of these observations to the total ELPD are then computed directly and substituted for the previous estimates from these J observations that are stored in the object created by loo.

Note: in the warning messages issued by loo about large Pareto k estimates we recommend setting k_threshold to at least 0.7. There is a theoretical reason, explained in Vehtari, Gelman, and Gabry (2017), for setting the threshold to the stricter value of 0.5, but in practice they find that errors in the LOO approximation start to increase non-negligibly when k > 0.7.

See Also

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
fit1 <- stan_glm(mpg ~ wt, data = mtcars)
fit2 <- stan_glm(mpg ~ wt + cyl, data = mtcars)

# compare on LOOIC
# (for bigger models use as many cores as possible)
loo1 <- loo(fit1, cores = 2)
print(loo1)
loo2 <- loo(fit2, cores = 2)
print(loo2)

# when comparing exactly two models, the reported 'elpd_diff'
# will be positive if the expected predictive accuracy for the
# second model is higher. the approximate standard error of the
# difference is also reported.
compare_models(loo1, loo2)
compare_models(loos = list(loo1, loo2)) # can also provide list

# when comparing three or more models they are ordered by
# expected predictive accuracy. elpd_diff and se_diff are relative
# to the model with best elpd_loo (first row)
fit3 <- stan_glm(mpg ~ disp * as.factor(cyl), data = mtcars)
loo3 <- loo(fit3, cores = 2, k_threshold = 0.7)
compare_models(loo1, loo2, loo3)

# setting detail=TRUE will also print model formulas
compare_models(loo1, loo2, loo3, detail=TRUE)

# Computing model weights
model_list <- stanreg_list(fit1, fit2, fit3)
loo_model_weights(model_list, cores = 2) # can specify k_threshold=0.7 if necessary

# if you have already computed loo then it's more efficient to pass a list
# of precomputed loo objects than a "stanreg_list", avoiding the need
# for loo_models weights to call loo() internally
loo_list <- list(fit1 = loo1, fit2 = loo2, fit3 = loo3) # names optional (affects printing)
loo_model_weights(loo_list)

# 10-fold cross-validation
(kfold1 <- kfold(fit1, K = 10))
kfold2 <- kfold(fit2, K = 10)
compare_models(kfold1, kfold2, detail=TRUE)

# Cross-validation stratifying by a grouping variable
# (note: might get some divergences warnings with this model but
# this is just intended as a quick example of how to code this)
library(loo)
fit4 <- stan_lmer(mpg ~ disp + (1|cyl), data = mtcars)
table(mtcars$cyl)
folds_cyl <- kfold_split_stratified(K = 3, x = mtcars$cyl)
table(cyl = mtcars$cyl, fold = folds_cyl)
kfold4 <- kfold(fit4, K = 3, folds = folds_cyl)

csetraynor/mstte documentation built on May 31, 2019, 4:54 a.m.