tests/testthat/test.predicted_draws.R

# Tests for [add_]predicted_draws
#
# Author: mjskay
###############################################################################

suppressWarnings(suppressMessages({
  library(dplyr)
  library(tidyr)
  library(magrittr)
}))




# data
mtcars_tbl = mtcars %>%
  set_rownames(seq_len(nrow(.))) %>%
  as_tibble()

# for reliable testing, need to use only a single core (otherwise random
# numbers do not seem to always be reproducible on brms)
options(mc.cores = 1)


test_that("[add_]predicted_draws throws an error on unsupported models", {
  data("RankCorr", package = "ggdist")

  expect_error(predicted_draws(RankCorr, data.frame()),
    'no applicable method')
  expect_error(add_predicted_draws(data.frame(), RankCorr),
    'no applicable method')
})


test_that("[add_]predicted_draws and basic arguments works on a simple rstanarm model", {
  skip_if_not_installed("rstanarm")
  m_hp_wt = readRDS(test_path("../models/models.rstanarm.m_hp_wt.rds"))

  preds = rstanarm::posterior_predict(m_hp_wt, mtcars_tbl, draws = 10, seed = 123) %>%
    as.data.frame() %>%
    mutate(
      .chain = NA_integer_,
      .iteration = NA_integer_,
      .draw = seq_len(n())
    ) %>%
    gather(.row, .prediction, -.chain, -.iteration, -.draw) %>%
    as_tibble()

  ref = mtcars_tbl %>%
    mutate(.row = rownames(.)) %>%
    inner_join(preds, by = ".row", multiple = "all") %>%
    mutate(.row = as.integer(.row)) %>%
    group_by(mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row)

  expect_equal(predicted_draws(m_hp_wt, mtcars_tbl, ndraws = 10, seed = 123), ref)
  expect_equal(add_predicted_draws(mtcars_tbl, m_hp_wt, ndraws = 10, seed = 123), ref)
  expect_warning(
    expect_equal(predicted_draws(m_hp_wt, mtcars_tbl, n = 10, seed = 123), ref),
    "`n`.*deprecated.*`ndraws`"
  )
  expect_warning(
    expect_equal(add_predicted_draws(mtcars_tbl, m_hp_wt, n = 10, seed = 123), ref),
    "`n`.*deprecated.*`ndraws`"
  )

  # should still work using default implementation
  expect_equal(predicted_draws.default(m_hp_wt, mtcars_tbl, draws = 10, seed = 123), ref)
})


test_that("[add_]predicted_draws and basic arguments works on an rstanarm model with random effects", {
  skip_if_not_installed("rstanarm")
  m_cyl = readRDS(test_path("../models/models.rstanarm.m_cyl.rds"))

  preds = rstanarm::posterior_predict(m_cyl, mtcars_tbl, draws = 10, seed = 123) %>%
    as.data.frame() %>%
    mutate(
      .chain = NA_integer_,
      .iteration = NA_integer_,
      .draw = seq_len(n())
    ) %>%
    gather(.row, .prediction, -.chain, -.iteration, -.draw) %>%
    as_tibble()

  ref = mtcars_tbl %>%
    mutate(.row = rownames(.)) %>%
    inner_join(preds, by = ".row", multiple = "all") %>%
    mutate(.row = as.integer(.row)) %>%
    group_by(mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row)

  expect_equal(predicted_draws(m_cyl, mtcars_tbl, ndraws = 10, seed = 123), ref)
  expect_equal(add_predicted_draws(mtcars_tbl, m_cyl, ndraws = 10, seed = 123), ref)
})


test_that("[add_]predicted_draws works on a simple brms model", {
  skip_if_not_installed("brms")
  m_hp = readRDS(test_path("../models/models.brms.m_hp.rds"))

  set.seed(123)
  preds = predict(m_hp, mtcars_tbl, summary = FALSE, ndraws = 10) %>%
    as.data.frame() %>%
    set_names(seq_len(ncol(.))) %>%
    mutate(
      .chain = NA_integer_,
      .iteration = NA_integer_,
      .draw = seq_len(n())
    ) %>%
    gather(.row, .prediction, -.chain, -.iteration, -.draw) %>%
    as_tibble()

  ref = mtcars_tbl %>%
    mutate(.row = rownames(.)) %>%
    inner_join(preds, by = ".row", multiple = "all") %>%
    mutate(.row = as.integer(.row)) %>%
    group_by(mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row)

  expect_equal(predicted_draws(m_hp, mtcars_tbl, ndraws = 10, seed = 123), ref)
  expect_equal(add_predicted_draws(mtcars_tbl, m_hp, ndraws = 10, seed = 123), ref)
})

test_that("[add_]predicted_draws works on brms models with categorical outcomes", {
  skip_if_not_installed("brms")
  m_cyl_mpg = readRDS(test_path("../models/models.brms.m_cyl_mpg.rds"))

  set.seed(1234)
  raw_preds = predict(m_cyl_mpg, mtcars_tbl, summary = FALSE, ndraws = 10)
  preds = raw_preds %>%
    array2df(list(.draw = NA, .row = NA), label.x = ".prediction") %>%
    mutate(
      .chain = NA_integer_,
      .iteration = NA_integer_,
      .draw = as.integer(.draw),
      .row = as.character(.row),
    )
  # get this test to pass for now on brms versions that don't output levels anymore
  if (!is.null(attr(raw_preds, "levels"))) {
    preds %<>% mutate(
      .prediction = factor(.prediction, levels = 1:3, labels = paste0("c", c(4,6,8)))
    )
  }

  ref = mtcars_tbl %>%
    mutate(.row = rownames(.)) %>%
    inner_join(preds, by = ".row", multiple = "all") %>%
    mutate(.row = as.integer(.row)) %>%
    select(mpg:.row, .chain, .iteration, .draw, everything()) %>%
    group_by(mpg, cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb, .row)

  expect_equal(predicted_draws(m_cyl_mpg, mtcars_tbl, seed = 1234, ndraws = 10), ref)
  expect_equal(add_predicted_draws(mtcars_tbl, m_cyl_mpg, seed = 1234, ndraws = 10), ref)
})

test_that("[add_]predicted_draws works on brms models with dirichlet responses", {
  skip_if_not_installed("brms")
  skip_if_not(getRversion() >= "4")

  m_dirich = readRDS(test_path("../models/models.brms.m_dirich.rds"))

  set.seed(1234)
  grid = tibble(x = c("A", "B"))
  preds = predict(m_dirich, grid, summary = FALSE, ndraws = 10) %>%
    array2df(list(.draw = NA, .row = NA, .category = TRUE), label.x = ".prediction") %>%
    mutate(
      .chain = NA_integer_,
      .iteration = NA_integer_,
      .row = as.integer(.row),
      .draw = as.integer(.draw)
    )

  ref = grid %>%
    mutate(.row = as.integer(rownames(.))) %>%
    inner_join(preds, by = ".row", multiple = "all") %>%
    select(x, .row, .chain, .iteration, .draw, .category, everything()) %>%
    group_by(x, .row, .category)

  expect_equal(predicted_draws(m_dirich, grid, seed = 1234, ndraws = 10), ref)
})

test_that("[add_]predicted_draws works on brms models with multinomial responses", {
  skip_if_not_installed("brms")
  skip_if_not(getRversion() >= "4")

  m_multinom = readRDS(test_path("../models/models.brms.m_multinom.rds"))

  set.seed(1234)
  # use a low number for total so there are some 0s
  grid = tibble(total = c(10, 20))
  preds = predict(m_multinom, grid, summary = FALSE, ndraws = 10) %>%
    array2df(list(.draw = NA, .row = NA, .category = TRUE), label.x = ".prediction") %>%
    mutate(
      .chain = NA_integer_,
      .iteration = NA_integer_,
      .row = as.integer(.row),
      .draw = as.integer(.draw)
    )

  ref = grid %>%
    mutate(.row = as.integer(rownames(.))) %>%
    inner_join(preds, by = ".row", multiple = "all") %>%
    select(total, .row, .chain, .iteration, .draw, .category, everything()) %>%
    group_by(total, .row, .category)

  expect_equal(predicted_draws(m_multinom, grid, seed = 1234, ndraws = 10), ref)
})

test_that("[add_]predicted_draws throws an error when draws is called instead of ndraws in rstanarm", {
  skip_if_not_installed("rstanarm")
  m_hp_wt = readRDS(test_path("../models/models.rstanarm.m_hp_wt.rds"))

  expect_error(
    m_hp_wt %>% predicted_draws(mtcars_tbl, draws = 10),
    "`draws.*.`ndraws`.*.See the documentation for additional details."
  )
  expect_error(
    mtcars_tbl %>% add_predicted_draws(m_hp_wt, draws = 10),
    "`draws.*.`ndraws`.*.See the documentation for additional details."
  )
})

test_that("[add_]predicted_draws throws an error when re.form is called instead of re_formula in rstanarm", {
  skip_if_not_installed("rstanarm")
  m_hp_wt = readRDS(test_path("../models/models.rstanarm.m_hp_wt.rds"))

  expect_error(
    m_hp_wt %>% predicted_draws(mtcars_tbl, re.form = NULL),
    "`re.form.*.`re_formula`.*.See the documentation for additional details."
  )
  expect_error(
    mtcars_tbl %>% add_predicted_draws(m_hp_wt, re.form = NULL),
    "`re.form.*.`re_formula`.*.See the documentation for additional details."
  )
})

Try the tidybayes package in your browser

Any scripts or data that you put into this service are public.

tidybayes documentation built on Sept. 15, 2024, 9:08 a.m.