tests/testthat/test-mcmc-scatter-and-parcoord.R

library(bayesplot)
context("MCMC: scatter, hex, and parallel coordinates plots")

source(test_path("data-for-mcmc-tests.R"))

if (requireNamespace("rstanarm", quietly = TRUE)) {
  suppressPackageStartupMessages(library(rstanarm))
  # also fit an rstanarm model to use with mcmc_pairs
  fit <- stan_glm(mpg ~ wt + am, data = mtcars, iter = 1000, chains = 2, refresh = 0)
  post <- as.array(fit)
  lp <- log_posterior(fit)
  np <- ensure_divergences(nuts_params(fit))
}


# mcmc_scatter/hex --------------------------------------------------------
test_that("mcmc_scatter returns a ggplot object", {
  expect_gg(mcmc_scatter(arr, pars = c("beta[1]", "beta[2]")))
  expect_gg(mcmc_scatter(arr1chain, regex_pars = "beta", size = 3, alpha = 0.5))
  expect_gg(mcmc_scatter(drawsarr, pars = c("theta[1]", "theta[2]")))
  expect_gg(mcmc_scatter(mat, pars = c("sigma", "(Intercept)")))
  expect_gg(mcmc_scatter(dframe, regex_pars = "x:[2,4]"))
  expect_gg(mcmc_scatter(dframe_multiple_chains,
                         pars = c("sigma", "(Intercept)")))
})

test_that("mcmc_scatter throws error if number of parameters is not 2", {
  expect_error(mcmc_scatter(arr, pars = c("sigma", "beta[1]", "beta[2]")), "exactly 2 parameters")
  expect_error(mcmc_scatter(arr, pars = "sigma"), "exactly 2 parameters")
  expect_error(mcmc_scatter(drawsarr, pars = "mu"), "exactly 2 parameters")
  expect_error(mcmc_scatter(arr1), "exactly 2 parameters")
  expect_error(mcmc_scatter(drawsarr1), "exactly 2 parameters")
  expect_error(mcmc_scatter(mat1), "exactly 2 parameters")
})

test_that("mcmc_scatter accepts NUTS info", {
  skip_if_not_installed("rstanarm")

  expect_gg(mcmc_scatter(post, pars = c("wt", "sigma"), np = np))

  div_style <- scatter_style_np(div_color = "orange", div_size = 2,
                                div_shape = 3, div_alpha = 0.5)
  g <- mcmc_scatter(post, pars = c("wt", "sigma"), np = np, np_style = div_style)
  expect_gg(g)
  expect_named(g$data, c("x", "y", "Divergent"))
})

test_that("mcmc_hex returns a ggplot object", {
  skip_if_not_installed("hexbin")
  expect_gg(mcmc_hex(arr, pars = c("beta[1]", "beta[2]")))
  expect_gg(mcmc_hex(arr1chain, regex_pars = "beta", binwidth = c(.5,.5)))
  expect_gg(mcmc_hex(drawsarr, pars = c("theta[1]", "theta[2]")))
})

test_that("mcmc_hex throws error if number of parameters is not 2", {
  skip_if_not_installed("hexbin")
  expect_error(mcmc_hex(arr, pars = c("sigma", "beta[1]", "beta[2]")), "exactly 2 parameters")
  expect_error(mcmc_hex(arr, pars = "sigma"), "exactly 2 parameters")
  expect_error(mcmc_hex(drawsarr, pars = "mu"), "exactly 2 parameters")
  expect_error(mcmc_hex(arr1), "exactly 2 parameters")
  expect_error(mcmc_hex(mat1), "exactly 2 parameters")
})



# mcmc_pairs  -------------------------------------------------------------
test_that("mcmc_pairs returns a bayesplot_grid object", {
  g <- mcmc_pairs(arr, pars = c("(Intercept)", "sigma"))
  expect_bayesplot_grid(g)
  expect_equal(print(g), plot(g))
  expect_bayesplot_grid(mcmc_pairs(arr, pars = "sigma", regex_pars = "beta"))
  expect_bayesplot_grid(mcmc_pairs(drawsarr, pars = "mu", regex_pars = "theta"))

  expect_bayesplot_grid(suppressWarnings(mcmc_pairs(arr1chain, regex_pars = "beta")))
  expect_bayesplot_grid(suppressWarnings(mcmc_pairs(drawsarr1chain, regex_pars = "theta")))
  expect_bayesplot_grid(suppressWarnings(mcmc_pairs(mat, pars = c("(Intercept)", "sigma"))))
  expect_bayesplot_grid(suppressWarnings(mcmc_pairs(dframe, pars = c("(Intercept)", "sigma"))))
  expect_bayesplot_grid(mcmc_pairs(dframe_multiple_chains, regex_pars = "beta"))
})

test_that("mcmc_pairs using hexbin works", {
  skip_if_not_installed("hexbin")
  expect_bayesplot_grid(mcmc_pairs(arr, regex_pars = "x:[1-3]",
                                   transformations = "exp",
                                   diag_fun = "dens", off_diag_fun = "hex",
                                   diag_args = list(trim = FALSE),
                                   off_diag_args = list(binwidth = c(0.5, 0.5))))
})

test_that("no mcmc_pairs non-NUTS 'condition's fail", {
  expect_bayesplot_grid(
    mcmc_pairs(arr, pars = "sigma", regex_pars = "beta",
               condition = pairs_condition(chains = list(1, 2:4)))
    )
  expect_bayesplot_grid(
    mcmc_pairs(arr, pars = "sigma", regex_pars = "beta",
               condition = pairs_condition(draws = rep(c(T,F), length.out = prod(dim(arr)[1:2]))))
    )
  expect_bayesplot_grid(
    mcmc_pairs(arr, pars = "sigma", regex_pars = "beta",
               condition = pairs_condition(draws = 1/3))
  )
  expect_bayesplot_grid(
    mcmc_pairs(arr, pars = "sigma", regex_pars = "beta",
               condition = pairs_condition(chains = c(1,3)))
  )
})

test_that("mcmc_pairs works with NUTS info", {
  skip_if_not_installed("rstanarm")
  expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am", "sigma"), np = np))
  expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am"),
                                   condition = pairs_condition(nuts="energy__"), np = np))
  expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am"),
                                   condition = pairs_condition(nuts="divergent__"), np = np))
  expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am"),
                                   condition = pairs_condition(nuts = "lp__"), lp=lp, np = np,
                                   max_treedepth = 2))

  p <- mcmc_pairs(
    post,
    pars = c("wt", "am"),
    off_diag_fun = "scatter",
    condition = pairs_condition(nuts = "lp__"),
    lp = lp,
    np = np,
    np_style = pairs_style_np(div_color = "firebrick", td_color = "dodgerblue", div_size = 2, td_size = 2),
    max_treedepth = with(np, max(Value[Parameter == "treedepth__"]) - 1)
  )
  expect_bayesplot_grid(p)
})


test_that("mcmc_pairs throws correct warnings and errors", {
  skip_if_not_installed("rstanarm")

  expect_warning(mcmc_pairs(arr1chain, regex_pars = "beta"),
                 "This plot is more useful with multiple chains")
  expect_error(mcmc_pairs(arr, pars = "sigma"),
               "requires at least two parameters")

  expect_error(
    mcmc_pairs(arr, condition = pairs_condition(draws = c(T, F))),
    "length(condition) == (n_iter * n_chain) is not TRUE",
    fixed = TRUE
  )
  expect_error(
    mcmc_pairs(arr, condition = pairs_condition(nuts = "accept_stat__")),
    "the 'np' argument to 'mcmc_pairs' must also be specified"
  )
  expect_error(
    mcmc_pairs(arr, condition = pairs_condition(nuts = "lp__")),
    "the 'lp' argument to 'mcmc_pairs' must also be specified"
  )
  expect_error(
    mcmc_pairs(arr, condition = "lp__"),
    'inherits(condition, "pairs_condition") is not TRUE',
    fixed = TRUE
  )

  expect_error(
    mcmc_pairs(post, pars = c("wt", "am"), max_treedepth = 2, np = np,
               np_style = list(color = "green")),
    'inherits(np_style, "nuts_style") is not TRUE',
    fixed = TRUE
  )

  post2 <- post
  post2[,1:2,"wt"] <- 0
  expect_warning(
    mcmc_pairs(post2, pars = c("wt", "am", "sigma")),
    "parameters were dropped because they are constant: wt"
  )

  post[,, "sigma"] <- post[,, "am"]
  expect_warning(
    mcmc_pairs(post, pars = c("wt", "sigma", "am")),
    "parameters were dropped because they are duplicative: am"
  )
})


# pairs_style_np -------------------------------------------------------
test_that("pairs_style_np returns correct structure", {
  style <- pairs_style_np(div_size = 3, td_color = "gray", td_shape = 1)
  expect_s3_class(style, "nuts_style")
  expect_named(style, c("color", "shape", "size", "alpha"), ignore.order = TRUE)
  expect_named(style$color, c("div", "td"))
  expect_named(style$size, c("div", "td"))
  expect_named(style$shape, c("div", "td"))
  expect_named(style$alpha, c("div", "td"))
})

test_that("pairs_style_np throws correct errors", {
  expect_error(
    pairs_style_np(div_size = "3"),
    "is.numeric(div_size) is not TRUE",
    fixed = TRUE
  )
  expect_error(
    pairs_style_np(td_color = 1),
    "is.character(td_color) is not TRUE",
    fixed = TRUE
  )
})


# pairs_condition ---------------------------------------------------------
test_that("pairs_condition returns correct structure", {
  # default
  cond0 <- pairs_condition()
  expect_s3_class(cond0, "pairs_condition")
  expect_equivalent(unclass(cond0), list())
  expect_equal(attr(cond0, "type"), "default")

  # chains
  cond1 <- pairs_condition(chains = 1:4)
  expect_s3_class(cond1, "integer")
  expect_s3_class(cond1, "pairs_condition")
  expect_equivalent(unclass(cond1), 1:4)
  expect_equal(attr(cond1, "type"), "chain_vector")

  cond2 <- pairs_condition(chains = list(1:4, 5:6))
  expect_s3_class(cond2, "list")
  expect_s3_class(cond2, "pairs_condition")
  expect_equivalent(unclass(cond2), list(upper=1:4, lower=5:6))
  expect_equal(attr(cond2, "type"), "chain_list")

  # draws
  cond3 <- pairs_condition(draws = 0.7)
  expect_s3_class(cond3, "numeric")
  expect_s3_class(cond3, "pairs_condition")
  expect_equivalent(unclass(cond3), 0.7)
  expect_equal(attr(cond3, "type"), "draws_proportion")

  cond4 <- pairs_condition(draws = c(T, F, T))
  expect_s3_class(cond4, "logical")
  expect_s3_class(cond4, "pairs_condition")
  expect_equivalent(unclass(cond4), c(T, F, T))
  expect_equal(attr(cond4, "type"), "draws_selection")

  # nuts
  cond5 <- pairs_condition(nuts = "lp__")
  expect_s3_class(cond5, "character")
  expect_s3_class(cond5, "pairs_condition")
  expect_equivalent(unclass(cond5), "lp__")
  expect_equal(attr(cond5, "type"), "nuts")
})

test_that("pairs_condition throws correct errors", {
  # chain
  expect_error(
    pairs_condition(chains = "abc"),
    "must be an integer vector or a list of two integer vectors"
  )
  expect_error(
    pairs_condition(chains = list(1:2, 3:4, 5:6)),
    "length(chains) == 2 is not TRUE",
    fixed = TRUE
  )
  expect_error(
    pairs_condition(chains = list(1:2, 2:3)),
    "Each chain can only be specified once"
  )
  expect_error(
    pairs_condition(chains = c(1:3, 2)),
    "Each chain can only be specified once"
  )

  # draws
  expect_error(
    pairs_condition(draws = "abc"),
    "must be a single proportion or a logical vector"
  )
  expect_error(
    pairs_condition(draws = 2),
    "draws > 0 && draws < 1 is not TRUE",
    fixed = TRUE
  )

  # nuts
  expect_error(
    pairs_condition(nuts = 2),
    "must be a single string"
  )
  expect_error(
    pairs_condition(nuts = c("lp__", "energy__")),
    "must be a single string"
  )
  expect_error(
    pairs_condition(nuts = "step_size__"),
    "stepsize__"
  )
})

test_that("pairs_condition message if multiple args specified", {
  options(useFancyQuotes = FALSE)
  expect_message(
    pairs_condition(chains = 2, draws = 0.5, nuts = "lp__"),
    "because they are superseded by 'chains': 'draws', 'nuts'",
    fixed = TRUE
  )
  expect_message(
    pairs_condition(chains = 2, nuts = "lp__"),
    "because they are superseded by 'chains': 'nuts'",
    fixed = TRUE
  )
  expect_message(
    pairs_condition(draws = 0.5, nuts = "lp__"),
    "because they are superseded by 'draws': 'nuts'",
    fixed = TRUE
  )
})



# mcmc_parcoord -----------------------------------------------------------
test_that("mcmc_parcoord returns a ggplot object", {
  expect_gg(mcmc_parcoord(arr, pars = c("(Intercept)", "sigma")))
  expect_gg(mcmc_parcoord(arr, pars = "sigma", regex_pars = "beta"))
})

test_that("mcmc_parcoord with nuts info returns a ggplot object", {
  skip_if_not_installed("rstanarm")
  expect_gg(mcmc_parcoord(post, pars = c("wt", "am", "sigma"), np = np))
})

test_that("mcmc_parcoord throws correct warnings and errors", {
  skip_if_not_installed("rstanarm")
  expect_error(mcmc_parcoord(arr, pars = "sigma"),
               "requires at least two parameters")

  expect_error(
    mcmc_parcoord(post, np = np[, -1]),
    "NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value",
    fixed = TRUE
  )

  expect_error(
    mcmc_parcoord(post, np = np, np_style = list(div_color = "green")),
    'inherits(np_style, "nuts_style") is not TRUE',
    fixed = TRUE
  )
})


# parcoord_style_np -------------------------------------------------------
test_that("parcoord_style_np returns correct structure", {
  style <- parcoord_style_np()
  expect_s3_class(style, "nuts_style")
  expect_named(style, c("color", "alpha", "size"), ignore.order = TRUE)
  expect_named(style$color, c("div"))
  expect_named(style$size, c("div"))
  expect_named(style$alpha, c("div"))
})

test_that("parcoord_style_np throws correct errors", {
  expect_error(
    parcoord_style_np(div_size = "3"),
    "is.numeric(div_size) is not TRUE",
    fixed = TRUE
  )
  expect_error(
    parcoord_style_np(td_color = 1),
    "unused argument (td_color = 1)",
    fixed = TRUE
  )
})


# Visual tests -----------------------------------------------------------------

test_that("mcmc_scatter renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_scatter(vdiff_dframe_chains)
  vdiffr::expect_doppelganger("mcmc_scatter (default)", p_base)

  p_custom <- mcmc_scatter(
    vdiff_dframe_chains,
    size = 2,
    alpha = 0.2
  )
  vdiffr::expect_doppelganger("mcmc_scatter (size, alpha)", p_custom)

  p_divergences <- mcmc_scatter(
    vdiff_dframe_chains,
    np = vdiff_dframe_chains_divergences
  )
  vdiffr::expect_doppelganger("mcmc_scatter (np)", p_divergences)
})

test_that("mcmc_hex renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")
  skip_if_not_installed("hexbin")

  p_base <- mcmc_hex(vdiff_dframe, pars = c("V1", "V2"))
  vdiffr::expect_doppelganger("mcmc_hex (default)", p_base)

  p_custom <- mcmc_hex(
    vdiff_dframe,
    pars = c("V1", "V2"),
    binwidth = 0.1
  )
  vdiffr::expect_doppelganger("mcmc_hex (bw)", p_custom)
})

test_that("mcmc_parcoord renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")

  p_base <- mcmc_parcoord(vdiff_dframe_chains)
  vdiffr::expect_doppelganger("mcmc_parcoord (default)", p_base)

  p_divergences <- mcmc_parcoord(
    vdiff_dframe_chains,
    np = vdiff_dframe_chains_divergences,
    np_style = parcoord_style_np(div_size = 2)
  )
  vdiffr::expect_doppelganger("mcmc_parcoord (np)", p_divergences)
})


test_that("mcmc_pairs renders correctly", {
  skip_on_cran()
  skip_if_not_installed("vdiffr")
  skip_if_not_installed("hexbin")

  p_base <- mcmc_pairs(vdiff_dframe_chains)
  vdiffr::expect_doppelganger("mcmc_pairs (default)", p_base)

  p_divergences <- mcmc_pairs(
    vdiff_dframe_chains,
    np = vdiff_dframe_chains_divergences,
    np_style = pairs_style_np(div_size = 2),
    off_diag_fun = "hex"
  )
  vdiffr::expect_doppelganger("mcmc_pairs (divs, hex)", p_divergences)

  p_treedepth <- mcmc_pairs(
    vdiff_dframe_chains,
    np = vdiff_dframe_chains_treedepth,
    np_style = pairs_style_np(td_color = "green"),
    max_treedepth = 9
  )
  vdiffr::expect_doppelganger("mcmc_pairs (td)", p_treedepth)

  p_divs_treedepth_divergences <- mcmc_pairs(
    vdiff_dframe_chains,
    np = vdiff_dframe_chains_np,
    np_style = pairs_style_np(div_size = 3, td_color = "green"),
    max_treedepth = 9
  )
  vdiffr::expect_doppelganger("mcmc_pairs (divs, td)", p_divs_treedepth_divergences)
})
jgabry/ppcheck documentation built on Feb. 17, 2024, 5:35 a.m.