R/aaa-metric_set.R

Defines functions call_remove_static_arguments eval_safely validate_estimate_static_linear_pred validate_function_class validate_inputs_are_functions validate_not_empty make_quantile_metric_function make_survival_metric_function make_numeric_metric_function make_prob_class_metric_function get_quo_label get_metric_fn_direction class1 map_chr as_tibble.metric_set format.metric_set print.metric_set metric_set

Documented in metric_set

# Metric set -------------------------------------------------------------------

#' Combine metric functions
#'
#' `metric_set()` allows you to combine multiple metric functions together
#' into a new function that calculates all of them at once.
#'
#' @param ... The bare names of the functions to be included in the metric set.
#'
#' @details
#' All functions must be either:
#' - Only numeric metrics
#' - A mix of class metrics or class prob metrics
#' - A mix of dynamic, integrated, and static survival metrics
#'
#' For instance, `rmse()` can be used with `mae()` because they
#' are numeric metrics, but not with `accuracy()` because it is a classification
#' metric. But `accuracy()` can be used with `roc_auc()`.
#'
#' The returned metric function will have a different argument list
#' depending on whether numeric metrics or a mix of class/prob metrics were
#' passed in.
#'
#' ```
#' # Numeric metric set signature:
#' fn(
#'   data,
#'   truth,
#'   estimate,
#'   na_rm = TRUE,
#'   case_weights = NULL,
#'   ...
#' )
#'
#' # Class / prob metric set signature:
#' fn(
#'   data,
#'   truth,
#'   ...,
#'   estimate,
#'   estimator = NULL,
#'   na_rm = TRUE,
#'   event_level = yardstick_event_level(),
#'   case_weights = NULL
#' )
#'
#' # Dynamic / integrated / static survival metric set signature:
#' fn(
#'   data,
#'   truth,
#'   ...,
#'   estimate,
#'   na_rm = TRUE,
#'   case_weights = NULL
#' )
#' ```
#'
#' @section Naming the `estimate` argument:
#'
#' Note that for class/prob metric sets and survival metric sets, the
#' `estimate` argument comes *after* `...` in the function signature. This
#' means you must always pass `estimate` as a named argument; otherwise
#' your column will be captured by `...` and you'll get an error.
#'
#' ```
#' # Correct - estimate is named:
#' my_metrics(two_class_example, truth, estimate = predicted)
#'
#' # Incorrect - estimate is not named and gets captured by `...`:
#' my_metrics(two_class_example, truth, predicted)
#' ```
#'
#' When mixing class and class prob metrics, pass in the hard predictions
#' (the factor column) as the named argument `estimate`, and the soft
#' predictions (the class probability columns) as bare column names or
#' `tidyselect` selectors to `...`.
#'
#' When mixing dynamic, integrated, and static survival metrics, pass in the
#' time predictions as the named argument `estimate`, and the survival
#' predictions as bare column names or `tidyselect` selectors to `...`.
#'
#' If `metric_tweak()` has been used to "tweak" one of these arguments, like
#' `estimator` or `event_level`, then the tweaked version wins. This allows you
#' to set the estimator on a metric by metric basis and still use it in a
#' `metric_set()`.
#'
#' @examples
#' library(dplyr)
#'
#' # Multiple regression metrics
#' multi_metric <- metric_set(rmse, rsq, ccc)
#'
#' # The returned function has arguments:
#' # fn(data, truth, estimate, na_rm = TRUE, ...)
#' multi_metric(solubility_test, truth = solubility, estimate = prediction)
#'
#' # Groups are respected on the new metric function
#' class_metrics <- metric_set(accuracy, kap)
#'
#' hpc_cv |>
#'   group_by(Resample) |>
#'   class_metrics(obs, estimate = pred)
#'
#' # ---------------------------------------------------------------------------
#'
#' # If you need to set options for certain metrics, do so by using
#' # `metric_tweak()`. Here's an example where we use the `bias` option to the
#' # `ccc()` metric
#' ccc_with_bias <- metric_tweak("ccc_with_bias", ccc, bias = TRUE)
#'
#' multi_metric2 <- metric_set(rmse, rsq, ccc_with_bias)
#'
#' multi_metric2(solubility_test, truth = solubility, estimate = prediction)
#'
#' # ---------------------------------------------------------------------------
#' # A class probability example:
#'
#' # Note that, when given class or class prob functions,
#' # metric_set() returns a function with signature:
#' # fn(data, truth, ..., estimate)
#' # to be able to mix class and class prob metrics.
#'
#' # You must provide the `estimate` column by explicitly naming
#' # the argument
#'
#' class_and_probs_metrics <- metric_set(roc_auc, pr_auc, accuracy)
#'
#' hpc_cv |>
#'   group_by(Resample) |>
#'   class_and_probs_metrics(obs, VF:L, estimate = pred)
#'
#' @seealso [metrics()], [get_metrics()]
#'
#' @export
metric_set <- function(...) {
  quo_fns <- enquos(...)
  validate_not_empty(quo_fns)

  # Get values and check that they are fns
  fns <- lapply(quo_fns, eval_tidy)
  validate_inputs_are_functions(fns)

  # Add on names, and then check that
  # all fns are of the same function class
  for (i in seq_along(fns)) {
    fn_name <- names(fns[i])
    if (fn_name == "") {
      names(fns)[i] <- get_quo_label(quo_fns[[i]])
    } else {
      names(fns)[i] <- fn_name
    }
  }
  validate_function_class(fns)

  fn_cls <- class1(fns[[1]])

  # signature of the function is different depending on input functions
  if (fn_cls == "numeric_metric") {
    make_numeric_metric_function(fns)
  } else if (
    fn_cls %in% c("prob_metric", "class_metric", "ordered_prob_metric")
  ) {
    make_prob_class_metric_function(fns)
  } else if (fn_cls == "quantile_metric") {
    make_quantile_metric_function(fns)
  } else if (
    fn_cls %in%
      c(
        "dynamic_survival_metric",
        "static_survival_metric",
        "integrated_survival_metric",
        "linear_pred_survival_metric"
      )
  ) {
    make_survival_metric_function(fns)
  } else {
    # should not be reachable
    # nocov start
    cli::cli_abort(
      "{.fn validate_function_class} should have errored on unknown classes.",
      .internal = TRUE
    )
    # nocov end
  }
}

#' @export
print.metric_set <- function(x, ...) {
  cat(format(x), sep = "\n")
  invisible(x)
}

#' @export
format.metric_set <- function(x, ...) {
  metrics <- attributes(x)$metrics
  names <- names(metrics)

  cli::cli_format_method({
    cli::cli_text("A metric set, consisting of:")

    metric_formats <- vapply(metrics, format, character(1))
    metric_formats <- strsplit(metric_formats, " | ", fixed = TRUE)

    metric_names <- names(metric_formats)
    metric_types <- vapply(
      metric_formats,
      `[`,
      character(1),
      1,
      USE.NAMES = FALSE
    )
    metric_descs <- vapply(metric_formats, `[`, character(1), 2)
    metric_nchars <- nchar(metric_names) + nchar(metric_types)
    metric_desc_paddings <- max(metric_nchars) - metric_nchars
    # see r-lib/cli#506
    metric_desc_paddings <- lapply(metric_desc_paddings, rep, x = "\u00a0")
    metric_desc_paddings <- vapply(
      metric_desc_paddings,
      paste,
      character(1),
      collapse = ""
    )

    for (i in seq_along(metrics)) {
      cli::cli_text(
        "- {.fun {metric_names[i]}},
           {tolower(metric_types[i])}{metric_desc_paddings[i]} |
           {metric_descs[i]}"
      )
    }
  })
}

#' @export
as_tibble.metric_set <- function(x, ...) {
  metrics <- attributes(x)$metrics
  names <- names(metrics)
  metrics <- unname(metrics)

  classes <- map_chr(metrics, class1)
  directions <- map_chr(metrics, get_metric_fn_direction)

  dplyr::tibble(
    metric = names,
    class = classes,
    direction = directions
  )
}

map_chr <- function(x, f, ...) {
  vapply(x, f, character(1), ...)
}

class1 <- function(x) {
  class(x)[[1]]
}

get_metric_fn_direction <- function(x) {
  attr(x, "direction")
}

get_quo_label <- function(quo) {
  out <- as_label(quo)

  if (length(out) != 1L) {
    # should not be reachable
    # nocov start
    cli::cli_abort(
      "{.code as_label(quo)} resulted in a character vector of length >1.",
      .internal = TRUE
    )
    # nocov end
  }

  is_namespaced <- grepl("::", out, fixed = TRUE)

  if (is_namespaced) {
    # Split by `::` and take the second half
    split <- strsplit(out, "::", fixed = TRUE)[[1]]
    out <- split[[2]]
  }

  out
}

make_prob_class_metric_function <- function(fns) {
  metric_function <- function(
    data,
    truth,
    ...,
    estimate,
    estimator = NULL,
    na_rm = TRUE,
    event_level = yardstick_event_level(),
    case_weights = NULL
  ) {
    # Find class vs prob metrics
    are_class_metrics <- vapply(
      X = fns,
      FUN = inherits,
      FUN.VALUE = logical(1),
      what = "class_metric"
    )

    class_fns <- fns[are_class_metrics]
    prob_fns <- fns[!are_class_metrics]

    dots_not_empty <- length(match.call(expand.dots = FALSE)$...) > 0
    if (!is_empty(class_fns) && missing(estimate) && dots_not_empty) {
      cli::cli_abort(
        c(
          "!" = "{.arg estimate} is required for class metrics but was not
                 provided.",
          "i" = "In a metric set, the {.arg estimate} argument must be named.",
          "i" = "Example: {.code my_metrics(data, truth, estimate = my_column)}"
        ),
        call = current_env()
      )
    }

    metric_list <- list()

    # Evaluate class metrics
    if (!is_empty(class_fns)) {
      class_args <- quos(
        data = data,
        truth = !!enquo(truth),
        estimate = !!enquo(estimate),
        estimator = estimator,
        na_rm = na_rm,
        event_level = event_level,
        case_weights = !!enquo(case_weights)
      )

      class_calls <- lapply(class_fns, call2, !!!class_args)

      class_calls <- mapply(
        call_remove_static_arguments,
        class_calls,
        class_fns
      )

      class_list <- mapply(
        FUN = eval_safely,
        class_calls, # .x
        names(class_calls), # .y
        SIMPLIFY = FALSE,
        USE.NAMES = FALSE
      )

      metric_list <- c(metric_list, class_list)
    }

    # Evaluate prob metrics
    if (!is_empty(prob_fns)) {
      # TODO - If prob metrics can all do micro, we can remove this
      if (!is.null(estimator) && estimator == "micro") {
        prob_estimator <- NULL
      } else {
        prob_estimator <- estimator
      }

      prob_args <- quos(
        data = data,
        truth = !!enquo(truth),
        ... = ...,
        estimator = prob_estimator,
        na_rm = na_rm,
        event_level = event_level,
        case_weights = !!enquo(case_weights)
      )

      prob_calls <- lapply(prob_fns, call2, !!!prob_args)

      prob_calls <- mapply(call_remove_static_arguments, prob_calls, prob_fns)

      prob_list <- mapply(
        FUN = eval_safely,
        prob_calls, # .x
        names(prob_calls), # .y
        SIMPLIFY = FALSE,
        USE.NAMES = FALSE
      )

      metric_list <- c(metric_list, prob_list)
    }

    dplyr::bind_rows(metric_list)
  }

  class(metric_function) <- c(
    "class_prob_metric_set",
    "metric_set",
    class(metric_function)
  )

  attr(metric_function, "metrics") <- fns

  metric_function
}

make_numeric_metric_function <- function(fns) {
  metric_function <- function(
    data,
    truth,
    estimate,
    na_rm = TRUE,
    case_weights = NULL,
    ...
  ) {
    # Construct common argument set for each metric call
    # Doing this dynamically inside the generated function means
    # we capture the correct arguments
    call_args <- quos(
      data = data,
      truth = !!enquo(truth),
      estimate = !!enquo(estimate),
      na_rm = na_rm,
      case_weights = !!enquo(case_weights),
      ... = ...
    )

    # Construct calls from the functions + arguments
    calls <- lapply(fns, call2, !!!call_args)

    calls <- mapply(call_remove_static_arguments, calls, fns)

    # Evaluate
    metric_list <- mapply(
      FUN = eval_safely,
      calls, # .x
      names(calls), # .y
      SIMPLIFY = FALSE,
      USE.NAMES = FALSE
    )

    dplyr::bind_rows(metric_list)
  }

  class(metric_function) <- c(
    "numeric_metric_set",
    "metric_set",
    class(metric_function)
  )

  attr(metric_function, "metrics") <- fns

  metric_function
}

make_survival_metric_function <- function(fns) {
  metric_function <- function(
    data,
    truth,
    ...,
    estimate,
    pred_time,
    na_rm = TRUE,
    case_weights = NULL
  ) {
    # Construct common argument set for each metric call
    # Doing this dynamically inside the generated function means
    # we capture the correct arguments

    is_static <- vapply(
      fns,
      inherits,
      logical(1),
      "static_survival_metric"
    )
    is_linear_pred <- vapply(
      fns,
      inherits,
      logical(1),
      "linear_pred_survival_metric"
    )
    is_dynamic_or_integrated <- vapply(
      fns,
      function(fn) {
        inherits(fn, "dynamic_survival_metric") ||
          inherits(fn, "integrated_survival_metric")
      },
      FUN.VALUE = logical(1)
    )

    needs_estimate <- any(is_static) || any(is_linear_pred)
    dots_not_empty <- length(match.call(expand.dots = FALSE)$...) > 0
    if (needs_estimate && missing(estimate) && dots_not_empty) {
      cli::cli_abort(
        c(
          "!" = "{.arg estimate} is required for static or linear predictor
                 survival metrics but was not provided.",
          "i" = "In a metric set, the {.arg estimate} argument must be named.",
          "i" = "Example: {.code my_metrics(data, truth, estimate = my_column)}"
        ),
        call = current_env()
      )
    }

    # Static and linear pred metrics both use the `estimate` argument
    # so we need route the columns to the correct metric functions
    is_set_of_static_and_linear_pred <- any(is_static) && any(is_linear_pred)

    if (is_set_of_static_and_linear_pred) {
      estimate_eval <- tidyselect::eval_select(
        expr = enquo(estimate),
        data = data,
        allow_rename = TRUE,
        allow_empty = FALSE,
        error_call = current_env()
      )

      validate_estimate_static_linear_pred(estimate_eval, call = current_env())

      static_col_name <- names(data)[estimate_eval["static"]]
      linear_pred_col_name <- names(data)[estimate_eval["linear_pred"]]

      args_static <- quos(
        data = data,
        truth = !!enquo(truth),
        estimate = !!sym(static_col_name),
        na_rm = na_rm,
        case_weights = !!enquo(case_weights),
        ... = ...
      )

      args_linear_pred <- quos(
        data = data,
        truth = !!enquo(truth),
        estimate = !!sym(linear_pred_col_name),
        na_rm = na_rm,
        case_weights = !!enquo(case_weights),
        ... = ...
      )

      calls_static <- lapply(fns[is_static], call2, !!!args_static)
      calls_linear_pred <- lapply(
        fns[is_linear_pred],
        call2,
        !!!args_linear_pred
      )

      calls_estimate <- c(calls_static, calls_linear_pred)
    } else {
      args_estimate <- quos(
        data = data,
        truth = !!enquo(truth),
        estimate = !!enquo(estimate),
        na_rm = na_rm,
        case_weights = !!enquo(case_weights),
        ... = ...
      )

      needs_estimate_arg <- is_static | is_linear_pred
      calls_estimate <- lapply(fns[needs_estimate_arg], call2, !!!args_estimate)
    }

    args_dots <- quos(
      data = data,
      truth = !!enquo(truth),
      ... = ...,
      na_rm = na_rm,
      case_weights = !!enquo(case_weights)
    )
    calls_dots <- lapply(fns[is_dynamic_or_integrated], call2, !!!args_dots)

    calls <- c(calls_dots, calls_estimate)
    calls <- mapply(call_remove_static_arguments, calls, fns)

    # Evaluate
    metric_list <- mapply(
      FUN = eval_safely,
      calls, # .x
      names(calls), # .y
      SIMPLIFY = FALSE,
      USE.NAMES = FALSE
    )

    dplyr::bind_rows(metric_list)
  }

  class(metric_function) <- c(
    "survival_metric_set",
    "metric_set",
    class(metric_function)
  )

  attr(metric_function, "metrics") <- fns

  metric_function
}

make_quantile_metric_function <- function(fns) {
  metric_function <- function(
    data,
    truth,
    estimate,
    na_rm = TRUE,
    case_weights = NULL,
    ...
  ) {
    # Construct common argument set for each metric call
    # Doing this dynamically inside the generated function means
    # we capture the correct arguments
    call_args <- quos(
      data = data,
      truth = !!enquo(truth),
      estimate = !!enquo(estimate),
      na_rm = na_rm,
      case_weights = !!enquo(case_weights),
      ... = ...
    )

    # Construct calls from the functions + arguments
    calls <- lapply(fns, call2, !!!call_args)

    calls <- mapply(call_remove_static_arguments, calls, fns)

    # Evaluate
    metric_list <- mapply(
      FUN = eval_safely,
      calls, # .x
      names(calls), # .y
      SIMPLIFY = FALSE,
      USE.NAMES = FALSE
    )

    dplyr::bind_rows(metric_list)
  }

  class(metric_function) <- c(
    "quantile_metric_set",
    "metric_set",
    class(metric_function)
  )

  attr(metric_function, "metrics") <- fns

  metric_function
}

validate_not_empty <- function(x, call = caller_env()) {
  if (is_empty(x)) {
    cli::cli_abort(
      "At least 1 function must be supplied to {.code ...}.",
      call = call
    )
  }
}

validate_inputs_are_functions <- function(fns, call = caller_env()) {
  # Check that the user supplied all functions
  is_fun_vec <- vapply(fns, is_function, logical(1))
  all_fns <- all(is_fun_vec)

  if (!all_fns) {
    not_fn <- which(!is_fun_vec)
    cli::cli_abort(
      "All inputs to {.fn metric_set} must be functions.
      These inputs are not: {not_fn}.",
      call = call
    )
  }
}

# Validate that all metric functions inherit from valid function classes or
# combinations of classes
validate_function_class <- function(fns) {
  fn_cls <- vapply(fns, function(fn) class(fn)[1], character(1))
  fn_cls_unique <- unique(fn_cls)
  n_unique <- length(fn_cls_unique)

  if (n_unique == 0L) {
    return(invisible(fns))
  }
  valid_cls <- c(
    "class_metric",
    "prob_metric",
    "ordered_prob_metric",
    "numeric_metric",
    "dynamic_survival_metric",
    "static_survival_metric",
    "integrated_survival_metric",
    "linear_pred_survival_metric",
    "quantile_metric"
  )

  if (n_unique == 1L) {
    if (fn_cls_unique %in% valid_cls) {
      return(invisible(fns))
    }
  }

  class_prob_cls <- c("class_metric", "prob_metric", "ordered_prob_metric")
  if (
    any(fn_cls_unique %in% class_prob_cls) &&
      all(fn_cls_unique %in% class_prob_cls)
  ) {
    return(invisible(fns))
  }

  surv_cls <- c(
    "dynamic_survival_metric",
    "static_survival_metric",
    "integrated_survival_metric",
    "linear_pred_survival_metric"
  )
  if (any(fn_cls_unique %in% surv_cls) && all(fn_cls_unique %in% surv_cls)) {
    return(invisible(fns))
  }

  # Special case unevaluated groupwise metric factories
  if ("metric_factory" %in% fn_cls) {
    factories <- fn_cls[fn_cls == "metric_factory"]
    cli::cli_abort(
      c(
        "{cli::qty(factories)}The input{?s} {.arg {names(factories)}}
         {?is a/are} {.help [groupwise metric](yardstick::new_groupwise_metric)}
         {?factory/factories} and must be passed a data-column before
         addition to a metric set.",
        "i" = "Did you mean to type e.g. `{names(factories)[1]}(col_name)`?"
      ),
      call = rlang::call2("metric_set")
    )
  }

  # Each element of the list contains the names of the fns
  # that inherit that specific class
  fn_bad_names <- lapply(fn_cls_unique, function(x) {
    names(fns)[fn_cls == x]
  })

  # clean up for nicer printing
  fn_cls_unique <- gsub("_metric", "", fn_cls_unique, fixed = TRUE)
  fn_cls_unique <- gsub("function", "other", fn_cls_unique, fixed = TRUE)

  fn_cls_other <- fn_cls_unique == "other"

  if (any(fn_cls_other)) {
    fn_cls_other_loc <- which(fn_cls_other)
    fn_other_names <- fn_bad_names[[fn_cls_other_loc]]
    fns_other <- fns[fn_other_names]

    env_names_other <- vapply(
      fns_other,
      function(fn) env_name(fn_env(fn)),
      character(1)
    )

    fn_bad_names[[fn_cls_other_loc]] <- paste0(
      fn_other_names,
      " ",
      "<",
      env_names_other,
      ">"
    )
  }

  # Prints as:
  # - fn_type1 (fn_name1, fn_name2)
  # - fn_type2 (fn_name1)
  fn_pastable <- mapply(
    FUN = function(fn_type, fn_names) {
      fn_names <- paste0(fn_names, collapse = ", ")
      paste0("- ", fn_type, " (", fn_names, ")")
    },
    fn_type = fn_cls_unique,
    fn_names = fn_bad_names,
    USE.NAMES = FALSE
  )

  cli::cli_abort(
    c(
      "x" = "The combination of metric functions must be:",
      "*" = "only numeric metrics.",
      "*" = "a mix of class metrics and class probability metrics.",
      "*" = "a mix of dynamic and static survival metrics.",
      "i" = "The following metric function types are being mixed:",
      fn_pastable
    ),
    call = rlang::call2("metric_set")
  )
}

validate_estimate_static_linear_pred <- function(
  estimate_eval,
  call = caller_env()
) {
  if (length(estimate_eval) != 2L) {
    cli::cli_abort(
      "{.arg estimate} must select exactly 2 columns from {.arg data},
      not {length(estimate_eval)}.",
      call = call
    )
  }

  estimate_names <- names(estimate_eval)
  expected_names <- c("static", "linear_pred")

  if (!setequal(estimate_names, expected_names)) {
    cli::cli_abort(
      c(
        "When mixing static and linear predictor survival metrics,
             {.arg estimate} must use named selection.",
        "i" = "Use {.code estimate = c(static = col1, linear_pred = col2)}.",
        "i" = "Expected names: {.val {expected_names}}.",
        "x" = "Received names: {.val {estimate_names}}."
      ),
      call = call
    )
  }
}

# Safely evaluate metrics in such a way that we can capture the
# error and inform the user of the metric that failed
eval_safely <- function(expr, expr_nm, data = NULL, env = caller_env()) {
  tryCatch(
    expr = {
      eval_tidy(expr, data = data, env = env)
    },
    error = function(cnd) {
      cli::cli_abort(
        "Failed to compute {.fn {expr_nm}}.",
        parent = cnd,
        call = call("metric_set")
      )
    }
  )
}

call_remove_static_arguments <- function(call, fn) {
  static <- get_static_arguments(fn)

  if (length(static) == 0L) {
    # No static arguments
    return(call)
  }

  names <- rlang::call_args_names(call)
  names <- intersect(names, static)

  if (length(names) == 0L) {
    # `static` arguments don't intersect with `call`
    return(call)
  }

  zaps <- rlang::rep_named(names, list(rlang::zap()))
  call <- call_modify(call, !!!zaps)

  call
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 8, 2026, 1:06 a.m.