# Part of the rstanarm package for estimating model parameters
# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
suppressPackageStartupMessages(library(rstanarm))
LOO.CORES <- ifelse(.Platform$OS.type == "windows", 1, 2)
SEED <- 1234L
set.seed(SEED)
CHAINS <- 2
ITER <- 40 # small iter for speed but large enough for psis
REFRESH <- 0
if (!exists("example_model")) {
example_model <- run_example_model()
}
# loo and waic ------------------------------------------------------------
context("loo and waic")
# These tests just check that the loo.stanreg method (which calls loo.function
# method) results are identical to the loo.matrix results. Since for these tests
# the log-likelihood matrix is computed using the log-likelihood function, the
# only thing these tests really do is make sure that loo.stanreg and all the
# log-likelihood functions don't return any errors and whatnot (it does not
# check that the results returned by loo are actually correct).
expect_equivalent_loo <- function(fit) {
l <- suppressWarnings(loo(fit, cores = LOO.CORES))
w <- suppressWarnings(waic(fit))
expect_s3_class(l, "loo")
expect_s3_class(w, "loo")
expect_s3_class(w, "waic")
att_names <- c("names", "dims", "class", "model_name", "discrete", "yhash", "formula")
expect_named(attributes(l), att_names)
expect_named(attributes(w), att_names)
discrete <- attr(l, "discrete")
expect_true(!is.na(discrete) && is.logical(discrete))
llik <- log_lik(fit)
r <- loo::relative_eff(exp(llik), chain_id = rstanarm:::chain_id_for_loo(fit))
l2 <- suppressWarnings(loo(llik, r_eff = r, cores = LOO.CORES))
expect_equal(l$estimates, l2$estimates)
expect_equivalent(w, suppressWarnings(waic(log_lik(fit))))
}
test_that("loo & waic do something for non mcmc models", {
SW(fito <- stan_glm(mpg ~ wt, data = mtcars, algorithm = "optimizing",
seed = 1234L, prior_intercept = NULL, refresh = 0,
prior = NULL, prior_aux = NULL))
SW(fitvb1 <- update(fito, algorithm = "meanfield", iter = ITER))
SW(fitvb2 <- update(fito, algorithm = "fullrank", iter = ITER))
SW(loo1 <- loo(fito))
SW(loo2 <- loo(fitvb1))
SW(loo3 <- loo(fitvb2))
expect_true("importance_sampling_loo" %in% class(loo1))
expect_true("importance_sampling_loo" %in% class(loo2))
expect_true("importance_sampling_loo" %in% class(loo3))
})
test_that("loo errors if model has weights", {
SW(
fit <- stan_glm(mpg ~ wt, data = mtcars,
weights = rep_len(c(1,2), nrow(mtcars)),
seed = SEED, refresh = 0, iter = 50)
)
expect_error(loo(fit), "not supported")
expect_error(loo(fit), "'kfold'")
})
test_that("loo can handle empty interaction levels", {
d <- expand.grid(group1 = c("A", "B"), group2 = c("a", "b", "c"))[1:5,]
d$y <- c(0, 1, 0, 1, 0)
SW(fit <- rstanarm::stan_glm(y ~ group1:group2, data = d, family = "binomial",
refresh = 0, iter = 20, chains = 1))
SW(loo1 <- loo(fit))
expect_output(print(loo1), "Computed from 10 by 5 log-likelihood matrix")
})
# loo with refitting ------------------------------------------------------
context("loo then refitting")
test_that("loo issues errors/warnings", {
expect_warning(loo(example_model, cores = LOO.CORES, k_threshold = 2),
"Setting 'k_threshold' > 1 is not recommended")
expect_error(loo(example_model, k_threshold = -1),
"'k_threshold' < 0 not allowed.")
expect_error(loo(example_model, k_threshold = 1:2),
"'k_threshold' must be a single numeric value")
expect_warning(rstanarm:::recommend_kfold(5), "Found 5")
expect_warning(rstanarm:::recommend_kfold(5), "10-fold")
expect_warning(rstanarm:::recommend_reloo(7), "Found 7")
})
test_that("loo with k_threshold works", {
SW(fit <- stan_glm(mpg ~ wt, prior = normal(0, 500), data = mtcars[25:32,],
seed = 12345, iter = 5, chains = 1, cores = 1,
refresh = 0))
expect_message(loo(fit, k_threshold = 0.5), "Model will be refit")
# test that no errors from binomial model because it's trickier to get the
# data right internally in reloo (matrix outcome)
SW(loo_x <- loo(example_model))
expect_message(rstanarm:::reloo(example_model, loo_x, obs = 1),
"Model will be refit 1 times")
})
test_that("loo with k_threshold works for edge case(s)", {
# without 'data' argument
y <- mtcars$mpg[1:10]
x <- rexp(length(y))
SW(fit <- stan_glm(y ~ 1, refresh = 0, iter = 50))
expect_message(
SW(res <- loo(fit, k_threshold = 0.1, cores = LOO.CORES)), # low k_threshold to make sure reloo is triggered
"problematic observation\\(s\\) found"
)
expect_s3_class(res, "loo")
})
# kfold -------------------------------------------------------------------
context("kfold")
test_that("kfold does not throw an error for non mcmc models", {
SW(fito <- stan_glm(mpg ~ wt, data = mtcars, algorithm = "optimizing",
seed = 1234L, refresh = 0))
SW(k <- kfold(fito, K = 2))
expect_true("kfold" %in% class(k))
})
test_that("kfold throws error if K <= 1 or K > N", {
expect_error(kfold(example_model, K = 1), "K > 1", fixed = TRUE)
expect_error(kfold(example_model, K = 1e5), "K <= nobs(x)", fixed = TRUE)
})
test_that("kfold throws error if folds arg is bad", {
expect_error(kfold(example_model, K = 2, folds = 1:100), "length(folds) == N is not TRUE", fixed = TRUE)
expect_error(kfold(example_model, K = 2, folds = 1:2), "length(folds) == N is not TRUE", fixed = TRUE)
expect_error(kfold(example_model, K = 2, folds = seq(1,100, length.out = 56)), "all(folds == as.integer(folds)) is not TRUE", fixed = TRUE)
})
test_that("kfold throws error if model has weights", {
SW(
fit <- stan_glm(mpg ~ wt, data = mtcars,
iter = ITER, chains = CHAINS, refresh = 0,
weights = runif(nrow(mtcars), 0.5, 1.5))
)
expect_error(kfold(fit), "not currently available for models fit using weights")
})
test_that("kfold works on some examples", {
mtcars2 <- mtcars
mtcars2$wt[1] <- NA # make sure kfold works if NAs are dropped from original data
SW(
fit_gaus <- stan_glm(mpg ~ wt, data = mtcars2, refresh = 0,
chains = 1, iter = 10)
)
SW(kf <- kfold(fit_gaus, 2))
SW(kf2 <- kfold(example_model, 2))
expect_named(kf, c("estimates", "pointwise", "elpd_kfold", "se_elpd_kfold", "p_kfold", "se_p_kfold"))
expect_named(kf2, c("estimates", "pointwise", "elpd_kfold", "se_elpd_kfold", "p_kfold", "se_p_kfold"))
expect_named(attributes(kf), c("names", "class", "K", "dims", "model_name", "discrete", "yhash", "formula"))
expect_named(attributes(kf2), c("names", "class", "K", "dims", "model_name", "discrete", "yhash", "formula"))
expect_s3_class(kf, c("kfold", "loo"))
expect_s3_class(kf2, c("kfold", "loo"))
expect_false(is.na(kf$p_kfold))
expect_false(is.na(kf2$p_kfold))
SW(kf <- kfold(fit_gaus, K = 2, save_fits = TRUE))
expect_true("fits" %in% names(kf))
expect_s3_class(kf$fits[[1, "fit"]], "stanreg")
expect_type(kf$fits[[2, "omitted"]], "integer")
expect_length(kf$fits[[2, "omitted"]], 16)
})
# loo_compare ----------------------------------------------------------
test_that("loo_compare throws correct errors", {
SW(capture.output({
mtcars$mpg <- as.integer(mtcars$mpg)
fit1 <- stan_glm(mpg ~ wt, data = mtcars, iter = 5, chains = 2, refresh = 0)
fit2 <- update(fit1, data = mtcars[-1, ])
fit3 <- update(fit1, formula. = log(mpg) ~ .)
fit4 <- update(fit1, family = poisson("log"))
l1 <- loo(fit1, cores = LOO.CORES)
l2 <- loo(fit2, cores = LOO.CORES)
l3 <- loo(fit3, cores = LOO.CORES)
l4 <- loo(fit4, cores = LOO.CORES)
w1 <- waic(fit1)
k1 <- kfold(fit1, K = 3)
}))
# this uses loo::loo_compare
expect_error(loo_compare(l1, l2),
"Not all models have the same number of data points")
expect_error(loo_compare(list(l4, l2, l3)),
"Not all models have the same number of data points")
# using loo_compare.stanreg (can do extra checks)
fit1$loo <- l1
fit2$loo <- l2
fit3$loo <- l3
fit4$loo <- l4
expect_error(loo_compare(fit1, fit2), "Not all models have the same number of data points")
expect_warning(loo_compare(fit1, fit3), "Not all models have the same y variable")
expect_error(loo_compare(fit1, fit4),
"Discrete and continuous observation models can't be compared")
expect_error(loo_compare(l1, fit1),
"All inputs should have class 'loo'")
expect_error(loo_compare(l1),
"requires at least two models")
})
test_that("loo_compare works", {
suppressWarnings(capture.output({
mtcars$mpg <- as.integer(mtcars$mpg)
fit1 <- stan_glm(mpg ~ wt, data = mtcars, iter = 40, chains = 2, refresh = 0)
fit2 <- update(fit1, formula. = . ~ . + cyl)
fit3 <- update(fit2, formula. = . ~ . + gear)
fit4 <- update(fit1, family = "poisson")
fit5 <- update(fit1, family = "neg_binomial_2")
fit1$loo <- loo(fit1, cores = LOO.CORES)
fit2$loo <- loo(fit2, cores = LOO.CORES)
fit3$loo <- loo(fit3, cores = LOO.CORES)
fit4$loo <- loo(fit4, cores = LOO.CORES)
fit5$loo <- loo(fit5, cores = LOO.CORES)
k1 <- kfold(fit1, K = 2)
k2 <- kfold(fit2, K = 2)
k3 <- kfold(fit3, K = 3)
k4 <- kfold(fit4, K = 2)
k5 <- kfold(fit5, K = 2)
}))
expect_false(attr(fit1$loo, "discrete"))
expect_false(attr(fit2$loo, "discrete"))
expect_false(attr(fit3$loo, "discrete"))
expect_true(attr(fit4$loo, "discrete"))
expect_true(attr(fit5$loo, "discrete"))
comp1 <- loo_compare(fit1, fit2)
comp2 <- loo_compare(fit1, fit2, fit3)
expect_s3_class(comp1, "compare.loo")
expect_s3_class(comp2, "compare.loo")
expect_equal(comp1[, "elpd_diff"], loo_compare(list(fit1$loo, fit2$loo))[, "elpd_diff"])
expect_equal(comp2[, "elpd_diff"], loo_compare(list(fit1$loo, fit2$loo, fit3$loo))[, "elpd_diff"])
comp1_detail <- loo_compare(fit1, fit2, detail=TRUE)
expect_output(print(comp1_detail), "Model formulas")
# equivalent to stanreg_list method
expect_equivalent(comp2, loo_compare(stanreg_list(fit1, fit2, fit3)))
# for kfold
expect_warning(comp3 <- loo_compare(k1, k2, k3),
"Not all kfold objects have the same K value")
expect_true(attr(k4, "discrete"))
expect_true(attr(k5, "discrete"))
expect_s3_class(loo_compare(k4, k5), "compare.loo")
})
# helpers -----------------------------------------------------------------
context("loo and waic helpers")
test_that("kfold_and_reloo_data works", {
f <- rstanarm:::kfold_and_reloo_data
d <- f(example_model)
expect_identical(d, lme4::cbpp[, colnames(d)])
# if 'data' arg not originally specified when fitting the model
y <- rnorm(40)
SW(fit <- stan_glm(y ~ 1, iter = ITER, chains = CHAINS, refresh = 0))
expect_equivalent(f(fit), model.frame(fit))
# if 'subset' arg specified when fitting the model
SW(fit2 <- stan_glm(mpg ~ wt, data = mtcars, subset = gear != 5, iter = ITER,
chains = CHAINS, refresh = 0))
expect_equivalent(f(fit2), subset(mtcars[mtcars$gear != 5, c("mpg", "wt")]))
})
test_that(".weighted works", {
f <- rstanarm:::.weighted
expect_equal(f(2, NULL), 2)
expect_equal(f(2, 3), 6)
expect_equal(f(8, 0.25), 2)
expect_error(f(2), "missing, with no default")
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.