tests/testthat/test.curve_interval.R

# Tests for curve_interval
#
# Author: mjskay
###############################################################################

suppressPackageStartupMessages(suppressWarnings({
  library(dplyr)
  library(tidyr)
}))


test_that("curve_interval works with lineribbon", {
  skip_if_no_vdiffr()
  skip_if_not_installed("posterior")


  k = 11 # number of curves
  n = 101
  curve_df = tibble(
    .draw = 1:k,
    mean = seq(-5,5, length.out = k),
    x = list(seq(-15,15,length.out = n))
  ) %>%
    unnest(x) %>%
    mutate(y = dnorm(x, mean, 3)/max(dnorm(x, mean, 3))) %>%
    group_by(x)


  vdiffr::expect_doppelganger("curve_interval with mhd",
    curve_df %>%
      curve_interval(y, .width = c(.5, .8), .interval = "mhd") %>%
      ggplot(aes(x = x, y = y)) +
      geom_lineribbon(aes(ymin = .lower, ymax = .upper)) +
      geom_line(aes(group = .draw), alpha = 0.15, data = curve_df) +
      scale_fill_brewer()
  )

  # conditioning works
  set.seed(1234)
  b_draw = sample.int(n)
  ab_curve_df = bind_rows(
    mutate(curve_df, group = "a"),
    mutate(curve_df, group = "b", .draw = b_draw[.draw])
  ) %>%
    arrange(.draw)

  vdiffr::expect_doppelganger("conditional curve_interval with mhd",
    ab_curve_df %>%
      group_by(group) %>%
      curve_interval(y, .along = x, .width = c(.5, .8), .interval = "mhd") %>%
      ggplot(aes(x = x, y = y)) +
      geom_lineribbon(aes(ymin = .lower, ymax = .upper)) +
      geom_line(aes(group = .draw), alpha = 0.15, data = curve_df) +
      scale_fill_brewer() +
      facet_wrap(~ group)
  )

  # joint works
  # TODO: fix
  # vdiffr::expect_doppelganger("joint curve_interval with mhd",
  #   ab_curve_df %>%
  #     arrange(.draw) %>%  # TODO: this should not be needed
  #     curve_interval(y, .along = c(group, x), .width = c(.5, .8), .interval = "mhd") %>%
  #     ggplot(aes(x = x, y = y)) +
  #     geom_lineribbon(aes(ymin = .lower, ymax = .upper)) +
  #     geom_line(aes(group = .draw), alpha = 0.15, data = curve_df) +
  #     scale_fill_brewer() +
  #     facet_wrap(~ group)
  # )

  skip_if_not_installed("fda")
  vdiffr::expect_doppelganger("curve_interval with bd-mbd",
    curve_df %>%
      curve_interval(y, .width = c(.5, .8), .interval = "bd-mbd") %>%
      ggplot(aes(x = x, y = y)) +
      geom_lineribbon(aes(ymin = .lower, ymax = .upper)) +
      geom_line(aes(group = .draw), alpha = 0.15, data = curve_df) +
      scale_fill_brewer()
  )

})


# basic cases: 95%, 0%, 100%, > 100% ---------------------------------------------

test_that("basic cases on single interval work", {
  df = data.frame(value = ppoints(1000))

  ref = data.frame(
    value = 0.4995,
    .lower = NA,
    .upper = NA,
    .actual_width = NA,
    .width = NA,
    .point = "mhd",
    .interval = "mhd",
    stringsAsFactors = FALSE
  )

  expect_equal(
    curve_interval(df, .width = .95),
    mutate(ref, .lower = .0255, .upper = .9745, .actual_width = .95, .width = .95)
  )
  expect_equal(
    curve_interval(df, .width = 0),
    mutate(ref, .lower = 0.4995, .upper = .5005, .actual_width = 0.002, .width = 0)
  )
  expect_equal(
    curve_interval(df, .width = 1),
    mutate(ref, .lower = 0.0005, .upper = .9995, .actual_width = 1, .width = 1)
  )
  expect_equal(
    curve_interval(df, .width = 1.1),
    mutate(ref, .lower = 0.0005, .upper = .9995, .actual_width = 1, .width = 1.1)
  )
})

test_that("basic cases on single curve work", {
  df = data.frame(x = 1:3, y = rep(ppoints(1000), each = 3) + 1:3)

  ref = data.frame(
    x = rep(1:3, 3),
    y = rep(0.4995 + 1:3, 3),
    .lower = rep(c(.0255, 0.4995, 0.0005), each = 3) + rep(1:3, 3),
    .upper = rep(c(.9745, 0.5005, 0.9995), each = 3) + rep(1:3, 3),
    .actual_width = rep(c(.95, 0.002, 1), each = 3),
    .width = rep(c(.95, 0, 1), each = 3),
    .point = "mhd",
    .interval = "mhd",
    stringsAsFactors = FALSE
  )

  # data frame of draws
  expect_equal(curve_interval(df, .along = x, .width = c(.95, 0, 1)), ref)
  expect_equal(curve_interval(group_by(df, x), .width = c(.95, 0, 1)), ref)

  skip_if_not_installed("posterior")
  # data frame of rvars
  y_rvar = rep(posterior::rvar(ppoints(1000)), 3) + 1:3
  df = data.frame(x = 1:3, y = y_rvar)
  expect_equal(curve_interval(df, .along = x, .width = c(.95, 0, 1)), as_tibble(ref))
  expect_equal(curve_interval(group_by(df, x), .width = c(.95, 0, 1)), as_tibble(ref))

  # rvar
  expect_equal(curve_interval(y_rvar, .width = c(.95, 0, 1)), select(ref, -x, .value = y))

  # matrix
  expect_equal(curve_interval(posterior::draws_of(y_rvar), .width = c(.95, 0, 1)), select(ref, -x, .value = y))
})

test_that("basic cases on multiple variables", {
  df = data.frame(
    x = 1:3,
    y1 = rep(ppoints(1000), each = 3) + 1:3,
    y2 = rep(rev(ppoints(1000)), each = 3) + 2:4
  )

  ref = as.data.frame(tibble(
    x = rep(1:3, 3),
    y1 = rep(0.4995 + 1:3, 3),
    y1.lower = rep(c(.0255, 0.4995, 0.0005), each = 3) + rep(1:3, 3),
    y1.upper = rep(c(.9745, 0.5005, 0.9995), each = 3) + rep(1:3, 3),
    .actual_width = rep(c(.95, 0.002, 1), each = 3),
    .width = rep(c(.95, 0, 1), each = 3),
    y2 = rep(0.5005 + 2:4, 3),
    y2.lower = y1.lower + 1,
    y2.upper = y1.upper + 1,
    .point = "mhd",
    .interval = "mhd"
  ))

  expect_equal(curve_interval(df, .along = x, .width = c(.95, 0, 1)), ref)
  expect_equal(curve_interval(group_by(df, x), y1, y2, .width = c(.95, 0, 1)), ref)
  expect_equal(
    curve_interval(mutate(df, y = x + 1), .along = c(x, y), .width = c(.95, 0, 1)),
    mutate(ref, y = x + 1, .after = 1)
  )
})


# errors ------------------------------------------------------------------

test_that("error is thrown when no columns found to summarize", {
  df = data.frame(value = ppoints(10))
  expect_error(curve_interval(df, .exclude = "value"), "No columns found to calculate point and interval summaries for")
  expect_error(curve_interval(df, .along = x))
})

test_that("error is thrown when along does not match a column", {
  df = data.frame(value = ppoints(10))
  expect_error(curve_interval(df, .along = x), class = "ggdist_invalid_column_selection")
})

test_that("error is thrown with groups of different sizes", {
  df = data.frame(
    value = ppoints(9),
    group = c("a", "a", "b"),
    stringsAsFactors = FALSE
  )
  expect_error(curve_interval(df, .along = group), "Must have the same number of values in each group")
})

test_that("curve_interval(<rvar>) and curve_interval(<matrix>) do not support along", {
  expect_error(
    curve_interval(matrix(1:4, nrow = 2), .along = "x"),
    'does\\s+not\\s+support\\s+the\\s+[^a-zA-Z]*\\.along[^a-zA-Z]*\\s+argument'
  )

  skip_if_not_installed("posterior")
  expect_error(
    curve_interval(posterior::rvar(), .along = "x"),
    'does\\s+not\\s+support\\s+the\\s+[^a-zA-Z]*\\.along[^a-zA-Z]*\\s+argument'
  )
})

Try the ggdist package in your browser

Any scripts or data that you put into this service are public.

ggdist documentation built on July 4, 2024, 9:08 a.m.