Nothing
# Metric set -------------------------------------------------------------------
#' Combine metric functions
#'
#' `metric_set()` allows you to combine multiple metric functions together
#' into a new function that calculates all of them at once.
#'
#' @param ... The bare names of the functions to be included in the metric set.
#'
#' @details
#' All functions must be either:
#' - Only numeric metrics
#' - A mix of class metrics or class prob metrics
#' - A mix of dynamic, integrated, and static survival metrics
#'
#' For instance, `rmse()` can be used with `mae()` because they
#' are numeric metrics, but not with `accuracy()` because it is a classification
#' metric. But `accuracy()` can be used with `roc_auc()`.
#'
#' The returned metric function will have a different argument list
#' depending on whether numeric metrics or a mix of class/prob metrics were
#' passed in.
#'
#' ```
#' # Numeric metric set signature:
#' fn(
#' data,
#' truth,
#' estimate,
#' na_rm = TRUE,
#' case_weights = NULL,
#' ...
#' )
#'
#' # Class / prob metric set signature:
#' fn(
#' data,
#' truth,
#' ...,
#' estimate,
#' estimator = NULL,
#' na_rm = TRUE,
#' event_level = yardstick_event_level(),
#' case_weights = NULL
#' )
#'
#' # Dynamic / integrated / static survival metric set signature:
#' fn(
#' data,
#' truth,
#' ...,
#' estimate,
#' na_rm = TRUE,
#' case_weights = NULL
#' )
#' ```
#'
#' @section Naming the `estimate` argument:
#'
#' Note that for class/prob metric sets and survival metric sets, the
#' `estimate` argument comes *after* `...` in the function signature. This
#' means you must always pass `estimate` as a named argument; otherwise
#' your column will be captured by `...` and you'll get an error.
#'
#' ```
#' # Correct - estimate is named:
#' my_metrics(two_class_example, truth, estimate = predicted)
#'
#' # Incorrect - estimate is not named and gets captured by `...`:
#' my_metrics(two_class_example, truth, predicted)
#' ```
#'
#' When mixing class and class prob metrics, pass in the hard predictions
#' (the factor column) as the named argument `estimate`, and the soft
#' predictions (the class probability columns) as bare column names or
#' `tidyselect` selectors to `...`.
#'
#' When mixing dynamic, integrated, and static survival metrics, pass in the
#' time predictions as the named argument `estimate`, and the survival
#' predictions as bare column names or `tidyselect` selectors to `...`.
#'
#' If `metric_tweak()` has been used to "tweak" one of these arguments, like
#' `estimator` or `event_level`, then the tweaked version wins. This allows you
#' to set the estimator on a metric by metric basis and still use it in a
#' `metric_set()`.
#'
#' @examples
#' library(dplyr)
#'
#' # Multiple regression metrics
#' multi_metric <- metric_set(rmse, rsq, ccc)
#'
#' # The returned function has arguments:
#' # fn(data, truth, estimate, na_rm = TRUE, ...)
#' multi_metric(solubility_test, truth = solubility, estimate = prediction)
#'
#' # Groups are respected on the new metric function
#' class_metrics <- metric_set(accuracy, kap)
#'
#' hpc_cv |>
#' group_by(Resample) |>
#' class_metrics(obs, estimate = pred)
#'
#' # ---------------------------------------------------------------------------
#'
#' # If you need to set options for certain metrics, do so by using
#' # `metric_tweak()`. Here's an example where we use the `bias` option to the
#' # `ccc()` metric
#' ccc_with_bias <- metric_tweak("ccc_with_bias", ccc, bias = TRUE)
#'
#' multi_metric2 <- metric_set(rmse, rsq, ccc_with_bias)
#'
#' multi_metric2(solubility_test, truth = solubility, estimate = prediction)
#'
#' # ---------------------------------------------------------------------------
#' # A class probability example:
#'
#' # Note that, when given class or class prob functions,
#' # metric_set() returns a function with signature:
#' # fn(data, truth, ..., estimate)
#' # to be able to mix class and class prob metrics.
#'
#' # You must provide the `estimate` column by explicitly naming
#' # the argument
#'
#' class_and_probs_metrics <- metric_set(roc_auc, pr_auc, accuracy)
#'
#' hpc_cv |>
#' group_by(Resample) |>
#' class_and_probs_metrics(obs, VF:L, estimate = pred)
#'
#' @seealso [metrics()], [get_metrics()]
#'
#' @export
metric_set <- function(...) {
quo_fns <- enquos(...)
validate_not_empty(quo_fns)
# Get values and check that they are fns
fns <- lapply(quo_fns, eval_tidy)
validate_inputs_are_functions(fns)
# Add on names, and then check that
# all fns are of the same function class
for (i in seq_along(fns)) {
fn_name <- names(fns[i])
if (fn_name == "") {
names(fns)[i] <- get_quo_label(quo_fns[[i]])
} else {
names(fns)[i] <- fn_name
}
}
validate_function_class(fns)
fn_cls <- class1(fns[[1]])
# signature of the function is different depending on input functions
if (fn_cls == "numeric_metric") {
make_numeric_metric_function(fns)
} else if (
fn_cls %in% c("prob_metric", "class_metric", "ordered_prob_metric")
) {
make_prob_class_metric_function(fns)
} else if (fn_cls == "quantile_metric") {
make_quantile_metric_function(fns)
} else if (
fn_cls %in%
c(
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric",
"linear_pred_survival_metric"
)
) {
make_survival_metric_function(fns)
} else {
# should not be reachable
# nocov start
cli::cli_abort(
"{.fn validate_function_class} should have errored on unknown classes.",
.internal = TRUE
)
# nocov end
}
}
#' @export
print.metric_set <- function(x, ...) {
cat(format(x), sep = "\n")
invisible(x)
}
#' @export
format.metric_set <- function(x, ...) {
metrics <- attributes(x)$metrics
names <- names(metrics)
cli::cli_format_method({
cli::cli_text("A metric set, consisting of:")
metric_formats <- vapply(metrics, format, character(1))
metric_formats <- strsplit(metric_formats, " | ", fixed = TRUE)
metric_names <- names(metric_formats)
metric_types <- vapply(
metric_formats,
`[`,
character(1),
1,
USE.NAMES = FALSE
)
metric_descs <- vapply(metric_formats, `[`, character(1), 2)
metric_nchars <- nchar(metric_names) + nchar(metric_types)
metric_desc_paddings <- max(metric_nchars) - metric_nchars
# see r-lib/cli#506
metric_desc_paddings <- lapply(metric_desc_paddings, rep, x = "\u00a0")
metric_desc_paddings <- vapply(
metric_desc_paddings,
paste,
character(1),
collapse = ""
)
for (i in seq_along(metrics)) {
cli::cli_text(
"- {.fun {metric_names[i]}},
{tolower(metric_types[i])}{metric_desc_paddings[i]} |
{metric_descs[i]}"
)
}
})
}
#' @export
as_tibble.metric_set <- function(x, ...) {
metrics <- attributes(x)$metrics
names <- names(metrics)
metrics <- unname(metrics)
classes <- map_chr(metrics, class1)
directions <- map_chr(metrics, get_metric_fn_direction)
dplyr::tibble(
metric = names,
class = classes,
direction = directions
)
}
map_chr <- function(x, f, ...) {
vapply(x, f, character(1), ...)
}
class1 <- function(x) {
class(x)[[1]]
}
get_metric_fn_direction <- function(x) {
attr(x, "direction")
}
get_quo_label <- function(quo) {
out <- as_label(quo)
if (length(out) != 1L) {
# should not be reachable
# nocov start
cli::cli_abort(
"{.code as_label(quo)} resulted in a character vector of length >1.",
.internal = TRUE
)
# nocov end
}
is_namespaced <- grepl("::", out, fixed = TRUE)
if (is_namespaced) {
# Split by `::` and take the second half
split <- strsplit(out, "::", fixed = TRUE)[[1]]
out <- split[[2]]
}
out
}
make_prob_class_metric_function <- function(fns) {
metric_function <- function(
data,
truth,
...,
estimate,
estimator = NULL,
na_rm = TRUE,
event_level = yardstick_event_level(),
case_weights = NULL
) {
# Find class vs prob metrics
are_class_metrics <- vapply(
X = fns,
FUN = inherits,
FUN.VALUE = logical(1),
what = "class_metric"
)
class_fns <- fns[are_class_metrics]
prob_fns <- fns[!are_class_metrics]
dots_not_empty <- length(match.call(expand.dots = FALSE)$...) > 0
if (!is_empty(class_fns) && missing(estimate) && dots_not_empty) {
cli::cli_abort(
c(
"!" = "{.arg estimate} is required for class metrics but was not
provided.",
"i" = "In a metric set, the {.arg estimate} argument must be named.",
"i" = "Example: {.code my_metrics(data, truth, estimate = my_column)}"
),
call = current_env()
)
}
metric_list <- list()
# Evaluate class metrics
if (!is_empty(class_fns)) {
class_args <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
estimator = estimator,
na_rm = na_rm,
event_level = event_level,
case_weights = !!enquo(case_weights)
)
class_calls <- lapply(class_fns, call2, !!!class_args)
class_calls <- mapply(
call_remove_static_arguments,
class_calls,
class_fns
)
class_list <- mapply(
FUN = eval_safely,
class_calls, # .x
names(class_calls), # .y
SIMPLIFY = FALSE,
USE.NAMES = FALSE
)
metric_list <- c(metric_list, class_list)
}
# Evaluate prob metrics
if (!is_empty(prob_fns)) {
# TODO - If prob metrics can all do micro, we can remove this
if (!is.null(estimator) && estimator == "micro") {
prob_estimator <- NULL
} else {
prob_estimator <- estimator
}
prob_args <- quos(
data = data,
truth = !!enquo(truth),
... = ...,
estimator = prob_estimator,
na_rm = na_rm,
event_level = event_level,
case_weights = !!enquo(case_weights)
)
prob_calls <- lapply(prob_fns, call2, !!!prob_args)
prob_calls <- mapply(call_remove_static_arguments, prob_calls, prob_fns)
prob_list <- mapply(
FUN = eval_safely,
prob_calls, # .x
names(prob_calls), # .y
SIMPLIFY = FALSE,
USE.NAMES = FALSE
)
metric_list <- c(metric_list, prob_list)
}
dplyr::bind_rows(metric_list)
}
class(metric_function) <- c(
"class_prob_metric_set",
"metric_set",
class(metric_function)
)
attr(metric_function, "metrics") <- fns
metric_function
}
make_numeric_metric_function <- function(fns) {
metric_function <- function(
data,
truth,
estimate,
na_rm = TRUE,
case_weights = NULL,
...
) {
# Construct common argument set for each metric call
# Doing this dynamically inside the generated function means
# we capture the correct arguments
call_args <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)
# Construct calls from the functions + arguments
calls <- lapply(fns, call2, !!!call_args)
calls <- mapply(call_remove_static_arguments, calls, fns)
# Evaluate
metric_list <- mapply(
FUN = eval_safely,
calls, # .x
names(calls), # .y
SIMPLIFY = FALSE,
USE.NAMES = FALSE
)
dplyr::bind_rows(metric_list)
}
class(metric_function) <- c(
"numeric_metric_set",
"metric_set",
class(metric_function)
)
attr(metric_function, "metrics") <- fns
metric_function
}
make_survival_metric_function <- function(fns) {
metric_function <- function(
data,
truth,
...,
estimate,
pred_time,
na_rm = TRUE,
case_weights = NULL
) {
# Construct common argument set for each metric call
# Doing this dynamically inside the generated function means
# we capture the correct arguments
is_static <- vapply(
fns,
inherits,
logical(1),
"static_survival_metric"
)
is_linear_pred <- vapply(
fns,
inherits,
logical(1),
"linear_pred_survival_metric"
)
is_dynamic_or_integrated <- vapply(
fns,
function(fn) {
inherits(fn, "dynamic_survival_metric") ||
inherits(fn, "integrated_survival_metric")
},
FUN.VALUE = logical(1)
)
needs_estimate <- any(is_static) || any(is_linear_pred)
dots_not_empty <- length(match.call(expand.dots = FALSE)$...) > 0
if (needs_estimate && missing(estimate) && dots_not_empty) {
cli::cli_abort(
c(
"!" = "{.arg estimate} is required for static or linear predictor
survival metrics but was not provided.",
"i" = "In a metric set, the {.arg estimate} argument must be named.",
"i" = "Example: {.code my_metrics(data, truth, estimate = my_column)}"
),
call = current_env()
)
}
# Static and linear pred metrics both use the `estimate` argument
# so we need route the columns to the correct metric functions
is_set_of_static_and_linear_pred <- any(is_static) && any(is_linear_pred)
if (is_set_of_static_and_linear_pred) {
estimate_eval <- tidyselect::eval_select(
expr = enquo(estimate),
data = data,
allow_rename = TRUE,
allow_empty = FALSE,
error_call = current_env()
)
validate_estimate_static_linear_pred(estimate_eval, call = current_env())
static_col_name <- names(data)[estimate_eval["static"]]
linear_pred_col_name <- names(data)[estimate_eval["linear_pred"]]
args_static <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!sym(static_col_name),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)
args_linear_pred <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!sym(linear_pred_col_name),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)
calls_static <- lapply(fns[is_static], call2, !!!args_static)
calls_linear_pred <- lapply(
fns[is_linear_pred],
call2,
!!!args_linear_pred
)
calls_estimate <- c(calls_static, calls_linear_pred)
} else {
args_estimate <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)
needs_estimate_arg <- is_static | is_linear_pred
calls_estimate <- lapply(fns[needs_estimate_arg], call2, !!!args_estimate)
}
args_dots <- quos(
data = data,
truth = !!enquo(truth),
... = ...,
na_rm = na_rm,
case_weights = !!enquo(case_weights)
)
calls_dots <- lapply(fns[is_dynamic_or_integrated], call2, !!!args_dots)
calls <- c(calls_dots, calls_estimate)
calls <- mapply(call_remove_static_arguments, calls, fns)
# Evaluate
metric_list <- mapply(
FUN = eval_safely,
calls, # .x
names(calls), # .y
SIMPLIFY = FALSE,
USE.NAMES = FALSE
)
dplyr::bind_rows(metric_list)
}
class(metric_function) <- c(
"survival_metric_set",
"metric_set",
class(metric_function)
)
attr(metric_function, "metrics") <- fns
metric_function
}
make_quantile_metric_function <- function(fns) {
metric_function <- function(
data,
truth,
estimate,
na_rm = TRUE,
case_weights = NULL,
...
) {
# Construct common argument set for each metric call
# Doing this dynamically inside the generated function means
# we capture the correct arguments
call_args <- quos(
data = data,
truth = !!enquo(truth),
estimate = !!enquo(estimate),
na_rm = na_rm,
case_weights = !!enquo(case_weights),
... = ...
)
# Construct calls from the functions + arguments
calls <- lapply(fns, call2, !!!call_args)
calls <- mapply(call_remove_static_arguments, calls, fns)
# Evaluate
metric_list <- mapply(
FUN = eval_safely,
calls, # .x
names(calls), # .y
SIMPLIFY = FALSE,
USE.NAMES = FALSE
)
dplyr::bind_rows(metric_list)
}
class(metric_function) <- c(
"quantile_metric_set",
"metric_set",
class(metric_function)
)
attr(metric_function, "metrics") <- fns
metric_function
}
validate_not_empty <- function(x, call = caller_env()) {
if (is_empty(x)) {
cli::cli_abort(
"At least 1 function must be supplied to {.code ...}.",
call = call
)
}
}
validate_inputs_are_functions <- function(fns, call = caller_env()) {
# Check that the user supplied all functions
is_fun_vec <- vapply(fns, is_function, logical(1))
all_fns <- all(is_fun_vec)
if (!all_fns) {
not_fn <- which(!is_fun_vec)
cli::cli_abort(
"All inputs to {.fn metric_set} must be functions.
These inputs are not: {not_fn}.",
call = call
)
}
}
# Validate that all metric functions inherit from valid function classes or
# combinations of classes
validate_function_class <- function(fns) {
fn_cls <- vapply(fns, function(fn) class(fn)[1], character(1))
fn_cls_unique <- unique(fn_cls)
n_unique <- length(fn_cls_unique)
if (n_unique == 0L) {
return(invisible(fns))
}
valid_cls <- c(
"class_metric",
"prob_metric",
"ordered_prob_metric",
"numeric_metric",
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric",
"linear_pred_survival_metric",
"quantile_metric"
)
if (n_unique == 1L) {
if (fn_cls_unique %in% valid_cls) {
return(invisible(fns))
}
}
class_prob_cls <- c("class_metric", "prob_metric", "ordered_prob_metric")
if (
any(fn_cls_unique %in% class_prob_cls) &&
all(fn_cls_unique %in% class_prob_cls)
) {
return(invisible(fns))
}
surv_cls <- c(
"dynamic_survival_metric",
"static_survival_metric",
"integrated_survival_metric",
"linear_pred_survival_metric"
)
if (any(fn_cls_unique %in% surv_cls) && all(fn_cls_unique %in% surv_cls)) {
return(invisible(fns))
}
# Special case unevaluated groupwise metric factories
if ("metric_factory" %in% fn_cls) {
factories <- fn_cls[fn_cls == "metric_factory"]
cli::cli_abort(
c(
"{cli::qty(factories)}The input{?s} {.arg {names(factories)}}
{?is a/are} {.help [groupwise metric](yardstick::new_groupwise_metric)}
{?factory/factories} and must be passed a data-column before
addition to a metric set.",
"i" = "Did you mean to type e.g. `{names(factories)[1]}(col_name)`?"
),
call = rlang::call2("metric_set")
)
}
# Each element of the list contains the names of the fns
# that inherit that specific class
fn_bad_names <- lapply(fn_cls_unique, function(x) {
names(fns)[fn_cls == x]
})
# clean up for nicer printing
fn_cls_unique <- gsub("_metric", "", fn_cls_unique, fixed = TRUE)
fn_cls_unique <- gsub("function", "other", fn_cls_unique, fixed = TRUE)
fn_cls_other <- fn_cls_unique == "other"
if (any(fn_cls_other)) {
fn_cls_other_loc <- which(fn_cls_other)
fn_other_names <- fn_bad_names[[fn_cls_other_loc]]
fns_other <- fns[fn_other_names]
env_names_other <- vapply(
fns_other,
function(fn) env_name(fn_env(fn)),
character(1)
)
fn_bad_names[[fn_cls_other_loc]] <- paste0(
fn_other_names,
" ",
"<",
env_names_other,
">"
)
}
# Prints as:
# - fn_type1 (fn_name1, fn_name2)
# - fn_type2 (fn_name1)
fn_pastable <- mapply(
FUN = function(fn_type, fn_names) {
fn_names <- paste0(fn_names, collapse = ", ")
paste0("- ", fn_type, " (", fn_names, ")")
},
fn_type = fn_cls_unique,
fn_names = fn_bad_names,
USE.NAMES = FALSE
)
cli::cli_abort(
c(
"x" = "The combination of metric functions must be:",
"*" = "only numeric metrics.",
"*" = "a mix of class metrics and class probability metrics.",
"*" = "a mix of dynamic and static survival metrics.",
"i" = "The following metric function types are being mixed:",
fn_pastable
),
call = rlang::call2("metric_set")
)
}
validate_estimate_static_linear_pred <- function(
estimate_eval,
call = caller_env()
) {
if (length(estimate_eval) != 2L) {
cli::cli_abort(
"{.arg estimate} must select exactly 2 columns from {.arg data},
not {length(estimate_eval)}.",
call = call
)
}
estimate_names <- names(estimate_eval)
expected_names <- c("static", "linear_pred")
if (!setequal(estimate_names, expected_names)) {
cli::cli_abort(
c(
"When mixing static and linear predictor survival metrics,
{.arg estimate} must use named selection.",
"i" = "Use {.code estimate = c(static = col1, linear_pred = col2)}.",
"i" = "Expected names: {.val {expected_names}}.",
"x" = "Received names: {.val {estimate_names}}."
),
call = call
)
}
}
# Safely evaluate metrics in such a way that we can capture the
# error and inform the user of the metric that failed
eval_safely <- function(expr, expr_nm, data = NULL, env = caller_env()) {
tryCatch(
expr = {
eval_tidy(expr, data = data, env = env)
},
error = function(cnd) {
cli::cli_abort(
"Failed to compute {.fn {expr_nm}}.",
parent = cnd,
call = call("metric_set")
)
}
)
}
call_remove_static_arguments <- function(call, fn) {
static <- get_static_arguments(fn)
if (length(static) == 0L) {
# No static arguments
return(call)
}
names <- rlang::call_args_names(call)
names <- intersect(names, static)
if (length(names) == 0L) {
# `static` arguments don't intersect with `call`
return(call)
}
zaps <- rlang::rep_named(names, list(rlang::zap()))
call <- call_modify(call, !!!zaps)
call
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.