R/prob-binary-thresholds.R

Defines functions binary_threshold_curve

# For use with the `pr_curve()` and `roc_curve()`.
# Returns a data frame with:
# - Unique thresholds
# - Number of true positives per threshold
# - Number of false positives per threshold
binary_threshold_curve <- function(truth,
                                   estimate,
                                   ...,
                                   event_level = yardstick_event_level(),
                                   case_weights = NULL) {
  check_dots_empty()

  if (is.null(case_weights)) {
    case_weights <- rep(1, times = length(truth))
  }
  case_weights <- vec_cast(case_weights, to = double())

  if (!is.factor(truth)) {
    abort("`truth` must be a factor.", .internal = TRUE)
  }
  if (length(levels(truth)) != 2L) {
    abort("`truth` must have two levels.", .internal = TRUE)
  }
  if (!is.numeric(estimate)) {
    abort("`estimate` must be numeric.", .internal = TRUE)
  }
  if (length(truth) != length(estimate)) {
    abort("`truth` and `estimate` must be the same length.", .internal = TRUE)
  }
  if (length(truth) != length(case_weights)) {
    abort("`truth` and `case_weights` must be the same length.", .internal = TRUE)
  }

  truth <- unclass(truth)

  # Convert to `1 == event`, `0 == non-event`
  if (is_event_first(event_level)) {
    truth <- as.integer(truth == 1L)
  } else {
    truth <- as.integer(truth == 2L)
  }

  # Drop any `0` weights.
  # These shouldn't affect the result, but can result in divide by zero
  # issues if they are left in.
  detect_zero_weight <- case_weights == 0
  if (any(detect_zero_weight)) {
    detect_non_zero_weight <- !detect_zero_weight
    truth <- truth[detect_non_zero_weight]
    estimate <- estimate[detect_non_zero_weight]
    case_weights <- case_weights[detect_non_zero_weight]
  }

  # Sort by decreasing `estimate`
  order <- order(estimate, decreasing = TRUE)
  truth <- truth[order]
  estimate <- estimate[order]
  case_weights <- case_weights[order]

  # Skip repeated probabilities.
  # We want the last duplicate to ensure that we capture all the events from the
  # `cumsum()`, so we use `fromLast`.
  loc_unique <- which(!duplicated(estimate, fromLast = TRUE))
  thresholds <- estimate[loc_unique]

  case_weights_events <- truth * case_weights
  case_weights_non_events <- (1 - truth) * case_weights

  if (sum(case_weights_events) == 0L) {
    warn("There are `0` event cases in `truth`, results will be meaningless.")
  }

  tp <- cumsum(case_weights_events)
  tp <- tp[loc_unique]

  fp <- cumsum(case_weights_non_events)
  fp <- fp[loc_unique]

  dplyr::tibble(
    threshold = thresholds,
    tp = tp,
    fp = fp
  )
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 21, 2023, 9:08 a.m.