tests/testthat/test-cal-estimate-multinomial.R

test_that("Multinomial estimates work - data.frame", {
  skip_if_not_installed("modeldata")
  skip_if_not_installed("nnet")

  sp_multi <- cal_estimate_multinomial(species_probs, Species, smooth = FALSE)
  expect_cal_type(sp_multi, "multiclass")
  expect_cal_method(sp_multi, "Multinomial regression calibration")
  expect_cal_rows(sp_multi, n = 110)
  expect_snapshot(print(sp_multi))

  sp_smth_multi <- cal_estimate_multinomial(species_probs, Species, smooth = TRUE)
  expect_cal_type(sp_smth_multi, "multiclass")
  expect_cal_method(sp_smth_multi, "Generalized additive model calibration")
  expect_cal_rows(sp_smth_multi, n = 110)
  expect_snapshot(print(sp_smth_multi))

  sl_multi_group <- species_probs |>
    dplyr::mutate(group = .pred_bobcat > 0.5) |>
    cal_estimate_multinomial(Species, smooth = FALSE, .by = group)

  expect_cal_type(sl_multi_group, "multiclass")
  expect_cal_method(sl_multi_group, "Multinomial regression calibration")
  expect_cal_rows(sl_multi_group, n = 110)
  expect_snapshot(print(sl_multi_group))

  expect_snapshot_error(
    species_probs |>
      dplyr::mutate(group1 = 1, group2 = 2) |>
      cal_estimate_multinomial(Species, smooth = FALSE, .by = c(group1, group2))
  )

  mltm_configs <-
    mnl_with_configs() |>
    cal_estimate_multinomial(truth = obs, estimate = c(VF:L), smooth = FALSE)
})

test_that("Multinomial estimates work - tune_results", {
  skip_if_not_installed("modeldata")
  skip_if_not_installed("nnet")

  tl_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = FALSE)
  expect_cal_type(tl_multi, "multiclass")
  expect_cal_method(tl_multi, "Multinomial regression calibration")
  expect_snapshot(print(tl_multi))

  expect_equal(
    testthat_cal_multiclass() |>
      tune::collect_predictions(summarize = TRUE) |>
      nrow(),
    testthat_cal_multiclass() |>
      cal_apply(tl_multi) |>
      nrow()
  )

  tl_smth_multi <- cal_estimate_multinomial(testthat_cal_multiclass(), smooth = TRUE)
  expect_cal_type(tl_smth_multi, "multiclass")
  expect_cal_method(tl_smth_multi, "Generalized additive model calibration")
  expect_snapshot(print(tl_smth_multi))

  expect_equal(
    testthat_cal_multiclass() |>
      tune::collect_predictions(summarize = TRUE) |>
      nrow(),
    testthat_cal_multiclass() |>
      cal_apply(tl_smth_multi) |>
      nrow()
  )
})

test_that("Multinomial estimates errors - grouped_df", {
  skip_if_not_installed("modeldata")
  skip_if_not_installed("nnet")

  expect_snapshot_error(
    cal_estimate_multinomial(dplyr::group_by(mtcars, vs))
  )
})

test_that("Passing a binary outcome causes error", {
  expect_error(
    cal_estimate_multinomial(segment_logistic, Class)
  )
})

test_that("Multinomial spline switches to linear if too few unique", {
  skip_if_not_installed("modeldata")

  smol_species_probs <-
    species_probs |>
    dplyr::slice_head(n = 2, by = Species)

  expect_snapshot(
    sl_gam <- cal_estimate_multinomial(smol_species_probs, Species, smooth = TRUE)
  )
  sl_glm <- cal_estimate_multinomial(smol_species_probs, Species, smooth = FALSE)

  expect_identical(
    sl_gam$estimates,
    sl_glm$estimates
  )

  smol_by_species_probs <-
    species_probs |>
    dplyr::slice_head(n = 4, by = Species) |>
    dplyr::mutate(id = rep(1:2, 6))

  expect_snapshot(
    sl_gam <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = TRUE)
  )
  sl_glm <- cal_estimate_multinomial(smol_by_species_probs, Species, .by = id, smooth = FALSE)

  expect_identical(
    sl_gam$estimates,
    sl_glm$estimates
  )
})
topepo/probably documentation built on June 8, 2025, 4:23 a.m.