R/prob-brier_class.R

Defines functions brier_factor brier_ind brier_class_binary brier_class_estimator_impl brier_class_vec brier_class.data.frame brier_class

Documented in brier_class brier_class.data.frame brier_class_vec

#' Brier score for classification models
#'
#' Compute the Brier score for a classification model.
#'
#' @family class probability metrics
#' @seealso [All probability metrics][prob-metrics]
#' @templateVar fn brier_class
#' @template return
#' @details
#' Brier score is a metric that should be `r attr(brier_class, "direction")`d.
#' The output ranges from `r metric_range_chr(brier_class, 1)` to
#' `r metric_range_chr(brier_class, 2)`, with `r metric_optimal(brier_class)`
#' indicating perfect predictions.
#'
#' The Brier score is analogous to the mean squared error in regression models.
#' The difference between a binary indicator for a class and its corresponding
#' class probability are squared and averaged.
#'
#' The formula used here is:
#'
#' \deqn{\text{Brier} = \frac{1}{2N} \sum_{i=1}^{N} \sum_{j=1}^{K} (y_{ij} - p_{ij})^2}
#'
#' where \eqn{N} is the number of observations, \eqn{K} is the number of classes,
#' \eqn{y_{ij}} is 1 if observation \eqn{i} belongs to class \eqn{j} and 0
#' otherwise, and \eqn{p_{ij}} is the predicted probability of observation
#' \eqn{i} for class \eqn{j}.
#'
#' This function uses the convention in Kruppa _et al_ (2014) and divides the
#' result by two.
#'
#' Smaller values of the score are associated with better model performance.
#'
#' @section Multiclass:
#' Brier scores can be computed in the same way for any number of classes.
#' Because of this, no averaging types are supported.
#'
#' @inheritParams pr_auc
#' @template event_first
#'
#' @author Max Kuhn
#'
#' @references Kruppa, J., Liu, Y., Diener, H.-C., Holste, T., Weimar, C.,
#' Koonig, I. R., and Ziegler, A. (2014) Probability estimation with machine
#' learning methods for dichotomous and multicategory outcome: Applications.
#' Biometrical Journal, 56 (4): 564-583.
#' @examples
#' # Two class
#' data("two_class_example")
#' brier_class(two_class_example, truth, Class1)
#'
#' # Multiclass
#' library(dplyr)
#' data(hpc_cv)
#'
#' # You can use the col1:colN tidyselect syntax
#' hpc_cv |>
#'   filter(Resample == "Fold01") |>
#'   brier_class(obs, VF:L)
#'
#' # Groups are respected
#' hpc_cv |>
#'   group_by(Resample) |>
#'   brier_class(obs, VF:L)
#'
#' @export
brier_class <- function(data, ...) {
  UseMethod("brier_class")
}
brier_class <- new_prob_metric(
  brier_class,
  direction = "minimize",
  range = c(0, 1)
)

#' @export
#' @rdname brier_class
brier_class.data.frame <- function(
  data,
  truth,
  ...,
  na_rm = TRUE,
  event_level = yardstick_event_level(),
  case_weights = NULL
) {
  case_weights_quo <- enquo(case_weights)

  prob_metric_summarizer(
    name = "brier_class",
    fn = brier_class_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    event_level = event_level,
    case_weights = !!case_weights_quo
  )
}

#' @rdname brier_class
#' @export
brier_class_vec <- function(
  truth,
  estimate,
  na_rm = TRUE,
  event_level = yardstick_event_level(),
  case_weights = NULL,
  ...
) {
  check_bool(na_rm)
  abort_if_class_pred(truth)

  estimator <- finalize_estimator(truth, metric_class = "brier_class")

  check_prob_metric(truth, estimate, case_weights, estimator)

  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)

    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }

  brier_class_estimator_impl(
    truth = truth,
    estimate = estimate,
    estimator = estimator,
    event_level = event_level,
    case_weights = case_weights
  )
}

brier_class_estimator_impl <- function(
  truth,
  estimate,
  estimator,
  event_level,
  case_weights
) {
  if (is_binary(estimator)) {
    brier_class_binary(truth, estimate, event_level, case_weights)
  } else {
    brier_factor(truth, estimate, case_weights)
  }
}

brier_class_binary <- function(truth, estimate, event_level, case_weights) {
  if (!is_event_first(event_level)) {
    lvls <- levels(truth)
    truth <- stats::relevel(truth, lvls[[2]])
  }

  estimate <- matrix(c(estimate, 1 - estimate), ncol = 2)

  brier_factor(truth, estimate, case_weights)
}

# If `truth` is already a vector or matrix of binary data
brier_ind <- function(truth, estimate, case_weights = NULL) {
  if (is.vector(truth)) {
    truth <- matrix(truth, ncol = 1)
  }

  if (is.vector(estimate)) {
    estimate <- matrix(estimate, ncol = 1)
  }
  # In the binary case:
  if (ncol(estimate) == 1 && ncol(truth) == 2) {
    estimate <- unname(estimate)
    estimate <- vec_cbind(estimate, 1 - estimate, .name_repair = "unique_quiet")
  }

  resids <- (truth - estimate)^2

  if (is.null(case_weights)) {
    case_weights <- rep(1, nrow(resids))
  }

  not_missing <- !is.na(case_weights)
  resids <- resids[not_missing, , drop = FALSE]
  case_weights <- case_weights[not_missing]

  # Normalize weights (in case negative weights)
  # subtracting max to avoid Inf in calculations
  #  exp(x - max(x)) / sum(exp(x - max(x)))
  #  = (exp(x) / exp(max(x))) / (sum(exp(x)) / exp(max(x)))
  #  = exp(x) / sum(exp(x))
  case_weights <- case_weights - max(case_weights)
  case_weights <- exp(case_weights) / sum(exp(case_weights))

  res <- sum(resids * case_weights) / (2 * sum(case_weights))
  res
}

# When `truth` is a factor
brier_factor <- function(truth, estimate, case_weights = NULL) {
  inds <- hardhat::fct_encode_one_hot(truth)

  case_weights <- vctrs::vec_cast(case_weights, to = double())

  brier_ind(inds, estimate, case_weights)
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 8, 2026, 1:06 a.m.