tests/testthat/test-class-weight.R

test_that("setting class weights", {
 skip_if_not(torch::torch_is_installed())
 skip_if_not_installed("modeldata")

 suppressPackageStartupMessages(library(dplyr))

 # ------------------------------------------------------------------------------

 set.seed(585)
 mnl_tr <-
  modeldata::sim_multinomial(
   1000,
   ~  -0.5    +  0.6 * abs(A),
   ~ ifelse(A > 0 & B > 0, 1.0 + 0.2 * A / B, - 2),
   ~ -0.6 * A + 0.50 * B -  A * B)

 lvls <- levels(mnl_tr$class)
 num_class <- length(lvls)
 cls_xtab <- table(mnl_tr$class)
 min_class <- names(sort(cls_xtab))[1]

 cls_wts <- rep(1, num_class)
 names(cls_wts) <- lvls
 cls_wts[names(cls_wts) == min_class] <- 10

 bad_wts <- cls_wts
 names(bad_wts) <- letters[1:num_class]

 # ------------------------------------------------------------------------------

 expect_equal(
  brulee:::check_class_weights(1.0, lvls, cls_xtab, "fabulous") %>%
   as.numeric(),
  rep(1, num_class)
 )

 expect_s3_class(
  brulee:::check_class_weights(1.0, lvls, cls_xtab, "fabulous"),
  "torch_tensor"
 )

 expect_equal(
  brulee:::check_class_weights(NULL, lvls, cls_xtab, "fabulous") %>%
   as.numeric(),
  rep(1, num_class)
 )

 expect_equal(
  brulee:::check_class_weights(6.25, lvls, cls_xtab, "fabulous") %>%
   as.numeric(),
  c(1, 6.25, 1)
 )

 expect_equal(
  brulee:::check_class_weights(c(1, 6.25, 1), lvls, cls_xtab, "fabulous") %>%
   as.numeric(),
  c(1, 6.25, 1)
 )

 expect_null(
  brulee:::check_class_weights(1, character(0), cls_xtab, "fabulous")
 )

 expect_snapshot(
  brulee:::check_class_weights("a", lvls, cls_xtab, "fabulous"),
  error = TRUE
 )

 expect_snapshot(
  brulee:::check_class_weights(c(1, 6.25), lvls, cls_xtab, "fabulous"),
  error = TRUE
 )

 expect_snapshot(
  brulee:::check_class_weights(bad_wts, lvls, cls_xtab, "fabulous"),
  error = TRUE
 )

})
tidymodels/lantern documentation built on Feb. 28, 2024, 12:59 a.m.