# dplot3_calibration.R
# ::rtemis::
# 2023 EDG lambdamd.original
#' Draw calibration plot
#'
#' @param true.labels Factor or list of factors with true class labels
#' @param est.prob Numeric vector or list of numeric vectors with predicted probabilities
#' @param bin.method Character: "quantile" or "equidistant": Method to bin the estimated
#' probabilities.
#' @param n.bins Integer: Number of windows to split the data into
#' @param pos.class.idi Integer: Index of the positive class
#' @param xlab Character: x-axis label
#' @param ylab Character: y-axis label
#' @param mode Character: Plot mode
#' @param ... Additional arguments passed to [dplot3_xy]
#'
#' @return NULL
#' @author EDG
#' @export
#' @examples
#' \dontrun{
#' data(segment_logistic, package = "probably")
#'
#' # Plot the calibration curve of the original predictions
#' dplot3_calibration(
#' true.labels = segment_logistic$Class,
#' est.prob = segment_logistic$.pred_poor,
#' n.bins = 10,
#' pos.class.idi = 2
#' )
#'
#' # Plot the calibration curve of the calibrated predictions
#' dplot3_calibration(
#' true.labels = segment_logistic$Class,
#' est.prob = calibrate(
#' segment_logistic$Class,
#' segment_logistic$.pred_poor
#' )$fitted.values,
#' n.bins = 10,
#' pos.class.idi = 2
#' )
#' }
dplot3_calibration <- function(true.labels, est.prob,
n.bins = 10,
bin.method = c("equidistant", "quantile"),
pos.class.idi = 1,
xlab = "Mean estimated probability",
ylab = "Empirical risk",
# conf_level = .95,
mode = "markers+lines", ...) {
bin.method <- match.arg(bin.method)
if (!is.list(true.labels)) {
true.labels <- list(true_labels = true.labels)
}
if (!is.list(est.prob)) {
est.prob <- list(estimated_prob = est.prob)
}
# Ensure same number of inputs
stopifnot(length(true.labels) == length(est.prob))
pos_class <- lapply(true.labels, \(x) {
levels(x)[pos.class.idi]
})
# Ensure same positive class
stopifnot(length(unique(unlist(pos_class))) == 1)
# Create windows
if (bin.method == "equidistant") {
breaks <- lapply(seq_along(est.prob), \(x) {
seq(0, 1, length.out = n.bins + 1)
})
} else if (bin.method == "quantile") {
breaks <- lapply(est.prob, \(x) {
quantile(x, probs = seq(0, 1, length.out = n.bins + 1))
})
}
# Calculate the mean probability in each window
mean_bin_prob <- lapply(seq_along(est.prob), \(i) {
sapply(seq_len(n.bins), \(j) {
mean(est.prob[[i]][est.prob[[i]] >= breaks[[i]][j] & est.prob[[i]] < breaks[[i]][j + 1]])
})
})
# Calculate the proportion of condition positive cases in each window
window_empirical_risk <- lapply(seq_along(est.prob), \(i) {
sapply(seq_len(n.bins), \(j) {
idl <- est.prob[[i]] >= breaks[[i]][j] & est.prob[[i]] < breaks[[i]][j + 1]
sum(true.labels[[i]][idl] == pos_class[[i]]) / sum(idl)
})
})
# Calculate confidence intervals
# confint <- sapply(seq_len(n.bins), \(i) {
# events <- length(true.labels[true.labels == pos_class & est.prob >= breaks[i] & est.prob < breaks[i + 1]])
# total <- length(est.prob >= breaks[i] & est.prob < breaks[i + 1])
# suppressWarnings(pt <- prop.test(
# events, total,
# conf.level = conf_level
# ))
# pt$conf.int
# })
# Plot
dplot3_xy(
mean_bin_prob, window_empirical_risk,
xlab = xlab,
ylab = ylab,
axes.square = TRUE, diagonal = TRUE,
xlim = c(0, 1), ylim = c(0, 1),
mode = mode, ...
)
} # rtemis::dplot3_calibration
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.