R/estimator-helpers.R

Defines functions is_multiclass.factor is_multiclass.table is_multiclass.default is_multiclass make_dummy finalize_estimator_default.factor finalize_estimator_default.table finalize_estimator_default.Surv finalize_estimator_default.numeric finalize_estimator_default.matrix finalize_estimator_default.default finalize_estimator_default finalize_estimator_internal.pr_auc finalize_estimator_internal.roc_auc finalize_estimator_internal.accuracy finalize_estimator_internal.default finalize_estimator_internal finalize_estimator get_weights

Documented in finalize_estimator finalize_estimator_internal get_weights

#' @section Weight Calculation:
#' `get_weights()` accepts a confusion matrix and an `estimator` of type
#' `"macro"`, `"micro"`, or `"macro_weighted"` and returns the correct weights.
#' It is useful when creating multiclass metrics.
#'
#' @export
#' @rdname developer-helpers
#' @param data A table with truth values as columns and predicted values
#' as rows.
get_weights <- function(data, estimator) {
  if (estimator == "macro") {
    n <- ncol(data)
    rep(1 / n, times = n)
  } else if (estimator == "micro") {
    1
  } else if (estimator == "macro_weighted") {
    .col_sums <- colSums(data)
    .col_sums / sum(.col_sums)
  } else {
    cli::cli_abort(
      "{.arg estimator} type {.val {estimator}} is unknown."
    )
  }
}

# ------------------------------------------------------------------------------

#' @section Estimator Selection:
#'
#' `finalize_estimator()` is the engine for auto-selection of `estimator` based
#' on the type of `x`. Generally `x` is the `truth` column. This function
#' is called from the vector method of your metric.
#'
#' `finalize_estimator_internal()` is an S3 generic that you should extend for
#'  your metric if it does not implement _only_ the following estimator types:
#'  `"binary"`, `"macro"`, `"micro"`, and `"macro_weighted"`.
#'  If your metric does support all of these, the default version of
#'  `finalize_estimator_internal()` will autoselect `estimator` appropriately.
#'  If you need to create a method, it should take the form:
#' `finalize_estimator_internal.metric_name`. Your method for
#' `finalize_estimator_internal()` should do two things:
#'
#' 1) If `estimator` is `NULL`, autoselect the `estimator` based on the
#' type of `x` and return a single character for the `estimator`.
#'
#' 2) If `estimator` is not `NULL`, validate that it is an allowed `estimator`
#' for your metric and return it.
#'
#' If you are using the default for `finalize_estimator_internal()`, the
#' `estimator` is selected using the following heuristics:
#'
#' 1) If `estimator` is not `NULL`, it is validated and returned immediately
#' as no auto-selection is needed.
#'
#' 2) If `x` is a:
#'
#'    * `factor` - Then `"binary"` is returned if it has 2 levels, otherwise
#'      `"macro"` is returned.
#'
#'    * `numeric` - Then `"binary"` is returned.
#'
#'    * `table` - Then `"binary"` is returned if it has 2 columns, otherwise
#'      `"macro"` is returned. This is useful if you have `table` methods.
#'
#'    * `matrix` - Then `"macro"` is returned.
#'
#' @rdname developer-helpers
#'
#' @inheritParams rlang::args_error_context
#'
#' @param metric_class A single character of the name of the metric to autoselect
#' the estimator for. This should match the method name created for
#' `finalize_estimator_internal()`.
#'
#' @param x The column used to autoselect the estimator. This is generally
#' the `truth` column, but can also be a table if your metric has table methods.
#'
#' @param estimator Either `NULL` for auto-selection, or a single character
#' for the type of estimator to use.
#'
#' @seealso [metric-summarizers] [check_metric] [yardstick_remove_missing]
#'
#' @export
finalize_estimator <- function(x,
                               estimator = NULL,
                               metric_class = "default",
                               call = caller_env()) {
  metric_dispatcher <- make_dummy(metric_class)
  finalize_estimator_internal(metric_dispatcher, x, estimator, call = call)
}

#' @rdname developer-helpers
#' @param metric_dispatcher A simple dummy object with the class provided to
#' `metric_class`. This is created and passed along for you.
#' @export
finalize_estimator_internal <- function(metric_dispatcher,
                                        x,
                                        estimator,
                                        call = caller_env()) {
  UseMethod("finalize_estimator_internal")
}

#' @export
finalize_estimator_internal.default <- function(metric_dispatcher,
                                                x,
                                                estimator,
                                                call = caller_env()) {
  finalize_estimator_default(x, estimator, call = call)
}

# Accuracy, Kappa, Mean Log Loss, and MCC have natural multiclass extensions.
# Additionally, they all produce the same results regardless of which level
# is considered the "event". Because of this, the user cannot set the estimator,
# and it should only be "binary" or "multiclass"
#' @export
finalize_estimator_internal.accuracy <- function(metric_dispatcher,
                                                 x,
                                                 estimator,
                                                 call = caller_env()) {
  if (is_multiclass(x)) {
    "multiclass"
  } else {
    "binary"
  }
}

#' @export
finalize_estimator_internal.kap <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.mcc <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.mn_log_loss <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.brier_class <- finalize_estimator_internal.accuracy


# Classification cost extends naturally to multiclass and produce the same
# result regardless of the "event" level.
#' @export
finalize_estimator_internal.classification_cost <- finalize_estimator_internal.accuracy

# Curve methods don't use the estimator when printing, but do dispatch
# off it to determine whether to do one-vs-all or not
#' @export
finalize_estimator_internal.gain_curve <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.lift_curve <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.roc_curve <- finalize_estimator_internal.accuracy

#' @export
finalize_estimator_internal.pr_curve <- finalize_estimator_internal.accuracy

# Hand Till method is the "best" multiclass extension to me
# because it is immune to class imbalance like binary roc_auc
#' @export
finalize_estimator_internal.roc_auc <- function(metric_dispatcher,
                                                x,
                                                estimator,
                                                call = caller_env()) {
  validate_estimator(
    estimator = estimator,
    estimator_override = c("binary", "macro", "macro_weighted", "hand_till")
  )

  if (!is.null(estimator)) {
    return(estimator)
  }

  if (is_multiclass(x)) {
    "hand_till"
  } else {
    "binary"
  }
}

# PR AUC and Gain Capture don't have micro methods currently
#' @export
finalize_estimator_internal.pr_auc <- function(metric_dispatcher,
                                               x,
                                               estimator,
                                               call = caller_env()) {
  validate_estimator(
    estimator = estimator,
    estimator_override = c("binary", "macro", "macro_weighted")
  )

  if (!is.null(estimator)) {
    return(estimator)
  }

  if (is_multiclass(x)) {
    "macro"
  } else {
    "binary"
  }
}

#' @export
finalize_estimator_internal.gain_capture <- finalize_estimator_internal.pr_auc

# Default ----------------------------------------------------------------------

finalize_estimator_default <- function(x, estimator, call = caller_env()) {
  if (!is.null(estimator)) {
    validate_estimator(estimator, call = call)
    return(estimator)
  }
  UseMethod("finalize_estimator_default")
}

finalize_estimator_default.default <- function(x,
                                               estimator,
                                               call = caller_env()) {
  "binary"
}

finalize_estimator_default.matrix <- function(x,
                                              estimator,
                                              call = caller_env()) {
  "macro"
}

# Covers all numeric metric functions
finalize_estimator_default.numeric <- function(x,
                                               estimator,
                                               call = caller_env()) {
  "standard"
}

# Covers all dynamic survival functions
finalize_estimator_default.Surv <- function(x,
                                            estimator,
                                            call = caller_env()) {
  "standard"
}

finalize_estimator_default.table <- function(x,
                                             estimator,
                                             call = caller_env()) {
  if (is_multiclass(x)) {
    "macro"
  } else {
    "binary"
  }
}

finalize_estimator_default.factor <- function(x,
                                              estimator,
                                              call = caller_env()) {
  if (is_multiclass(x)) {
    "macro"
  } else {
    "binary"
  }
}


# Util -------------------------------------------------------------------------

make_dummy <- function(metric_class) {
  structure(list(), class = metric_class)
}

is_multiclass <- function(x) {
  UseMethod("is_multiclass")
}

is_multiclass.default <- function(x) {
  # dont throw a error here
  # this case should only happen if x is an
  # unknown type, and better error catching
  # is done later to return a good error message
  FALSE
}

is_multiclass.table <- function(x) {
  n_col <- ncol(x)

  # binary
  if (n_col <= 2) {
    return(FALSE)
  }

  # multiclass
  if (n_col > 2) {
    return(TRUE)
  }
}

is_multiclass.factor <- function(x) {
  lvls <- levels(x)
  n_lvls <- length(lvls)

  if (n_lvls <= 2) {
    return(FALSE)
  }

  if (n_lvls > 2) {
    return(TRUE)
  }
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on June 22, 2024, 7:07 p.m.