R/class-accuracy.R

Defines functions accuracy_table_impl accuracy_vec accuracy.matrix accuracy.table accuracy.data.frame accuracy

Documented in accuracy accuracy.data.frame accuracy_vec

#' Accuracy
#'
#' Accuracy is the proportion of the data that are predicted correctly.
#'
#' @family class metrics
#' @seealso [All class metrics][class-metrics]
#' @templateVar fn accuracy
#' @template return
#'
#' @section Multiclass:
#'
#' Accuracy extends naturally to multiclass scenarios. Because
#' of this, macro and micro averaging are not implemented.
#'
#' @inheritParams sens
#'
#' @details
#' Suppose a 2x2 table with notation:
#'
#' \tabular{rcc}{ \tab Reference \tab \cr Predicted \tab Positive \tab Negative
#' \cr Positive \tab A \tab B \cr Negative \tab C \tab D \cr }
#'
#' The formula used here is:
#'
#' \deqn{\text{Accuracy} = \frac{A + D}{A + B + C + D}}
#'
#' Accuracy is a metric that should be `r attr(accuracy, "direction")`d. The
#' output ranges from `r metric_range_chr(accuracy, 1)` to
#' `r metric_range_chr(accuracy, 2)`, with `r metric_optimal(accuracy)` indicating
#' perfect predictions.
#'
#' @author Max Kuhn
#'
#' @export
#' @examples
#' library(dplyr)
#' data("two_class_example")
#' data("hpc_cv")
#'
#' # Two class
#' accuracy(two_class_example, truth, predicted)
#'
#' # Multiclass
#' # accuracy() has a natural multiclass extension
#' hpc_cv |>
#'   filter(Resample == "Fold01") |>
#'   accuracy(obs, pred)
#'
#' # Groups are respected
#' hpc_cv |>
#'   group_by(Resample) |>
#'   accuracy(obs, pred)
accuracy <- function(data, ...) {
  UseMethod("accuracy")
}
accuracy <- new_class_metric(
  accuracy,
  direction = "maximize",
  range = c(0, 1)
)

#' @export
#' @rdname accuracy
accuracy.data.frame <- function(
  data,
  truth,
  estimate,
  na_rm = TRUE,
  case_weights = NULL,
  ...
) {
  class_metric_summarizer(
    name = "accuracy",
    fn = accuracy_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!enquo(estimate),
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
}

#' @export
accuracy.table <- function(data, ...) {
  check_table(data)
  estimator <- finalize_estimator(data, metric_class = "accuracy")

  metric_tibbler(
    .metric = "accuracy",
    .estimator = estimator,
    .estimate = accuracy_table_impl(data)
  )
}

#' @export
accuracy.matrix <- function(data, ...) {
  data <- as.table(data)
  accuracy.table(data)
}

#' @export
#' @rdname accuracy
accuracy_vec <- function(
  truth,
  estimate,
  na_rm = TRUE,
  case_weights = NULL,
  ...
) {
  check_bool(na_rm)
  abort_if_class_pred(truth)
  estimate <- as_factor_from_class_pred(estimate)

  estimator <- finalize_estimator(truth, metric_class = "accuracy")
  check_class_metric(truth, estimate, case_weights, estimator)

  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)

    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }

  data <- yardstick_table(truth, estimate, case_weights = case_weights)
  accuracy_table_impl(data)
}

accuracy_table_impl <- function(x) {
  sum(diag(x)) / sum(x)
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 8, 2026, 1:06 a.m.