# Part of the rstanarm package for estimating model parameters
# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
suppressPackageStartupMessages(library(rstanarm))
SEED <- 123
ITER <- 10
CHAINS <- 2
CORES <- 1
if (!exists("example_model")) {
example_model <- run_example_model()
}
fit <- example_model
SW(fito <- stan_glm(mpg ~ ., data = mtcars, algorithm = "optimizing", seed = SEED, refresh = 0))
SW(fitvb <- update(fito, algorithm = "meanfield"))
# plot.stanreg ------------------------------------------------------------
context("plot.stanreg")
test_that("plot.stanreg errors if chains = 1 but needs multiple", {
multiple_chain_plots <- c("trace_highlight",
"hist_by_chain",
"dens_overlay",
"violin")
SW(fit_1chain <- stan_glm(mpg ~ wt, data = mtcars, chains = 1, iter = 100, refresh = 0))
for (f in multiple_chain_plots) {
expect_error(plot(fit_1chain, plotfun = f), info = f,
regexp = "requires multiple chains")
}
})
test_that("other plot.stanreg errors thrown correctly", {
expect_error(plot(fit, plotfun = "9999"),
"not a valid MCMC function name")
expect_error(plot(fit, plotfun = "ppc_hist"),
"use the 'pp_check' method")
expect_error(plot(fit, plotfun = "stan_diag"),
"help('NUTS', 'bayesplot')", fixed = TRUE)
})
test_that("plot.stanreg returns correct object", {
# ggplot objects
ggplot_object_plots <- c(
"intervals", "areas",
"dens", "dens_overlay",
"hist", "hist_by_chain",
"trace", "trace_highlight",
"violin",
"rhat", "rhat_hist",
"neff", "neff_hist", "ess",
"acf", "acf_bar", "ac"
)
for (f in ggplot_object_plots)
expect_gg(plot(fit, f))
# requires exactly 2 parameters
expect_gg(plot(fit, "scat", pars = c("period2", "period3")))
})
test_that("plot method returns correct object for nuts diagnostic plots", {
# energy plot returns ggplot object
expect_gg(plot(fit, "nuts_energy"))
# others return gtable objects
gtable_object_plots <-
paste0("nuts_",
c("stepsize", "acceptance", "divergence", "treedepth"))
for (f in gtable_object_plots)
expect_s3_class(plot(fit, plotfun = f), "gtable")
})
test_that("plot.stanreg ok for optimization", {
expect_gg(plot(fito))
expect_gg(plot(fito, "areas"))
expect_gg(plot(fito, "dens"))
expect_gg(plot(fito, "scatter", pars = c("wt", "cyl")))
expect_gg(plot(fito, pars = c("alpha", "beta")))
expect_warning(plot(fito, regex_pars = "wt"),
regexp = "'regex_pars' ignored")
expect_error(plot(fito, "trace"),
regexp = "only available for models fit using MCMC")
expect_error(plot(fito, "nuts_acceptance"),
regexp = "only available for models fit using MCMC")
expect_error(plot(fito, "rhat_hist"),
regexp = "only available for models fit using MCMC")
})
test_that("plot.stanreg ok for vb", {
expect_gg(plot(fitvb))
expect_gg(plot(fitvb, "areas"))
expect_gg(plot(fitvb, "dens"))
expect_gg(plot(fitvb, "scatter", pars = c("wt", "cyl")))
expect_gg(plot(fitvb, pars = c("alpha", "beta")))
expect_error(plot(fitvb, "trace"),
regexp = "only available for models fit using MCMC")
expect_error(plot(fitvb, "nuts_acceptance"),
regexp = "only available for models fit using MCMC")
expect_error(plot(fitvb, "rhat_hist"),
regexp = "only available for models fit using MCMC")
expect_error(plot(fitvb, "mcmc_neff"),
regexp = "only available for models fit using MCMC")
})
# pairs.stanreg -----------------------------------------------------------
context("pairs.stanreg")
test_that("pairs method ok", {
expect_silent(pairs(fit, pars = c("period2", "log-posterior")))
expect_silent(pairs(fit, pars = "b[(Intercept) herd:15]", regex_pars = "Sigma"))
expect_silent(pairs(fit, pars = "b[(Intercept) herd:15]", regex_pars = "Sigma",
condition = pairs_condition(nuts = "lp__")))
expect_error(pairs(fitvb), regexp = "only available for models fit using MCMC")
expect_error(pairs(fito), regexp = "only available for models fit using MCMC")
})
# posterior_vs_prior ------------------------------------------------------
context("posterior_vs_prior")
test_that("posterior_vs_prior ok", {
SW(p1 <- posterior_vs_prior(fit, pars = "beta"))
expect_gg(p1)
SW(p2 <- posterior_vs_prior(fit, pars = "varying", group_by_parameter = TRUE,
color_by = "vs"))
expect_gg(p2)
SW(p3 <- posterior_vs_prior(fit, regex_pars = "period",
group_by_parameter = FALSE,
color_by = "none",
facet_args = list(scales = "free", nrow = 2)))
expect_gg(p3)
SW(fit_polr <- stan_polr(tobgp ~ agegp, data = esoph, method = "probit",
prior = R2(0.2, "mean"), init_r = 0.1,
seed = SEED, chains = CHAINS, cores = CORES,
iter = 100, refresh = 0))
SW(p4 <- posterior_vs_prior(fit_polr))
SW(p5 <- posterior_vs_prior(fit_polr, regex_pars = "\\|",
group_by_parameter = TRUE,
color_by = "vs"))
expect_gg(p4)
expect_gg(p5)
})
test_that("posterior_vs_prior throws errors", {
lmfit <- lm(mpg ~ wt, data = mtcars)
expect_error(posterior_vs_prior(lmfit), "no applicable method")
expect_error(posterior_vs_prior(fit, prob = 1), "prob < 1")
expect_error(posterior_vs_prior(fito),
"only available for models fit using MCMC")
expect_error(posterior_vs_prior(fitvb),
"only available for models fit using MCMC")
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.