tests/testthat/test-probcurve-pr_curve.R

test_that("Calculations are correct", {
  # Known PR Curve result
  pr_example <- data.frame(
    lab = factor(c("Yes", "Yes", "No", "Yes"), levels = c("Yes", "No")),
    score = c(0.9, 0.4, 0.35, 0.7)
  )

  pr_result <- list(
    .threshold = c(Inf, 0.9, 0.7, 0.4, 0.35),
    recall = c(0, 1 / 3, 2 / 3, 1, 1),
    precision = c(1, 1, 1, 1, 0.75)
  )

  expect_equal(
    as.list(pr_curve(pr_example, truth = "lab", "score")),
    pr_result
  )

  res <- pr_curve(hpc_cv, obs, VF:L)

  expect_equal(
    colnames(res),
    c(".level", ".threshold", "recall", "precision")
  )

  expect_equal(
    unique(res$.level),
    levels(hpc_cv$obs)
  )
})

test_that("na_rm = FALSE errors if missing values are present", {
  df <- two_class_example
  df$Class1[1] <- NA

  expect_snapshot(
    error = TRUE,
    pr_curve_vec(df$truth, df$Class1, na_rm = FALSE)
  )
})

test_that("Case weights calculations are correct", {
  # grouped multiclass (one-vs-all) weighted example matches expanded equivalent
  hpc_cv$weight <- rep(1, times = nrow(hpc_cv))
  hpc_cv$weight[c(100, 200, 150, 2)] <- 5

  hpc_cv <- dplyr::group_by(hpc_cv, Resample)

  hpc_cv_expanded <- hpc_cv[
    vec_rep_each(seq_len(nrow(hpc_cv)), times = hpc_cv$weight),
  ]

  expect_identical(
    pr_curve(hpc_cv, obs, VF:L, case_weights = weight),
    pr_curve(hpc_cv_expanded, obs, VF:L)
  )

  # zero weights don't affect the curve
  # If they weren't removed, we'd get a `NaN` from a division by zero issue
  df <- dplyr::tibble(
    truth = factor(c("b", "a", "b", "a", "a"), levels = c("a", "b")),
    a = c(0.75, 0.7, 0.4, 0.9, 0.8),
    weight = c(0, 1, 3, 0, 5)
  )

  expect_identical(
    pr_curve(df, truth, a, case_weights = weight),
    pr_curve(df[df$weight != 0, ], truth, a, case_weights = weight)
  )

  two_class_example$weight <- read_weights_two_class_example()

  curve <- pr_curve(two_class_example, truth, Class1, case_weights = weight)

  expect_identical(
    curve,
    read_pydata("py-pr-curve")$case_weight$binary
  )
})

test_that("works with hardhat case weights", {
  df <- data.frame(
    truth = factor(c("Yes", "Yes", "No", "Yes", "No"), levels = c("Yes", "No")),
    estimate = c(0.9, 0.8, 0.7, 0.68, 0.5),
    weight = c(2, 1, 1, 3, 2)
  )

  curve1 <- pr_curve(df, truth, estimate, case_weights = weight)

  df$weight <- hardhat::frequency_weights(df$weight)

  curve2 <- pr_curve(df, truth, estimate, case_weights = weight)

  expect_identical(curve1, curve2)
})

test_that("errors with class_pred input", {
  skip_if_not_installed("probably")

  cp_truth <- probably::as_class_pred(two_class_example$truth, which = 1)
  fct_truth <- two_class_example$truth
  fct_truth[1] <- NA

  estimate <- two_class_example$Class1

  expect_snapshot(
    error = TRUE,
    pr_curve_vec(cp_truth, estimate)
  )
})

test_that("na_rm argument check", {
  expect_snapshot(
    error = TRUE,
    pr_curve_vec(1, 1, na_rm = "yes")
  )
})

test_that("PR - perfect separation (#93)", {
  truth <- factor(c("x", "x", "y", "y"))
  prob <- c(0.9, 0.8, 0.4, 0.3)

  data <- data.frame(truth, prob)
  val_curve <- pr_curve(data, truth, prob)
  val_auc <- pr_auc(data, truth, prob)

  expect_equal(
    val_curve$recall,
    c(0, 0.5, 1, 1, 1)
  )

  expect_equal(
    val_curve$precision,
    c(1, 1, 1, 2 / 3, 1 / 2)
  )

  expect_equal(
    val_curve$.threshold,
    c(Inf, 0.9, 0.8, 0.4, 0.3)
  )

  expect_equal(
    val_auc$.estimate,
    1
  )
})

test_that("PR - perfect separation - duplicates probs at the end (#93)", {
  truth <- factor(c("x", "x", "y", "y"))
  prob <- c(0.9, 0.8, 0.3, 0.3)

  data <- data.frame(truth, prob)
  val_curve <- pr_curve(data, truth, prob)
  val_auc <- pr_auc(data, truth, prob)

  expect_equal(
    val_curve$recall,
    c(0, 0.5, 1, 1)
  )

  expect_equal(
    val_curve$precision,
    c(1, 1, 1, 1 / 2)
  )

  expect_equal(
    val_curve$.threshold,
    c(Inf, 0.9, 0.8, 0.3)
  )

  expect_equal(
    val_auc$.estimate,
    1
  )
})

test_that("PR - perfect separation - duplicates probs at the start (#93)", {
  truth <- factor(c("x", "x", "y", "y"))
  prob <- c(0.9, 0.9, 0.4, 0.3)

  data <- data.frame(truth, prob)
  val_curve <- pr_curve(data, truth, prob)
  val_auc <- pr_auc(data, truth, prob)

  expect_equal(
    val_curve$recall,
    c(0, 1, 1, 1)
  )

  expect_equal(
    val_curve$precision,
    c(1, 1, 2 / 3, 1 / 2)
  )

  expect_equal(
    val_curve$.threshold,
    c(Inf, 0.9, 0.4, 0.3)
  )

  expect_equal(
    val_auc$.estimate,
    1
  )
})

test_that("PR - same class prob, different prediction value (#93)", {
  # x class prob .9
  # y class prob .9
  truth <- factor(c("x", "y", "y", "x", "x"))
  prob <- c(0.9, 0.9, 0.8, 0.4, 0.3)

  data <- data.frame(truth, prob)
  val_curve <- pr_curve(data, truth, prob)
  val_auc <- pr_auc(data, truth, prob)

  expect_equal(
    val_curve$recall,
    c(0, 1 / 3, 1 / 3, 2 / 3, 1)
  )

  expect_equal(
    val_curve$precision,
    c(1, 1 / 2, 1 / 3, 1 / 2, 3 / 5)
  )

  expect_equal(
    val_curve$.threshold,
    c(Inf, 0.9, 0.8, 0.4, 0.3)
  )

  expect_equal(
    val_auc$.estimate,
    0.572222222222222
  )
})

test_that("PR - zero row data frame works", {
  df <- data.frame(y = factor(levels = c("a", "b")), x = double())

  expect <- dplyr::tibble(
    .threshold = Inf,
    recall = 0,
    precision = 1
  )

  class(expect) <- c("pr_df", class(expect))

  expect_snapshot(
    out <- pr_curve(df, y, x)
  )

  expect_identical(out, expect)
})

test_that("PR - No `truth` gives `NaN` recall values", {
  df <- data.frame(
    y = factor(c("b", "b"), levels = c("a", "b")),
    x = c(0.1, 0.2)
  )

  expect_warning({
    curve <- pr_curve(df, y, x)
  })

  expect_identical(curve$recall, c(0, NaN, NaN))
})

test_that("sklearn equivalent", {
  curve <- pr_curve(two_class_example, truth, Class1)

  expect_identical(
    curve,
    read_pydata("py-pr-curve")$binary
  )
})

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 8, 2026, 1:06 a.m.