R/dplot3_calibration.R

Defines functions dplot3_calibration

Documented in dplot3_calibration

# dplot3_calibration.R
# ::rtemis::
# 2023 EDG lambdamd.original

#' Draw calibration plot
#'
#' @param true.labels Factor or list of factors with true class labels
#' @param est.prob Numeric vector or list of numeric vectors with predicted probabilities
#' @param bin.method Character: "quantile" or "equidistant": Method to bin the estimated 
#' probabilities.
#' @param n.bins Integer: Number of windows to split the data into
#' @param pos.class.idi Integer: Index of the positive class
#' @param xlab Character: x-axis label
#' @param ylab Character: y-axis label
#' @param mode Character: Plot mode
#' @param ... Additional arguments passed to [dplot3_xy]
#'
#' @return NULL
#' @author EDG
#' @export
#' @examples
#' \dontrun{
#' data(segment_logistic, package = "probably")
#'
#' # Plot the calibration curve of the original predictions
#' dplot3_calibration(
#'   true.labels = segment_logistic$Class,
#'   est.prob = segment_logistic$.pred_poor,
#'   n.bins = 10,
#'   pos.class.idi = 2
#' )
#'
#' # Plot the calibration curve of the calibrated predictions
#' dplot3_calibration(
#'   true.labels = segment_logistic$Class,
#'   est.prob = calibrate(
#'     segment_logistic$Class,
#'     segment_logistic$.pred_poor
#'   )$fitted.values,
#'   n.bins = 10,
#'   pos.class.idi = 2
#' )
#' }
dplot3_calibration <- function(true.labels, est.prob,
                               n.bins = 10,
                               bin.method = c("equidistant", "quantile"),
                               pos.class.idi = 1,
                               xlab = "Mean estimated probability",
                               ylab = "Empirical risk",
                               #    conf_level = .95,
                               mode = "markers+lines", ...) {

  bin.method <- match.arg(bin.method)
  if (!is.list(true.labels)) {
    true.labels <- list(true_labels = true.labels)
  }
  if (!is.list(est.prob)) {
    est.prob <- list(estimated_prob = est.prob)
  }
  # Ensure same number of inputs
  stopifnot(length(true.labels) == length(est.prob))

  pos_class <- lapply(true.labels, \(x) {
    levels(x)[pos.class.idi]
  })

  # Ensure same positive class
  stopifnot(length(unique(unlist(pos_class))) == 1)

  # Create windows
  if (bin.method == "equidistant") {
    breaks <- lapply(seq_along(est.prob), \(x) {
      seq(0, 1, length.out = n.bins + 1)
    })
  } else if (bin.method == "quantile") {
    breaks <- lapply(est.prob, \(x) {
      quantile(x, probs = seq(0, 1, length.out = n.bins + 1))
    })
  }

  # Calculate the mean probability in each window
  mean_bin_prob <- lapply(seq_along(est.prob), \(i) {
    sapply(seq_len(n.bins), \(j) {
      mean(est.prob[[i]][est.prob[[i]] >= breaks[[i]][j] & est.prob[[i]] < breaks[[i]][j + 1]])
    })
  })

  # Calculate the proportion of condition positive cases in each window
  window_empirical_risk <- lapply(seq_along(est.prob), \(i) {
    sapply(seq_len(n.bins), \(j) {
      idl <- est.prob[[i]] >= breaks[[i]][j] & est.prob[[i]] < breaks[[i]][j + 1]
      sum(true.labels[[i]][idl] == pos_class[[i]]) / sum(idl)
    })
  })

  # Calculate confidence intervals
  # confint <- sapply(seq_len(n.bins), \(i) {
  #     events <- length(true.labels[true.labels == pos_class & est.prob >= breaks[i] & est.prob < breaks[i + 1]])
  #     total <- length(est.prob >= breaks[i] & est.prob < breaks[i + 1])
  #     suppressWarnings(pt <- prop.test(
  #         events, total,
  #         conf.level = conf_level
  #     ))
  #     pt$conf.int
  # })

  # Plot
  dplot3_xy(
    mean_bin_prob, window_empirical_risk,
    xlab = xlab,
    ylab = ylab,
    axes.square = TRUE, diagonal = TRUE,
    xlim = c(0, 1), ylim = c(0, 1),
    mode = mode, ...
  )
} # rtemis::dplot3_calibration
egenn/rtemis documentation built on May 4, 2024, 7:40 p.m.