tests/testthat/test_plots.R

# Part of the rstanarm package for estimating model parameters
# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 3
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


suppressPackageStartupMessages(library(rstanarm))
SEED <- 123
ITER <- 10
CHAINS <- 2
CORES <- 1

if (!exists("example_model")) {
  example_model <- run_example_model()
}
fit <- example_model
SW(fito <- stan_glm(mpg ~ ., data = mtcars, algorithm = "optimizing", seed = SEED, refresh = 0))
SW(fitvb <- update(fito, algorithm = "meanfield"))

# plot.stanreg ------------------------------------------------------------
context("plot.stanreg")
test_that("plot.stanreg errors if chains = 1 but needs multiple", {
  multiple_chain_plots <- c("trace_highlight",
                            "hist_by_chain",
                            "dens_overlay",
                            "violin")
  SW(fit_1chain <- stan_glm(mpg ~ wt, data = mtcars, chains = 1, iter = 100, refresh = 0))
  for (f in multiple_chain_plots) {
    expect_error(plot(fit_1chain, plotfun = f), info = f, 
                 regexp = "requires multiple chains")
  }
})

test_that("other plot.stanreg errors thrown correctly", {
  expect_error(plot(fit, plotfun = "9999"), 
               "not a valid MCMC function name")
  expect_error(plot(fit, plotfun = "ppc_hist"), 
               "use the 'pp_check' method")
  expect_error(plot(fit, plotfun = "stan_diag"), 
               "help('NUTS', 'bayesplot')", fixed = TRUE)
})

test_that("plot.stanreg returns correct object", {
  # ggplot objects
  ggplot_object_plots <- c(
    "intervals", "areas",
    "dens", "dens_overlay", 
    "hist", "hist_by_chain",
    "trace", "trace_highlight",
    "violin", 
    "rhat", "rhat_hist", 
    "neff", "neff_hist", "ess",
    "acf", "acf_bar", "ac"
  )
  for (f in ggplot_object_plots)
    expect_gg(plot(fit, f))
  
  # requires exactly 2 parameters
  expect_gg(plot(fit, "scat", pars = c("period2", "period3")))
})

test_that("plot method returns correct object for nuts diagnostic plots", {
  # energy plot returns ggplot object
  expect_gg(plot(fit, "nuts_energy"))
  
  # others return gtable objects
  gtable_object_plots <-
    paste0("nuts_",
           c("stepsize", "acceptance", "divergence", "treedepth"))
  for (f in gtable_object_plots)
    expect_s3_class(plot(fit, plotfun = f), "gtable")
})

test_that("plot.stanreg ok for optimization", {
  expect_gg(plot(fito))
  expect_gg(plot(fito, "areas"))
  expect_gg(plot(fito, "dens"))
  expect_gg(plot(fito, "scatter", pars = c("wt", "cyl")))
  expect_gg(plot(fito, pars = c("alpha", "beta")))
  
  expect_warning(plot(fito, regex_pars = "wt"), 
                 regexp = "'regex_pars' ignored")
  expect_error(plot(fito, "trace"), 
               regexp = "only available for models fit using MCMC")
  expect_error(plot(fito, "nuts_acceptance"), 
               regexp = "only available for models fit using MCMC")
  expect_error(plot(fito, "rhat_hist"), 
               regexp = "only available for models fit using MCMC")
})

test_that("plot.stanreg ok for vb", {
  expect_gg(plot(fitvb))
  expect_gg(plot(fitvb, "areas"))
  expect_gg(plot(fitvb, "dens"))
  expect_gg(plot(fitvb, "scatter", pars = c("wt", "cyl")))
  expect_gg(plot(fitvb, pars = c("alpha", "beta")))
  
  expect_error(plot(fitvb, "trace"), 
               regexp = "only available for models fit using MCMC")
  expect_error(plot(fitvb, "nuts_acceptance"), 
               regexp = "only available for models fit using MCMC")
  expect_error(plot(fitvb, "rhat_hist"), 
               regexp = "only available for models fit using MCMC")
  expect_error(plot(fitvb, "mcmc_neff"), 
               regexp = "only available for models fit using MCMC")
})


# pairs.stanreg -----------------------------------------------------------
context("pairs.stanreg")
test_that("pairs method ok", {
  expect_silent(pairs(fit, pars = c("period2", "log-posterior")))
  expect_silent(pairs(fit, pars = "b[(Intercept) herd:15]", regex_pars = "Sigma"))
  expect_silent(pairs(fit, pars = "b[(Intercept) herd:15]", regex_pars = "Sigma", 
                      condition = pairs_condition(nuts = "lp__")))
  expect_error(pairs(fitvb), regexp = "only available for models fit using MCMC")
  expect_error(pairs(fito), regexp = "only available for models fit using MCMC")
})



# posterior_vs_prior ------------------------------------------------------
context("posterior_vs_prior")
test_that("posterior_vs_prior ok", {
  SW(p1 <- posterior_vs_prior(fit, pars = "beta"))
  expect_gg(p1)
  
  SW(p2 <- posterior_vs_prior(fit, pars = "varying", group_by_parameter = TRUE, 
                              color_by = "vs"))
  expect_gg(p2)
  SW(p3 <- posterior_vs_prior(fit, regex_pars = "period", 
                              group_by_parameter = FALSE, 
                              color_by = "none", 
                              facet_args = list(scales = "free", nrow = 2)))
  expect_gg(p3)
  
  SW(fit_polr <- stan_polr(tobgp ~ agegp, data = esoph, method = "probit",
                          prior = R2(0.2, "mean"), init_r = 0.1, 
                          seed = SEED, chains = CHAINS, cores = CORES, 
                          iter = 100, refresh = 0))
  SW(p4 <- posterior_vs_prior(fit_polr))
  SW(p5 <- posterior_vs_prior(fit_polr, regex_pars = "\\|", 
                              group_by_parameter = TRUE, 
                              color_by = "vs"))
  expect_gg(p4)
  expect_gg(p5)
})

test_that("posterior_vs_prior throws errors", {
  lmfit <- lm(mpg ~ wt, data = mtcars)
  expect_error(posterior_vs_prior(lmfit), "no applicable method")
  expect_error(posterior_vs_prior(fit, prob = 1), "prob < 1")
  expect_error(posterior_vs_prior(fito), 
               "only available for models fit using MCMC")
  expect_error(posterior_vs_prior(fitvb), 
               "only available for models fit using MCMC")
})

Try the rstanarm package in your browser

Any scripts or data that you put into this service are public.

rstanarm documentation built on Sept. 14, 2023, 1:07 a.m.