tests/testthat/test-conf_mat.R

test_that("Three class format", {
  lst <- data_three_class()
  three_class <- lst$three_class
  three_class_tb <- lst$three_class_tb

  # Because of case weight support, `conf_mat()` returns a double table
  storage.mode(three_class_tb) <- "double"

  expect_identical(
    conf_mat(three_class, truth = "obs", estimate = "pred", dnn = c("", ""))$table,
    three_class_tb
  )
})

test_that("Summary method", {
  lst <- data_three_class()
  three_class <- lst$three_class
  three_class_tb <- lst$three_class_tb

  sum_obj_3 <- summary(conf_mat(three_class, obs, pred))
  sum_obj_2 <- summary(conf_mat(three_class_tb[1:2, 1:2]))

  expect_equal(
    sum_obj_3$.metric,
    c(
      "accuracy", "kap", "sens", "spec", "ppv", "npv", "mcc", "j_index",
      "bal_accuracy", "detection_prevalence", "precision", "recall",
      "f_meas"
    )
  )

  expect_equal(
    dplyr::slice(sum_obj_3, 1),
    accuracy(three_class_tb)
  )

  expect_equal(
    sum_obj_2$.metric,
    c(
      "accuracy", "kap", "sens", "spec", "ppv", "npv", "mcc", "j_index",
      "bal_accuracy", "detection_prevalence", "precision", "recall",
      "f_meas"
    )
  )

  expect_equal(
    dplyr::filter(sum_obj_2, .metric == "sens"),
    sens(three_class_tb[1:2, 1:2])
  )
})

test_that("Summary method - estimators pass through", {
  lst <- data_three_class()
  three_class <- lst$three_class

  sum_obj_micro <- summary(conf_mat(three_class, obs, pred), estimator = "micro")
  sum_obj_macrow <- summary(conf_mat(three_class, obs, pred), estimator = "macro_weighted")

  # All multiclass or micro
  expect_true(
    all(vapply(sum_obj_micro$.estimator, `%in%`, logical(1), c("multiclass", "micro")))
  )

  # All multiclass or macro_weighted
  expect_true(
    all(vapply(sum_obj_macrow$.estimator, `%in%`, logical(1), c("multiclass", "macro_weighted")))
  )
})

test_that("summary method - `event_level` passes through (#160)", {
  lst <- data_powers()
  df <- lst$df_2_1

  df_rev <- df
  df_rev$truth <- stats::relevel(df_rev$truth, "Irrelevant")
  df_rev$prediction <- stats::relevel(df_rev$prediction, "Irrelevant")

  expect_equal(
    as.data.frame(summary(conf_mat(df, truth, prediction))),
    as.data.frame(summary(conf_mat(df_rev, truth, prediction), event_level = "second"))
  )
})

test_that("Grouped conf_mat() handler works", {
  hpc_g <- dplyr::group_by(hpc_cv, Resample)
  res <- conf_mat(hpc_g, obs, pred)

  expect_s3_class(res, "tbl_df")
  expect_type(res$conf_mat, "list")

  expect_equal(
    res$conf_mat[[1]],
    hpc_cv %>%
      dplyr::filter(Resample == "Fold01") %>%
      conf_mat(obs, pred)
  )
})

test_that("Multilevel table -> conf_mat", {
  expect_identical(
    conf_mat(table(hpc_cv$pred, hpc_cv$obs, dnn = c("Prediction", "Truth"))),
    conf_mat(hpc_cv, obs, pred)
  )
})

test_that("Multilevel matrix -> conf_mat", {
  expect_identical(
    conf_mat(as.matrix(table(hpc_cv$pred, hpc_cv$obs, dnn = c("Prediction", "Truth")))),
    conf_mat(hpc_cv, obs, pred)
  )
})

test_that("Tidy method", {
  res <- tidy(conf_mat(hpc_cv, obs, pred))

  expect_equal(
    res$value[[1]],
    1620
  )

  expect_equal(
    res$name[[1]],
    "cell_1_1"
  )
})

test_that("can change the dimnames names", {
  out <- conf_mat(two_class_example, truth, predicted, dnn = c("Foo", "Bar"))
  expect_identical(names(dimnames(out$table)), c("Foo", "Bar"))
})

test_that("case weights are supported in data frame method", {
  two_class_example$weight <- read_weights_two_class_example()

  expect_identical(
    conf_mat(two_class_example, truth, predicted, case_weights = weight)$table,
    as.table(yardstick_table(
      truth = two_class_example$truth,
      estimate = two_class_example$predicted,
      case_weights = two_class_example$weight
    ))
  )
})

test_that("case weights propagate through to summary method metrics", {
  two_class_example$weight <- read_weights_two_class_example()

  out <- conf_mat(two_class_example, truth, predicted, case_weights = weight)
  metrics <- summary(out)

  accuracy <- metrics[metrics$.metric == "accuracy", ]
  accuracy <- accuracy$.estimate

  expect <- accuracy(two_class_example, truth, predicted, case_weights = weight)
  expect <- expect[[".estimate"]]

  expect_identical(accuracy, expect)
})

test_that("case weights are supported in grouped-df method", {
  hpc_cv$weight <- read_weights_hpc_cv()

  hpc_cv_f1 <- dplyr::filter(hpc_cv, Resample == "Fold01")

  table_f1 <- yardstick_table(
    truth = hpc_cv_f1$obs,
    estimate = hpc_cv_f1$pred,
    case_weights = hpc_cv_f1$weight
  )
  table_f1 <- conf_mat(table_f1)

  hpc_cv <- dplyr::group_by(hpc_cv, Resample)
  result <- conf_mat(hpc_cv, obs, pred, case_weights = weight)

  expect_identical(
    result$conf_mat[[1]],
    table_f1
  )
})

test_that("`conf_mat()` returns a double table with or without case weights", {
  two_class_example$weight <- read_weights_two_class_example()

  out <- conf_mat(two_class_example, truth, predicted)
  expect_s3_class(out$table, "table")
  expect_type(out$table, "double")

  out <- conf_mat(two_class_example, truth, predicted, case_weights = weight)
  expect_s3_class(out$table, "table")
  expect_type(out$table, "double")
})

test_that("`as.data.frame.table()` method is run on the underlying `table` object", {
  # Used by tune, so this test ensures we don't break that. We have to keep the
  # underlying object as a `table`, even though it can be numeric when combined
  # with case weights.

  two_class_example$weight <- read_weights_two_class_example()

  out <- conf_mat(two_class_example, truth, predicted)
  expect_named(as.data.frame(out$table), c("Prediction", "Truth", "Freq"))

  out <- conf_mat(two_class_example, truth, predicted, case_weights = weight)
  expect_named(as.data.frame(out$table), c("Prediction", "Truth", "Freq"))
})

test_that("`...` is deprecated with a warning", {
  skip_if(getRversion() <= "3.5.3", "Base R used a different deprecated warning class.")
  rlang::local_options(lifecycle_verbosity = "warning")

  expect_snapshot(conf_mat(two_class_example, truth, predicted, foo = 1))

  hpc_cv <- dplyr::group_by(hpc_cv, Resample)
  expect_snapshot(conf_mat(hpc_cv, obs, pred, foo = 1))
})

test_that("Errors are thrown correctly", {
  lst <- data_three_class()
  three_class <- lst$three_class
  three_class$obs_rev <- three_class$obs
  levels(three_class$obs_rev) <- rev(levels(three_class$obs))
  three_class$onelevel <- factor(1)

  expect_snapshot(
    error = TRUE,
    conf_mat(three_class, truth = obs_rev, estimate = pred, dnn = c("", ""))
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(three_class, truth = onelevel, estimate = pred, dnn = c("", ""))
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(three_class, truth = onelevel, estimate = onelevel, dnn = c("", ""))
  )
})

test_that("Errors are thrown correctly - grouped", {
  lst <- data_three_class()
  three_class <- lst$three_class
  three_class$obs_rev <- three_class$obs
  levels(three_class$obs_rev) <- rev(levels(three_class$obs))
  three_class$onelevel <- factor(1)
  three_class <- dplyr::group_by(three_class, pred)

  expect_snapshot(
    error = TRUE,
    conf_mat(three_class, truth = obs_rev, estimate = pred, dnn = c("", ""))
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(three_class, truth = onelevel, estimate = pred, dnn = c("", ""))
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(three_class, truth = onelevel, estimate = onelevel, dnn = c("", ""))
  )
})

test_that("conf_mat()'s errors when wrong things are passes", {
  expect_snapshot(
    error = TRUE,
    conf_mat(two_class_example, not_truth, predicted)
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(two_class_example, truth, not_predicted)
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(
      dplyr::group_by(two_class_example, truth),
      truth = not_truth,
      estimate = predicted
    )
  )

  expect_snapshot(
    error = TRUE,
    conf_mat(
      dplyr::group_by(two_class_example, truth),
      truth = truth,
      estimate = not_predicted
    )
  )
})

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on June 22, 2024, 7:07 p.m.