thresholder: Generate Data to Choose a Probability Threshold

View source: R/thresholder.R

thresholderR Documentation

Generate Data to Choose a Probability Threshold


This function uses the resampling results from a train object to generate performance statistics over a set of probability thresholds for two-class problems.


thresholder(x, threshold, final = TRUE, statistics = "all")



A train object where the values of savePredictions was either TRUE, "all", or "final" in trainControl. Also, the control argument clasProbs should have been TRUE.


A numeric vector of candidate probability thresholds between [0,1]. If the class probability corresponding to the first level of the outcome is greater than the threshold, the data point is classified as that level.


A logical: should only the final tuning parameters chosen by train be used when savePredictions = 'all'?


A character vector indicating which statistics to calculate. See details below for possible choices; the default value "all" computes all of these.


The argument statistics designates the statistics to compute for each probability threshold. One or more of the following statistics can be selected:

  • Sensitivity

  • Specificity

  • Pos Pred Value

  • Neg Pred Value

  • Precision

  • Recall

  • F1

  • Prevalence

  • Detection Rate

  • Detection Prevalence

  • Balanced Accuracy

  • Accuracy

  • Kappa

  • J

  • Dist

For a description of these statistics (except the last two), see the documentation of confusionMatrix. The last two statistics are Youden's J statistic and the distance to the best possible cutoff (i.e. perfect sensitivity and specificity.


A data frame with columns for each of the tuning parameters from the model along with an additional column called prob_threshold for the probability threshold. There are also columns for summary statistics averaged over resamples with column names corresponding to the input argument statistics.


## Not run: 
dat <- twoClassSim(500, intercept = -10)

ctrl <- trainControl(method = "cv", 
                     classProbs = TRUE,
                     savePredictions = "all",
                     summaryFunction = twoClassSummary)

mod <- train(Class ~ ., data = dat, 
             method = "rda",
             tuneLength = 4,
             metric = "ROC",
             trControl = ctrl)

resample_stats <- thresholder(mod, 
                              threshold = seq(.5, 1, by = 0.05), 
                              final = TRUE)

ggplot(resample_stats, aes(x = prob_threshold, y = J)) + 
ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) + 
ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) + 
  geom_point() + 
  geom_point(aes(y = Specificity), col = "red")

## End(Not run)

caret documentation built on March 31, 2023, 9:49 p.m.