R/validation.R

Defines functions validate_numeric_truth_quantile_estimate validate_case_weights validate_estimator validate_binary_estimator validate_surv_truth_numeric_estimate validate_surv_truth_list_estimate validate_ordered_truth_matrix_estimate validate_factor_truth_matrix_estimate validate_factor_truth_factor_estimate validate_numeric_truth_numeric_estimate

Documented in validate_estimator

validate_numeric_truth_numeric_estimate <- function(
  truth,
  estimate,
  call = caller_env()
) {
  if (!is.numeric(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a numeric vector,
      not {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (!is.numeric(estimate)) {
    cli::cli_abort(
      "{.arg estimate} should be a numeric vector,
      not {.obj_type_friendly {estimate}}.",
      call = call
    )
  }

  if (is.matrix(estimate)) {
    cli::cli_abort(
      "{.arg estimate} should be a numeric vector, not a numeric matrix.",
      call = call
    )
  }

  if (is.matrix(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a numeric vector, not a numeric matrix.",
      call = call
    )
  }

  n_truth <- length(truth)
  n_estimate <- length(estimate)

  if (n_truth != n_estimate) {
    cli::cli_abort(
      "{.arg truth} ({n_truth}) and
      {.arg estimate} ({n_estimate}) must be the same length.",
      call = call
    )
  }
}

validate_factor_truth_factor_estimate <- function(
  truth,
  estimate,
  call = caller_env()
) {
  if (!is.factor(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a factor,
      not a {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (!is.factor(estimate)) {
    cli::cli_abort(
      "{.arg estimate} should be a factor,
      not a {.obj_type_friendly {estimate}}.",
      call = call
    )
  }

  lvls_t <- levels(truth)
  lvls_e <- levels(estimate)

  if (!identical(lvls_t, lvls_e)) {
    cli::cli_abort(
      c(
        "x" = "{.arg truth} and {.arg estimate} levels must be equivalent.",
        "*" = "{.arg truth}: {lvls_t}.",
        "*" = "{.arg estimate}: {lvls_e}."
      ),
      call = call
    )
  }

  n_truth <- length(truth)
  n_estimate <- length(estimate)

  if (n_truth != n_estimate) {
    cli::cli_abort(
      "{.arg truth} ({n_truth}) and
      {.arg estimate} ({n_estimate}) must be the same length.",
      call = call
    )
  }
}

validate_factor_truth_matrix_estimate <- function(
  truth,
  estimate,
  estimator,
  call = caller_env()
) {
  if (!is.factor(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a factor,
      not a {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (estimator == "binary") {
    if (is.matrix(estimate)) {
      cli::cli_abort(
        "You are using a binary metric but have passed multiple columns to
        {.arg ...}.",
        call = call
      )
    }

    if (!is.numeric(estimate)) {
      cli::cli_abort(
        "{.arg estimate} should be a numeric vector,
        not {.obj_type_friendly {estimate}}.",
        call = call
      )
    }

    n_lvls <- nlevels(truth)
    if (n_lvls != 2) {
      cli::cli_abort(
        "{.arg estimator} is binary, only two class {.arg truth} factors are
        allowed. A factor with {n_lvls} levels was provided.",
        call = call
      )
    }
  } else {
    n_lvls <- nlevels(truth)
    if (is.matrix(estimate)) {
      n_cols <- ncol(estimate)
    } else {
      n_cols <- 1L
    }

    if (n_lvls != n_cols) {
      cli::cli_abort(
        "The number of levels in `truth` ({n_lvls})
        must match the number of columns supplied in `...` ({n_cols}).",
        call = call
      )
    }

    if (!is.numeric(as.vector(estimate))) {
      cls <- as.vector(estimate)
      cli::cli_abort(
        "The columns supplied in {.arg ...} should be numerics,
        not {.cls cls}.",
        call = call
      )
    }
  }
}

validate_ordered_truth_matrix_estimate <- function(
  truth,
  estimate,
  estimator,
  call = caller_env()
) {
  if (!is.ordered(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a ordered factor,
      not a {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (estimator == "binary") {
    if (is.matrix(estimate)) {
      cli::cli_abort(
        "You are using a binary metric but have passed multiple columns to
        {.arg ...}.",
        call = call
      )
    }

    if (!is.numeric(estimate)) {
      cli::cli_abort(
        "{.arg estimate} should be a numeric vector,
        not {.obj_type_friendly {estimate}}.",
        call = call
      )
    }

    n_lvls <- nlevels(truth)
    if (n_lvls != 2) {
      cli::cli_abort(
        "{.arg estimator} is binary, only two class {.arg truth} factors are
        allowed. A factor with {n_lvls} levels was provided.",
        call = call
      )
    }
  } else {
    n_lvls <- nlevels(truth)
    if (is.matrix(estimate)) {
      n_cols <- ncol(estimate)
    } else {
      n_cols <- 1L
    }

    if (n_lvls != n_cols) {
      cli::cli_abort(
        "The number of levels in `truth` ({n_lvls})
        must match the number of columns supplied in `...` ({n_cols}).",
        call = call
      )
    }

    if (!is.numeric(as.vector(estimate))) {
      cls <- as.vector(estimate)
      cli::cli_abort(
        "The columns supplied in {.arg ...} should be numerics,
        not {.cls cls}.",
        call = call
      )
    }
  }
}

validate_surv_truth_list_estimate <- function(
  truth,
  estimate,
  call = caller_env()
) {
  if (!inherits(truth, "Surv")) {
    cli::cli_abort(
      "`truth` should be a Surv object,
      not a {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (!is.list(estimate)) {
    cli::cli_abort(
      "{.arg estimate} should be a list,
      not a {.obj_type_friendly {estimate}}.",
      call = call
    )
  }

  if (!is.data.frame(estimate[[1]])) {
    cli::cli_abort(
      "All elements of {.arg estimate} should be data.frames.",
      call = call
    )
  }

  valid_names <- c(".eval_time", ".pred_survival", ".weight_censored")
  has_names <- all(valid_names %in% names(estimate[[1]]))
  if (!has_names) {
    cli::cli_abort(
      "All data.frames of {.arg estimate} should include column names:
      {.field (.eval_time)}, {.field (.pred_survival)}, and
      {.field (.weight_censored)}.",
      call = call
    )
  }

  n_truth <- nrow(truth)
  n_estimate <- length(estimate)

  if (n_truth != n_estimate) {
    cli::cli_abort(
      "{.arg truth} ({n_truth}) and
      {.arg estimate} ({n_estimate}) must be the same length.",
      call = call
    )
  }

  eval_time <- estimate[[1]]$.eval_time

  if (anyNA(eval_time)) {
    cli::cli_abort(
      c(
        x = "Missing values in {.field .eval_time} are not allowed."
      ),
      call = call
    )
  }

  if (any(eval_time < 0)) {
    offenders <- unique(eval_time[eval_time < 0])

    cli::cli_abort(
      c(
        x = "Negative values of {.field .eval_time} are not allowed.",
        i = "The following negative values were found: {.val {offenders}}."
      ),
      call = call
    )
  }

  if (any(is.infinite(eval_time))) {
    cli::cli_abort(
      c(
        x = "Infinite values of {.field .eval_time} are not allowed."
      ),
      call = call
    )
  }

  if (anyDuplicated(eval_time) > 0) {
    cli::cli_abort(
      c(
        x = "Duplicate values of {.field .eval_time} are not allowed."
      ),
      call = call
    )
  }
}

validate_surv_truth_numeric_estimate <- function(
  truth,
  estimate,
  call = caller_env()
) {
  if (!.is_surv(truth, fail = FALSE)) {
    cli::cli_abort(
      "{.arg truth} should be a Surv object,
      not {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (!is.numeric(estimate)) {
    cli::cli_abort(
      "{.arg estimate} should be a numeric vector,
      not {.obj_type_friendly {estimate}}.",
      call = call
    )
  }

  if (is.matrix(estimate)) {
    cli::cli_abort(
      "{.arg estimate} should be a numeric vector, not a numeric matrix.",
      call = call
    )
  }

  n_truth <- nrow(truth)
  n_estimate <- length(estimate)

  if (n_truth != n_estimate) {
    cli::cli_abort(
      "{.arg truth} ({n_truth}) and
      {.arg estimate} ({n_estimate}) must be the same length.",
      call = call
    )
  }
}

validate_binary_estimator <- function(truth, estimator, call = caller_env()) {
  if (estimator != "binary") {
    return()
  }

  lvls <- levels(truth)
  if (length(lvls) != 2) {
    cli::cli_abort(
      "{.arg estimator} is binary, only two class {.arg truth} factors are
      allowed. A factor with {length(lvls)} levels was provided.",
      call = call
    )
  }
}

#' @section Estimator Validation:
#' `validate_estimator()` is called from your metric specific method of
#' `finalize_estimator_internal()` and ensures that a user provided estimator
#' is of the right format and is one of the allowed values.
#'
#' @param estimator_override A character vector overriding the default allowed
#' estimator list of
#' `c("binary", "macro", "micro", "macro_weighted")`. Set
#' this if your classification estimator does not support all of these methods.
#' @rdname developer-helpers
#' @export
validate_estimator <- function(
  estimator,
  estimator_override = NULL,
  call = caller_env()
) {
  if (is.null(estimator)) {
    return()
  }

  if (!is.null(estimator_override)) {
    allowed <- estimator_override
  } else {
    allowed <- c("binary", "macro", "micro", "macro_weighted")
  }

  estimator <- rlang::arg_match(estimator, allowed, error_call = call)
  invisible(NULL)
}

validate_case_weights <- function(case_weights, size, call = caller_env()) {
  if (is.null(case_weights)) {
    return(invisible())
  }

  size_case_weights <- length(case_weights)

  if (size_case_weights != size) {
    cli::cli_abort(
      "{.arg truth} ({size}) and
      {.arg case_weights} ({size_case_weights}) must be the same length.",
      call = call
    )
  }

  invisible(NULL)
}

validate_numeric_truth_quantile_estimate <- function(
  truth,
  estimate,
  call = caller_env()
) {
  if (!is.numeric(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a numeric vector,
      not {.obj_type_friendly {truth}}.",
      call = call
    )
  }

  if (!inherits(estimate, "quantile_pred")) {
    cli::cli_abort(
      "{.arg estimate} should be a {.cls quantile_pred} object,
      not {.obj_type_friendly {estimate}}.",
      call = call
    )
  }

  if (is.matrix(truth)) {
    cli::cli_abort(
      "{.arg truth} should be a numeric vector, not a numeric matrix.",
      call = call
    )
  }

  n_truth <- length(truth)
  n_estimate <- vctrs::vec_size(estimate)

  if (n_truth != n_estimate) {
    cli::cli_abort(
      "{.arg truth} ({n_truth}) and
      {.arg estimate} ({n_estimate}) must be the same length.",
      call = call
    )
  }
}

Try the yardstick package in your browser

Any scripts or data that you put into this service are public.

yardstick documentation built on April 8, 2026, 1:06 a.m.