thresholder: Generate Data to Choose a Probability Threshold

View source: R/thresholder.R

thresholderR Documentation

Generate Data to Choose a Probability Threshold

Description

This function uses the resampling results from a train object to generate performance statistics over a set of probability thresholds for two-class problems.

Usage

thresholder(x, threshold, final = TRUE, statistics = "all")

Arguments

x

A train object where the values of savePredictions was either TRUE, "all", or "final" in trainControl. Also, the control argument clasProbs should have been TRUE.

threshold

A numeric vector of candidate probability thresholds between [0,1]. If the class probability corresponding to the first level of the outcome is greater than the threshold, the data point is classified as that level.

final

A logical: should only the final tuning parameters chosen by train be used when savePredictions = 'all'?

statistics

A character vector indicating which statistics to calculate. See details below for possible choices; the default value "all" computes all of these.

Details

The argument statistics designates the statistics to compute for each probability threshold. One or more of the following statistics can be selected:

  • Sensitivity

  • Specificity

  • Pos Pred Value

  • Neg Pred Value

  • Precision

  • Recall

  • F1

  • Prevalence

  • Detection Rate

  • Detection Prevalence

  • Balanced Accuracy

  • Accuracy

  • Kappa

  • J

  • Dist

For a description of these statistics (except the last two), see the documentation of confusionMatrix. The last two statistics are Youden's J statistic and the distance to the best possible cutoff (i.e. perfect sensitivity and specificity.

Value

A data frame with columns for each of the tuning parameters from the model along with an additional column called prob_threshold for the probability threshold. There are also columns for summary statistics averaged over resamples with column names corresponding to the input argument statistics.

Examples

## Not run: 
set.seed(2444)
dat <- twoClassSim(500, intercept = -10)
table(dat$Class)

ctrl <- trainControl(method = "cv", 
                     classProbs = TRUE,
                     savePredictions = "all",
                     summaryFunction = twoClassSummary)

set.seed(2863)
mod <- train(Class ~ ., data = dat, 
             method = "rda",
             tuneLength = 4,
             metric = "ROC",
             trControl = ctrl)

resample_stats <- thresholder(mod, 
                              threshold = seq(.5, 1, by = 0.05), 
                              final = TRUE)

ggplot(resample_stats, aes(x = prob_threshold, y = J)) + 
  geom_point()
ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) + 
  geom_point()
ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) + 
  geom_point() + 
  geom_point(aes(y = Specificity), col = "red")

## End(Not run)

caret documentation built on March 31, 2023, 9:49 p.m.