R/thresholder.R

Defines functions summ_stats stats recode expand_preds thresholder

Documented in thresholder

#' Generate Data to Choose a Probability Threshold
#' 
#' This function uses the resampling results from a \code{\link{train}}
#'  object to generate performance statistics over a set of probability
#'  thresholds for two-class problems. 
#' 
#' @param x A \code{\link{train}} object where the values of
#'  \code{savePredictions} was either \code{TRUE}, \code{"all"},
#'  or \code{"final"} in \code{\link{trainControl}}. Also, the 
#'  control argument \code{clasProbs} should have been \code{TRUE}.
#' @param threshold A numeric vector of candidate probability thresholds
#'  between [0,1]. If the class probability corresponding to the first
#'  level of the outcome is greater than the threshold, the data point
#'  is classified as that level. 
#' @param final A logical: should only the final tuning parameters
#'   chosen by \code{\link{train}} be used when 
#'   \code{savePredictions = 'all'}?
#' @param statistics A character vector indicating which statistics to
#'   calculate. See details below for possible choices; the default value
#'   \code{"all"} computes all of these.
#' @return A data frame with columns for each of the tuning parameters
#'  from the model along with an additional column called
#'  \code{prob_threshold} for the probability threshold. There are
#'  also columns for summary statistics averaged over resamples with
#'  column names corresponding to the input argument \code{statistics}. 
#' @details The argument \code{statistics} designates the statistics to compute
#'  for each probability threshold. One or more of the following statistics can
#'  be selected:
#'  \itemize{
#'  \item Sensitivity
#'  \item Specificity
#'  \item Pos Pred Value
#'  \item Neg Pred Value
#'  \item Precision
#'  \item Recall
#'  \item F1
#'  \item Prevalence
#'  \item Detection Rate
#'  \item Detection Prevalence
#'  \item Balanced Accuracy
#'  \item Accuracy
#'  \item Kappa
#'  \item J
#'  \item Dist
#' }
#' For a description of these statistics (except the last two), see the
#' documentation of \code{\link{confusionMatrix}}. The last two statistics
#' are Youden's J statistic and the distance to the best possible cutoff (i.e.
#' perfect sensitivity and specificity.
#' @export
#' @importFrom plyr ddply
#' @examples 
#' \dontrun{
#' set.seed(2444)
#' dat <- twoClassSim(500, intercept = -10)
#' table(dat$Class)
#' 
#' ctrl <- trainControl(method = "cv", 
#'                      classProbs = TRUE,
#'                      savePredictions = "all",
#'                      summaryFunction = twoClassSummary)
#' 
#' set.seed(2863)
#' mod <- train(Class ~ ., data = dat, 
#'              method = "rda",
#'              tuneLength = 4,
#'              metric = "ROC",
#'              trControl = ctrl)
#' 
#' resample_stats <- thresholder(mod, 
#'                               threshold = seq(.5, 1, by = 0.05), 
#'                               final = TRUE)
#' 
#' ggplot(resample_stats, aes(x = prob_threshold, y = J)) + 
#'   geom_point()
#' ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) + 
#'   geom_point()
#' ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) + 
#'   geom_point() + 
#'   geom_point(aes(y = Specificity), col = "red")
#' }
thresholder <- function(x, threshold, final = TRUE, statistics = "all") {
  if(!inherits(x, "train"))
    stop("`x` should be an object of class 'train'", 
         call. = FALSE)
  if (!x$control$classProbs)
    stop("`classProbs` must be TRUE in `trainControl`",
         call. = FALSE)
  if (is.null(threshold))
    stop("Please supply probability threshold values.",
         call. = FALSE)
  if (any(threshold > 1 | threshold < 0))
    stop("`threshold` should be on [0,1]", call. = FALSE)
  
  if (is.logical(x$control$savePredictions)) {
    if (!x$control$savePredictions)
      stop("`savePredictions` should be TRUE, 'all', or 'final'")
  } else {
    if (x$control$savePredictions == "none")
      stop("`savePredictions` should be TRUE, 'all', or 'final'")
  }
  if (length(levels(x$pred$obs)) > 2)
    stop("For two class problems only", call. = TRUE)
  
  stat_names <- c("Sensitivity", "Specificity", "Pos Pred Value",
                  "Neg Pred Value", "Precision", "Recall", "F1", "Prevalence",
                  "Detection Rate", "Detection Prevalence", "Balanced Accuracy",
                  "Accuracy", "Kappa", "J", "Dist")
  if (!any(statistics %in% c("all", stat_names)) ||
      ("all" %in% statistics && length(statistics) > 1))
    stop("`statistics` should be either 'all', or one or more of '",
         paste0(stat_names, collapse="', '"), "'.")

  if (length(statistics) == 1 && statistics == "all")
    statistics <- stat_names
  
  disc <- c("pred", "rowIndex", x$levels[-1])
  
  ## Expand the predicted values with the candidate values of
  ## the threshold
  pred_dat <- expand_preds(if (final)
    merge(x$pred, x$bestTune)
    else
      x$pred,
    threshold,
    disc)
  
  param <- c("Resample", names(x$bestTune), "prob_threshold")
  
  ## Based on the threshold, recode the predicted classes
  pred_dat <- ddply(pred_dat, .variables = param, recode)
  
  ## Compute statistics per threshold and tuning parameters
  pred_stats <- ddply(pred_dat, .variables = param, stats)
  
  ## Summarize over resamples
  pred_resamp <- ddply(pred_stats, .variables = param[-1],
                       summ_stats, statistics)
  pred_resamp
}

expand_preds <- function(df, th, excl = NULL) {
  th <- unique(th)
  nth <- length(th)
  ndf <- nrow(df)
  if (!is.null(excl))
    df <- df[, !(names(df) %in% excl), drop = FALSE]
  df <- df[rep(1:nrow(df), times = nth),]
  df$prob_threshold <- rep(th, each = ndf)
  df
}


recode <- function(dat) {
  lvl <- levels(dat$obs)
  dat$pred <- ifelse(dat[, lvl[1]] > dat$prob_threshold,
                     lvl[1], lvl[2])
  dat$pred <- factor(dat$pred, levels = lvl)
  dat
}

stats <- function(dat) {
  tab <- caret::confusionMatrix(dat$pred, dat$obs,
                                positive = levels(dat$obs)[1])
  res <- c(tab$byClass, tab$overall[c("Accuracy", "Kappa")])
  res <- c(res,
           res["Sensitivity"] + res["Specificity"] - 1,
           sqrt((res["Sensitivity"] - 1) ^ 2 + (res["Specificity"] - 1) ^ 2))
  names(res)[-seq_len(length(res) - 2)] <- c("J", "Dist")
  res
}

summ_stats <- function(x, cols) {
  na_cols <- apply(x, 2, function(x) any(is.na(x)))
  na_col_names <- colnames(x)[na_cols]
  relevant_col_names <- intersect(na_col_names, cols)
  if (length(relevant_col_names) > 0)
    warning("The following columns have missing values (NA), which have been ",
            "removed: '", paste0(relevant_col_names, collapse = "', '"),
            "'.\n")
  colMeans(x[, cols, drop = FALSE], na.rm = TRUE)
}

Try the caret package in your browser

Any scripts or data that you put into this service are public.

caret documentation built on Aug. 9, 2022, 5:11 p.m.