#---------------------------------- Methods ------------------------------------
# TODO for regression; update a lot of help pages to talk more generally about multiple types
#' Applies a calibration to a set of existing predictions
#' @details
#'
#' `cal_apply()` currently supports data.frames only. It extracts the `truth` and
#' the estimate columns names from the calibration object.
#'
#' @param .data An object that can process a calibration object.
#' @param object The calibration object (`cal_object`).
#' @param pred_class (Optional, classification only) Column identifier for the
#' hard class predictions (a factor vector). This column will be adjusted based
#' on changes to the calibrated probability columns.
#' @param parameters (Optional) An optional tibble of tuning parameter values
#' that can be used to filter the predicted values before processing. Applies
#' only to `tune_results` objects.
#' @param ... Optional arguments; currently unused.
#' @seealso
#' \url{https://www.tidymodels.org/learn/models/calibration/},
#' [cal_estimate_beta()], [cal_estimate_isotonic()],
#' [cal_estimate_isotonic_boot()], [cal_estimate_linear()],
#' [cal_estimate_logistic()], [cal_estimate_multinomial()]
#' @examples
#'
#' # ------------------------------------------------------------------------------
#' # classification example
#'
#' w_calibration <- cal_estimate_logistic(segment_logistic, Class)
#'
#' cal_apply(segment_logistic, w_calibration)
#' @export
cal_apply <- function(.data,
object,
pred_class = NULL,
parameters = NULL,
...) {
rlang::check_dots_empty()
UseMethod("cal_apply")
}
#' @export
#' @rdname cal_apply
cal_apply.data.frame <- function(.data,
object,
pred_class = NULL,
parameters = NULL,
...) {
cal_pkg_check(required_pkgs(object))
stop_null_parameters(parameters)
cal_adjust_update(
object = object,
.data = .data,
pred_class = {{ pred_class }}
)
}
#' @export
#' @rdname cal_apply
cal_apply.tune_results <- function(.data,
object,
pred_class = NULL,
parameters = NULL,
...) {
cal_pkg_check(required_pkgs(object))
if (!(".predictions" %in% colnames(.data))) {
cli::cli_abort(
c(
"The {.arg .data} object does not contain columns with predictions.",
"i" = "Refit with the control argument {.code save_pred = TRUE} to save
these columns."
)
)
}
pred_class <- enquo(pred_class)
if (rlang::quo_is_null(pred_class)) {
pred_class <- rlang::parse_expr(".pred_class")
}
cp <- tune::collect_predictions(
x = .data,
summarize = TRUE,
parameters = parameters,
...
)
cal_adjust_update(
object = object,
.data = cp,
pred_class = !!pred_class
)
}
#' @export
#' @rdname cal_apply
cal_apply.cal_object <- function(.data,
object,
pred_class = NULL,
parameters = NULL,
...) {
if ("data.frame" %in% class(object)) {
cli::cli_abort(
c(
"{.fn cal_apply} expects the data as the first argument, and the object
as the second argument.",
"i" = "Please reverse the order of the arguments and try again."
)
)
}
}
#---------------------------------- Adjust -------------------------------------
cal_adjust <- function(object, .data, pred_class) {
UseMethod("cal_adjust")
}
#' @export
cal_adjust.cal_estimate_isotonic <- function(object, .data, pred_class) {
apply_interval_impl(
object = object,
.data = .data,
multi = FALSE
)
}
#' @export
cal_adjust.cal_estimate_isotonic_boot <- function(object, .data, pred_class) {
apply_interval_impl(
object = object,
.data = .data,
multi = TRUE
)
}
#' @export
cal_adjust.cal_estimate_beta <- function(object,
.data,
pred_class = NULL,
...) {
apply_beta_impl(
object = object,
.data = .data
)
}
#' @export
cal_adjust.cal_multi <- function(object, .data, pred_class) {
cal_apply_multi(
object = object,
.data = .data,
pred_class = {{ pred_class }}
)
}
#' @export
cal_adjust.cal_binary <- function(object, .data, pred_class) {
cal_apply_binary(
object = object,
.data = .data,
pred_class = {{ pred_class }}
)
}
#' @export
cal_adjust.cal_regression <- function(object, .data, pred_class) {
cal_apply_regression(
object = object,
.data = .data,
pred_class = NULL
)
}
#' @export
cal_adjust.cal_estimate_none <- function(object, .data, pred_class) {
.data
}
cal_adjust_update <- function(.data,
object,
pred_class = NULL,
parameters = NULL,
...) {
if (object$type != "regression") {
pred_class <- enquo(pred_class)
} else {
pred_class <- quo(NULL)
}
res <- cal_adjust(
object = object,
.data = .data,
pred_class = !!pred_class
)
if (!rlang::quo_is_null(pred_class)) {
pred_name <- as_name(pred_class)
if (pred_name %in% colnames(res)) {
res[, pred_name] <- NULL
}
col_names <- nm_levels(object$levels)
factor_levels <- names(object$levels)
predictions <- res[, col_names] |> max.col(ties.method = "first")
predictions <- factor_levels[predictions]
predictions <- factor(predictions, levels = factor_levels)
res[, pred_name] <- predictions
}
res
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.