tests/testthat/test-mcmc-nuts.R

library(bayesplot)
context("MCMC: nuts")

if (requireNamespace("rstanarm", quietly = TRUE)) {
  ITER <- 1000
  CHAINS <- 3
  fit <- rstanarm::stan_glm(mpg ~ wt + am, data = mtcars,
                            iter = ITER, chains = CHAINS,
                            refresh = 0)
  np <- nuts_params(fit)
  lp <- log_posterior(fit)
}

test_that("all mcmc_nuts_* (except energy) return gtable objects", {
  skip_if_not_installed("rstanarm")
  expect_gtable(mcmc_nuts_acceptance(np, lp))
  expect_gtable(mcmc_nuts_acceptance(np, lp, chain = CHAINS))

  expect_gtable(mcmc_nuts_treedepth(np, lp))
  expect_gtable(mcmc_nuts_treedepth(np, lp, chain = CHAINS))

  expect_gtable(mcmc_nuts_stepsize(np, lp))
  expect_gtable(mcmc_nuts_stepsize(np, lp, chain = CHAINS))

  np <- ensure_divergences(np)
  expect_gtable(mcmc_nuts_divergence(np, lp))
  expect_gtable(mcmc_nuts_divergence(np, lp, chain = CHAINS))
})

test_that("all mcmc_nuts_* (except energy) error if chain argument is bad", {
  skip_if_not_installed("rstanarm")
  funs <- c("acceptance", "divergence", "treedepth", "stepsize")
  for (f in paste0("mcmc_nuts_", funs)) {
    expect_error(do.call(f, list(x=np, lp=lp, chain = CHAINS + 1)),
                 regexp = paste("only", CHAINS, "chains found"),
                 info = f)
    expect_error(do.call(f, list(x=np, lp=lp, chain = 0)),
                 regexp = "chain >= 1",
                 info = f)
  }
})

test_that("mcmc_nuts_energy returns a ggplot object", {
  skip_if_not_installed("rstanarm")

  p <- mcmc_nuts_energy(np)
  expect_gg(p)
  expect_s3_class(p$facet, "FacetWrap")
  expect_equal(names(p$facet$params$facets), "Chain")

  p <- mcmc_nuts_energy(np, merge_chains = TRUE)
  expect_gg(p)
  expect_s3_class(p$facet, "FacetNull")
})

test_that("mcmc_nuts_energy throws correct warnings", {
  skip_if_not_installed("rstanarm")
  expect_warning(mcmc_nuts_energy(np, chain = 1), "ignored: chain")
})


test_that("validate_nuts_data_frame throws errors", {
  skip_if_not_installed("rstanarm")
  expect_error(
    validate_nuts_data_frame(list(Iteration = 1, Chain = 1)),
    "NUTS parameters should be in a data frame"
  )
  expect_error(
    validate_nuts_data_frame(data.frame(Iteration = 1, apple = 2)),
    "NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value"
  )
  expect_error(
    validate_nuts_data_frame(np, as.matrix(lp)),
    "lp should be in a data frame"
  )

  lp2 <- lp
  colnames(lp2)[3] <- "Chains"
  expect_error(
    validate_nuts_data_frame(np, lp2),
    "lp data frame must have columns: Chain, Iteration, Value"
  )

  lp2 <- subset(lp, Chain %in% 1:2)
  expect_error(
    validate_nuts_data_frame(np, lp2),
    "Number of chains"
  )
})



# Visual tests -----------------------------------------------------------------

source(test_path("data-for-mcmc-tests.R"))

test_that("mcmc_nuts_acceptance renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_nuts_acceptance(vdiff_dframe_chains_np, vdiff_dframe_chains_lp)
  vdiffr::expect_doppelganger("mcmc_nuts_acceptance (default)", p_base)

  p_chain <- mcmc_nuts_acceptance(vdiff_dframe_chains_np, vdiff_dframe_chains_lp, chain = 1)
  vdiffr::expect_doppelganger("mcmc_nuts_acceptance (chain)", p_chain)
})

test_that("mcmc_nuts_divergence renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_nuts_divergence(vdiff_dframe_chains_np, vdiff_dframe_chains_lp)
  vdiffr::expect_doppelganger("mcmc_nuts_divergence (default)", p_base)

  p_chain <- mcmc_nuts_divergence(vdiff_dframe_chains_np, vdiff_dframe_chains_lp, chain = 1)
  vdiffr::expect_doppelganger("mcmc_nuts_divergence (chain)", p_chain)
})

test_that("mcmc_nuts_treedepth renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_nuts_treedepth(vdiff_dframe_chains_np, vdiff_dframe_chains_lp)
  vdiffr::expect_doppelganger("mcmc_nuts_treedepth (default)", p_base)

  p_chain <- mcmc_nuts_treedepth(vdiff_dframe_chains_np, vdiff_dframe_chains_lp, chain = 1)
  vdiffr::expect_doppelganger("mcmc_nuts_treedepth (chain)", p_chain)
})

test_that("mcmc_nuts_stepsize renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_nuts_stepsize(vdiff_dframe_chains_np, vdiff_dframe_chains_lp)
  vdiffr::expect_doppelganger("mcmc_nuts_stepsize (default)", p_base)

  p_chain <- mcmc_nuts_stepsize(vdiff_dframe_chains_np, vdiff_dframe_chains_lp, chain = 1)
  vdiffr::expect_doppelganger("mcmc_nuts_stepsize (chain)", p_chain)
})

test_that("mcmc_nuts_energy renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_nuts_energy(vdiff_dframe_chains_np, vdiff_dframe_chains_lp, binwidth = 10)
  vdiffr::expect_doppelganger("mcmc_nuts_energy (default)", p_base)

  p_merged <- mcmc_nuts_energy(vdiff_dframe_chains_np, vdiff_dframe_chains_lp, binwidth = 10, merge_chains = TRUE)
  vdiffr::expect_doppelganger("mcmc_nuts_energy (merged)", p_merged)
})

Try the bayesplot package in your browser

Any scripts or data that you put into this service are public.

bayesplot documentation built on Nov. 17, 2022, 1:08 a.m.