tests/testthat/test.parse_dist.R

# Tests for parse_dist
#
# Author: mjskay
###############################################################################

suppressWarnings(suppressPackageStartupMessages({
  library(dplyr)
  library(distributional)
}))


test_that("parse_dist works on vectors", {
  dist_env = rlang::current_env()
  ref = as.data.frame(tibble(
    .dist = c("norm", "lnorm", "chisq", NA, NA),
    .args = list(list(0,1), list(2,3), list(4), NA, NA),
    .dist_obj = c(
      dist_wrap("norm", 0, 1, package = "stats"),
      dist_wrap("lnorm", 2, 3, package = "stats"),
      dist_wrap("chisq", 4, package = "stats"),
      NA, NA
    )
  ))

  expect_equal(
    as.data.frame(parse_dist(c("Normal(0,1)", "log-normal(2,3)", "CHI square(4)", "bad", "bad2(1"), package = "stats")),
    ref
  )
})

test_that("parse_dist works on data frames", {
  dist_env = rlang::current_env()
  dists = factor(c("Normal(0,1)", "log-normal(2,3)", "Student's t(3,0,1)"))
  ref = as.data.frame(tibble(
    p = dists,
    lb = c(NA_real_, NA_real_, 0),
    ub = c(NA_real_, NA_real_, NA_real_),
    .dist = c("norm", "lnorm", "student_t"),
    .args = list(list(0,1), list(2,3), list(3,0,1)),
    .dist_obj = c(
      dist_wrap("norm", 0, 1, package = dist_env),
      dist_wrap("lnorm", 2, 3, package = dist_env),
      dist_truncated(dist_wrap("student_t", 3, 0, 1, package = dist_env), 0, Inf)
    )
  ))

  df = data.frame(
    p = dists,
    lb = c(NA_real_, NA_real_, 0),
    ub = c(NA_real_, NA_real_, NA_real_)
  )

  expect_equal(parse_dist(df, p), ref)
  expect_equal(parse_dist(df, p, package = dist_env), ref)
})

test_that("parse_dist works on brmsprior objects", {
  dist_env = rlang::current_env()
  dists = factor(c("Normal(0,1)", "log-normal(2,3)", "Student's t(3,0,1)"))
  ref = as.data.frame(tibble(
    prior = dists,
    .dist = c("norm", "lnorm", "student_t"),
    .args = list(list(0,1), list(2,3), list(3,0,1)),
    .dist_obj = c(
      dist_wrap("norm", 0, 1, package = dist_env),
      dist_wrap("lnorm", 2, 3, package = dist_env),
      dist_wrap("student_t", 3, 0, 1, package = dist_env)
    )
  ))

  brmsprior = data.frame(prior = dists)
  class(brmsprior) = c("brmsprior", "data.frame")
  expect_equal(parse_dist(brmsprior), ref)
})

test_that("parse_dist + marginalize_lkjcorr produces correct results", {
  skip_if_no_vdiffr()


  vdiffr::expect_doppelganger("LKJ marginalization works",
    tibble(prior = "lkjcorr(3)") %>%
      parse_dist(prior) %>%
      marginalize_lkjcorr(K = 2) %>%
      ggplot(aes(y = prior, dist = .dist, args = .args)) +
      stat_dist_halfeye(n = 20) +
      xlim(-1, 1) +
      xlab("Marginal correlation for LKJ(3) prior on 2x2 correlation matrix")
  )
})

test_that("unsupported objects throw error with parse_dist", {
  expect_error(
    parse_dist(list()),
    class = "ggdist_unsupported_type"
  )
})


# check_dist_name ---------------------------------------------------------

test_that("check_dist ignores unknown distributions", {
  expect_warning(
    expect_equal(check_dist_name(c("norm","foo","bar","t")), c("norm",NA,NA,"t")),
    class = "ggdist_unsupported_distribution_name"
  )
})

Try the ggdist package in your browser

Any scripts or data that you put into this service are public.

ggdist documentation built on July 4, 2024, 9:08 a.m.