R/class-ppv.R

Defines functions ppv_multiclass ppv_binary ppv_table_impl ppv_vec ppv.matrix ppv.table ppv.data.frame ppv

Documented in ppv ppv.data.frame ppv_vec

#' Positive predictive value
#'
#' These functions calculate the [ppv()] (positive predictive value) of a
#' measurement system compared to a reference result (the "truth" or gold standard).
#' Highly related functions are [spec()], [sens()], and [npv()].
#'
#' The positive predictive value ([ppv()]) is defined as the percent of
#' predicted positives that are actually positive while the
#' negative predictive value ([npv()]) is defined as the percent of negative
#' positives that are actually negative.
#'
#' @family class metrics
#' @family sensitivity metrics
#' @seealso [All class metrics][class-metrics]
#' @templateVar fn ppv
#' @template event_first
#' @template multiclass
#' @template return
#'
#' @inheritParams sens
#'
#' @param prevalence A numeric value for the rate of the
#'  "positive" class of the data.
#'
#' @details
#' Suppose a 2x2 table with notation:
#'
#' \tabular{rcc}{ \tab Reference \tab \cr Predicted \tab Positive \tab Negative
#' \cr Positive \tab A \tab B \cr Negative \tab C \tab D \cr }
#'
#' The formulas used here are:
#'
#' \deqn{\text{Sensitivity} = \frac{A}{A + C}}
#'
#' \deqn{\text{Specificity} = \frac{D}{B + D}}
#'
#' \deqn{\text{Prevalence} = \frac{A + C}{A + B + C + D}}
#'
#' \deqn{\text{PPV} = \frac{\text{Sensitivity} \cdot \text{Prevalence}}{(\text{Sensitivity} \cdot \text{Prevalence}) + ((1 - \text{Specificity}) \cdot (1 - \text{Prevalence}))}}
#'
#' PPV is a metric that should be `r attr(ppv, "direction")`d. The output
#' ranges from `r metric_range_chr(ppv, 1)` to `r metric_range_chr(ppv, 2)`, with
#' `r metric_optimal(ppv)` indicating all predicted positives are true
#' positives.
#'
#' @author Max Kuhn
#'
#' @references
#'
#' Altman, D.G., Bland, J.M. (1994) ``Diagnostic tests 2:
#' predictive values,'' *British Medical Journal*, vol 309,
#' 102.
#'
#' @template examples-class
#' @examples
#' # Using a different value of 'prevalence'... if you are adding the metric to a
#' # metric set, you can create a new metric function with the updated argument
#' # value:
#'
#' ppv_alt_prev  <- metric_tweak("ppv_alt_prev", ppv, prevalence = 0.40)
#' multi_metrics <- metric_set(ppv, ppv_alt_prev)
#' multi_metrics(two_class_example, truth, estimate = predicted)
#'
#' @examples
#' # But what if we think that Class 1 only occurs 40% of the time?
#' ppv(two_class_example, truth, predicted, prevalence = 0.40)
#'
#' @export
ppv <- function(data, ...) {
  UseMethod("ppv")
}
ppv <- new_class_metric(
  ppv,
  direction = "maximize",
  range = c(0, 1)
)

#' @rdname ppv
#' @export
ppv.data.frame <- function(
  data,
  truth,
  estimate,
  prevalence = NULL,
  estimator = NULL,
  na_rm = TRUE,
  case_weights = NULL,
  event_level = yardstick_event_level(),
  ...
) {
  class_metric_summarizer(
    name = "ppv",
    fn = ppv_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!enquo(estimate),
    estimator = estimator,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights),
    event_level = event_level,
    fn_options = list(prevalence = prevalence)
  )
}

#' @export
ppv.table <- function(
  data,
  prevalence = NULL,
  estimator = NULL,
  event_level = yardstick_event_level(),
  ...
) {
  check_table(data)
  estimator <- finalize_estimator(data, estimator)

  metric_tibbler(
    .metric = "ppv",
    .estimator = estimator,
    .estimate = ppv_table_impl(
      data,
      estimator = estimator,
      event_level = event_level,
      prevalence = prevalence
    )
  )
}

#' @export
ppv.matrix <- function(
  data,
  prevalence = NULL,
  estimator = NULL,
  event_level = yardstick_event_level(),
  ...
) {
  data <- as.table(data)

  ppv.table(
    data,
    prevalence = prevalence,
    estimator = estimator,
    event_level = event_level
  )
}

#' @export
#' @rdname ppv
ppv_vec <- function(
  truth,
  estimate,
  prevalence = NULL,
  estimator = NULL,
  na_rm = TRUE,
  case_weights = NULL,
  event_level = yardstick_event_level(),
  ...
) {
  check_bool(na_rm)
  check_number_decimal(prevalence, min = 0, max = 1, allow_null = TRUE)
  abort_if_class_pred(truth)
  estimate <- as_factor_from_class_pred(estimate)

  estimator <- finalize_estimator(truth, estimator)

  check_class_metric(truth, estimate, case_weights, estimator)

  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)

    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }

  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)

    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }

  data <- yardstick_table(truth, estimate, case_weights = case_weights)
  ppv_table_impl(data, estimator, event_level, prevalence = prevalence)
}

ppv_table_impl <- function(data, estimator, event_level, prevalence = NULL) {
  if (is_binary(estimator)) {
    ppv_binary(data, event_level, prevalence)
  } else {
    w <- get_weights(data, estimator)
    out_vec <- ppv_multiclass(data, estimator, prevalence)
    stats::weighted.mean(out_vec, w)
  }
}

ppv_binary <- function(data, event_level, prevalence = NULL) {
  positive <- pos_val(data, event_level)

  if (is.null(prevalence)) {
    prevalence <- sum(data[, positive]) / sum(data)
  }

  sens <- sens_binary(data, event_level)
  spec <- spec_binary(data, event_level)
  (sens * prevalence) / ((sens * prevalence) + ((1 - spec) * (1 - prevalence)))
}

ppv_multiclass <- function(data, estimator, prevalence = NULL) {
  # ppv should be equal to precision in all cases except when
  # prevalence is explicitely set. In that case, that value
  # is used which alters the result
  if (is.null(prevalence)) {
    tpfn <- colSums(data)
    tptnfpfn <- rep(sum(data), times = nrow(data))

    if (is_micro(estimator)) {
      tpfn <- sum(tpfn)
      tptnfpfn <- sum(tptnfpfn)
    }

    prevalence <- tpfn / tptnfpfn
  }

  .sens_vec <- recall_multiclass(data, estimator)
  .spec_vec <- spec_multiclass(data, estimator)

  numer <- .sens_vec * prevalence
  denom <- .sens_vec * prevalence + (1 - .spec_vec) * (1 - prevalence)

  denom[denom <= 0] <- NA_real_

  numer / denom
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 8, 2026, 1:06 a.m.