R/loss_functions.R

Defines functions cv_risk custom_ROCR_risk risk loss_squared_error_multivariate loss_loglik_multinomial loss_loglik_binomial loss_loglik_true_cat loss_squared_error

Documented in custom_ROCR_risk cv_risk loss_loglik_binomial loss_loglik_multinomial loss_loglik_true_cat loss_squared_error loss_squared_error_multivariate risk

utils::globalVariables(c("id", "loss", "obs", "pred", "wts"))

#' Loss Function Definitions
#'
#' Loss functions for use in evaluating learner fits.
#'
#' @param pred A vector of predicted values
#' @param observed A vector of observed values
#'
#' @return A vector of loss values
#'
#' @name loss_functions

# SQUARED ERROR LOSS
#
#' @rdname loss_functions
#'
#' @export
loss_squared_error <- function(pred, observed) {
  out <- (pred - observed)^2
  attributes(out)$name <- "MSE"
  return(out)
}

# NEGATIVE LOG-LIKELIHOOD LOSS
#
# Assumes pred is p(Y = observed); therefore, observed is not actually used.
#
#' @rdname loss_functions
#'
#' @export
loss_loglik_true_cat <- function(pred, observed) {
  out <- -log(bound(pred))
  attributes(out)$name <- "NLL"
  return(out)
}

# NEGATIVE LOG-LIKELIHOOD LOSS FOR BINOMIAL OUTCOMES
#
#' @rdname loss_functions
#'
#' @export
loss_loglik_binomial <- function(pred, observed) {
  out <- -1 * ifelse(observed == 1, log(bound(pred)), log(1 - bound(pred)))
  attributes(out)$name <- "NLL"
  return(out)
}

# NEGATIVE LOG-LIKELIHOOD LOSS FOR MULTINOMIAL OUTCOMES
#
# Assumes predicted probabilities are "packed" into a single vector
#
#' @rdname loss_functions
#'
#' @export
loss_loglik_multinomial <- function(pred, observed) {
  # make index matrix
  index_mat <- cbind(seq_along(observed), observed)
  unpacked <- unpack_predictions(pred)
  class_liks <- log(bound(unpacked[index_mat]))
  out <- -1 * class_liks
  attributes(out)$name <- "NLL"
  return(out)
}

#' SQUARED ERROR LOSS FOR MULTIVARIATE (LOSS AVERAGED ACROSS OUTCOMES)
#'
#' @note Assumes predicted probabilities are "packed" into a single vector.
#'
#' @rdname loss_functions
#'
#' @export
loss_squared_error_multivariate <- function(pred, observed) {
  unpacked <- unpack_predictions(pred)
  losses <- rowMeans((unpacked - observed)^2)
  attributes(losses)$name <- "MSE"
  return(losses)
}

#' Risk Estimation
#'
#' Estimates a risk for a given set of predictions and loss function.
#'
#' @param pred A vector of predicted values.
#' @param observed A vector of observed values.
#' @param loss A loss function. For options, see \link{loss_functions}.
#' @param weights A vector of weights.
#'
#' @importFrom stats weighted.mean
#'
#' @export
risk <- function(pred, observed, loss = loss_squared_error, weights = NULL) {
  # get losses
  losses <- loss(pred, observed)

  # get weights
  if (is.null(weights)) {
    weights <- rep(1, length(observed))
  }

  # calculate risk
  risk <- stats::weighted.mean(losses, weights)
  return(risk)
}

#' FACTORY RISK FUNCTION FOR ROCR PERFORMANCE MEASURES WITH BINARY OUTCOMES
#'
#' Factory function for estimating an ROCR-based risk for a given ROCR measure,
#' and the risk is defined as one minus the performance measure.
#'
#' @note This risk does not take into account weights. In order to use this
#' risk, it must first be instantiated with respect to the \pkg{ROCR}
#' performance measure of interest, and then the user-defined function can be
#' used.
#'
#' @name risk_functions
#'
#' @param measure A character indicating which \code{ROCR} performance measure
#'  to use for evaluation. The \code{measure} must be either cutoff-dependent
#'  so a single value can be selected (e.g., "tpr"), or it's value is a scalar
#'  (e.g., "aucpr"). For more information, see \code{\link[ROCR]{performance}}.
#' @param cutoff A numeric value specifying the cutoff for choosing a single
#'  performance measure from the returned set. Only used for performance
#'  measures that are cutoff-dependent and default is 0.5. See
#'  \code{\link[ROCR]{performance}} for more detail.
#' @param name An optional character string for user to supply their desired
#'  name for the performance measure, which will be used for naming subsequent
#'  risk-related tables and metrics (e.g., \code{cv_risk} column names). When
#'  \code{name} is not supplied, the \code{measure} will be used for naming.
#' @param ... Optional arguments to specific \code{ROCR} performance
#'  measures. See \code{\link[ROCR]{performance}} for more detail.
#'
#' @rdname risk_functions
#'
#' @importFrom ROCR prediction performance
#'
#' @export
custom_ROCR_risk <- function(measure, cutoff = 0.5, name = NULL, ...) {
  # NOTE: arguments to factory-produced function goes undocumented
  function(pred, observed) {

    # remove NA, NaN, Inf values
    if (any(is.na(pred)) | any(is.nan(pred)) | any(is.infinite(pred))) {
      to_rm <- unique(which(is.na(pred) | is.nan(pred) | is.infinite(pred)))
      pred <- pred[-to_rm]
      observed <- observed[-to_rm]
    }

    # get ROCR performance object
    args <- list(...)
    args$measure <- measure
    args$prediction.obj <- ROCR::prediction(pred, observed)
    performance_object <- do.call(ROCR::performance, args)

    # get single performance value from the performance object
    performance <- unlist(performance_object@y.values)
    if (length(performance) != 1) {
      if (performance_object@x.name != "Cutoff") {
        stop(
          "Multiple performance values returned, but the measure is not",
          "cutoff-depedent. Check that the supplied measure is valid."
        )
      } else {
        # select the performance closest to the supplied cutoff
        performance <- performance[
          which.min(abs(unlist(performance_object@x.values) - cutoff))
        ]
      }
    }
    attributes(performance)$loss <- FALSE
    attributes(performance)$optimize <- "maximize"
    attributes(performance)$name <- ifelse(is.null(name), measure, name)
    return(performance)
  }
}

#' Cross-validated Risk Estimation
#'
#' Estimates the cross-validated risk for a given learner and evaluation
#' function, which can be either a loss or a risk function.
#'
#' @param learner A trained learner object.
#' @param eval_fun A valid loss or risk function. See
#'  \code{\link{loss_functions}} and \code{\link{risk_functions}}.
#' @param coefs A \code{numeric} vector of coefficients.
#'
#' @importFrom assertthat assert_that
#' @importFrom data.table data.table ":=" set setnames setorderv
#' @importFrom origami cross_validate validation fold_index
#' @importFrom stats sd
cv_risk <- function(learner, eval_fun = NULL, coefs = NULL) {
  assertthat::assert_that(
    "cv" %in% learner$properties,
    msg = "learner is not cv-aware"
  )

  task <- learner$training_task$revere_fold_task("validation")
  preds <- learner$predict_fold(task, "validation")
  if (!is.data.table(preds)) {
    preds <- data.table::data.table(preds)
    data.table::setnames(preds, names(preds), learner$name)
  }

  get_obsdata <- function(fold, task) {
    list(dt = data.table::data.table(
      fold_index = origami::fold_index(),
      index = origami::validation(),
      obs = origami::validation(task$Y),
      id = origami::validation(task$id),
      wts = origami::validation(task$weights)
    ))
  }
  dt <- origami::cross_validate(get_obsdata, task$folds, task)$dt
  data.table::setorderv(dt, c("index", "fold_index"))
  dt <- cbind(dt, preds)
  dt_long <- melt(
    dt,
    measure.vars = colnames(preds),
    variable.name = "learner",
    value.name = "pred"
  )

  eval_result <- eval_fun(dt_long[["pred"]], dt_long[["obs"]])

  if (!is.null(attr(eval_result, "loss")) && !attr(eval_result, "loss")) {
    # try stratifying by id
    eval_by_id <- tryCatch(
      {
        dt_long[, list(
          eval_result = eval_fun(pred, obs)
        ), by = list(learner, id, fold_index)]
      },
      error = function(c) {
        dt_long[, list(
          eval_result = eval_fun(pred, obs)
        ), by = list(learner, fold_index)]
      }
    )
  } else {
    dt_long <- cbind(dt_long, eval_result = dt_long[["wts"]] * eval_result)

    eval_by_id <- dt_long[, list(
      eval_result = mean(eval_result, na.rm = TRUE)
    ), by = list(learner, id, fold_index)]
  }

  # get learner-level evaluation statistics
  eval_stats <- eval_by_id[, list(
    coefficients = NA_real_,
    risk = mean(eval_result, na.rm = TRUE),
    se = (1 / sqrt(.N)) * stats::sd(eval_result)
  ), by = list(learner)]

  # get learner-level evaluation statistics by fold
  eval_fold_stats <- eval_by_id[, list(
    risk = mean(eval_result, na.rm = TRUE)
  ), by = list(learner, fold_index)]
  eval_stats_fold <- eval_fold_stats[, list(
    fold_sd = stats::sd(risk, na.rm = TRUE),
    fold_min_risk = min(risk, na.rm = TRUE),
    fold_max_risk = max(risk, na.rm = TRUE)
  ), by = list(learner)]

  risk_dt <- eval_stats <- merge(eval_stats, eval_stats_fold, by = "learner")

  if (!is.null(coefs)) {
    data.table::set(risk_dt, , "coefficients", coefs)
  }

  if (!is.null(attr(eval_result, "name"))) {
    colnames(risk_dt) <- gsub(
      "risk", attr(eval_result, "name"), colnames(risk_dt)
    )
  }
  return(risk_dt)
}
jeremyrcoyle/sl3 documentation built on Feb. 3, 2022, 9:12 a.m.