R/individual_plots.R

Defines functions recall_plot accuracy_plot propensity_plot sigmoid check_classifier_input_and_init check_predictions

Documented in accuracy_plot propensity_plot recall_plot sigmoid

check_predictions <- function(pred.prob) {
  
  if(max(pred.prob) > 1) {
      stop(paste("Pred.prob not in [0,1]. Max:", max(pred.prob),
        ". You can use the sigmoid(x) function in this package to map to [0,1]."))
  } 
  
  if(min(pred.prob) < 0) {
      stop(paste("Pred.prob not in [0,1]. Min:", min(pred.prob),
        ". You can use the sigmoid(x) function in this package to map to [0,1]."))
  } 
}

check_classifier_input_and_init <- function(test.y, pred.prob) {

  if(length(test.y) != length(pred.prob)) {
    stop(paste("Length of test.y:", length(test.y), "did not match pred.prob:", length(pred.prob)))
  } 
  yvals <- unique(test.y)
  if(length(yvals) != 2) {
    stop(paste("test.y had more than 2 unique values:", length(yvals)))
  }
  if(sum(yvals == 1.0) != 1) {
    stop(paste("This code expects test.y to be numerical, with the positive class indicated by '1'. There was no 1 in test.y!"))
  }
  
  check_predictions(pred.prob)
}

#' @title sigmoid
#' @description Logistic sigmoid function, that maps any real number to the [0,1] interval. Supports vectors of numeric.
#' @param x data
#' @export
sigmoid <- function(x) { 1.0/(1.0+exp(-x)) }

#' @title propensity_plot
#' @description Returns a ggplot2 plot object containing an propensity @@ percentile plot
#' @param test.y List of know labels on the test set
#' @param pred.prob List of probability predictions on the test set
#' @param granularity Default 0.02, probability step between points in plot.
#' @export
propensity_plot <- function(test.y, pred.prob, granularity=0.02) {
  check_classifier_input_and_init(test.y, pred.prob)  
  step_array <- seq(0.0, 1.0, by=granularity)
  thesh_steps <- round(quantile(pred.prob, step_array), digits=4)
  pred.order <- order(pred.prob,  decreasing=T)
  
  propensity_tbl_perc <- data.table(
    part=1:length(step_array), percentage=100 - 100*step_array, 
    threshold=thesh_steps, step_array=step_array)
  propensity_tbl_perc[, propensity := 
    propensity_at_threshold(test.y, pred.prob, part, pred.order, thesh_steps), by=c("part")]
  
  return(ggplot(propensity_tbl_perc, aes(x=percentage, y=100.0*propensity)) + 
    geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 
    scale_x_continuous(name="Instance decile (non-cumulative %)", breaks=seq(0.0, 100.0, 10.0)) + 
    scale_y_continuous(name="Smoothed positive (%)") + 
    ggtitle("Positive rate (rolling window)"))
}

#' @title accuracy_plot
#' @description Returns a ggplot2 plot object containing an accuracy @@ percentile plot
#' @param test.y List of know labels on the test set
#' @param pred.prob List of probability predictions on the test set
#' @param granularity Default 0.02, probability step between points in plot.
#' @param show_numbers Show values as numbers above the plot line
#' @export
accuracy_plot <- function(test.y, pred.prob, granularity=0.02, show_numbers=T) {
  check_classifier_input_and_init(test.y, pred.prob)  
  step_array <- seq(0.0, 1.0, by=granularity)
  thesh_steps <- round(quantile(pred.prob, step_array), digits=4)
  accuracy_tbl_perc <- data.table(percentage=100 - 100*step_array, threshold=thesh_steps)
  accuracy_tbl_perc[, accuracy := sapply(threshold, function(x) accuracy_at_threshold(x, test.y, pred.prob))]
  accuracy_tbl_perc[, accuracy_lb := sapply(threshold, function(x) accuracy_at_threshold_p(0.025, x, test.y, pred.prob))]
  accuracy_tbl_perc[, accuracy_ub := sapply(threshold, function(x) accuracy_at_threshold_p(0.975, x, test.y, pred.prob))]
  
  if(show_numbers) {
    deciles <- seq(0, 100, 10)
    accuracy_tbl_perc[percentage %in% deciles, dec_lbl := paste0(format(100*accuracy, digits=2), "%")]
    numbers <- geom_text(aes(x=percentage, y=102*accuracy, label=dec_lbl), 
              hjust=0.3, vjust=-1.0, size=4, color=I(blue_str))
  } else {
    numbers <- NULL
  }

  return(ggplot(accuracy_tbl_perc, aes(x=percentage, y=100.0*accuracy)) + 
    geom_ribbon(aes(ymin=100.0*accuracy_lb, ymax=100.0*accuracy_ub), fill=green_str, alpha=0.2) + 
    geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours + 
    scale_x_continuous(name="k% (thresholded to positive class)", breaks=seq(0.0, 100.0, 10.0)) + 
    scale_y_continuous(name="Accuracy (%)", limits=c(0,100), breaks=seq(0.0, 100.0, 10.0)) + 
    numbers + 
    ggtitle("Accuracy @ k"))
}

#' @title recall_plot
#' @description Returns a ggplot2 plot object containing an sensitivity @@ percentile plot
#' @param test.y List of know labels on the test set
#' @param pred.prob List of probability predictions on the test set
#' @param granularity Default 0.02, probability step between points in plot.
#' @param show_numbers Show numbers at deciles T/F default T.
#' @export
recall_plot <- function(test.y, pred.prob, granularity=0.02, show_numbers=T) {
  check_classifier_input_and_init(test.y, pred.prob)  
  step_array <- seq(0.0, 1.0, by=granularity)
  thesh_steps <- round(quantile(pred.prob, step_array), digits=4)
  tbl <- data.table(percentage=100 - 100*step_array, threshold=thesh_steps)
  tbl[, sensitivity := sapply(threshold, function(x) sensitivity_at_threshold(x, test.y, pred.prob))]
  tbl[, sensitivity_lb := sapply(threshold, function(x) sensitivity_at_threshold_p(0.025, x, test.y, pred.prob))]
  tbl[, sensitivity_ub := sapply(threshold, function(x) sensitivity_at_threshold_p(0.975, x, test.y, pred.prob))]
  
  if(show_numbers) {
    deciles <- seq(10, 100, 10)
    tbl[percentage %in% deciles, dec_lbl := paste0(format(100*sensitivity, digits=2), "%")]
    numbers <- geom_text(aes(x=percentage, y=100*sensitivity+2*sensitivity, label=dec_lbl), 
              hjust=0.3, vjust=3.0, size=4, color=I(blue_str))
  } else {
    numbers <- NULL
  }
  
  return(ggplot(tbl, aes(x=percentage, y=100.0*sensitivity)) + 
    geom_ribbon(aes(ymin=100.0*sensitivity_lb, ymax=100.0*sensitivity_ub), fill=green_str, alpha=0.2) + 
    geom_line(color=green_str, size=1.5) + classifier_theme + classifier_colours +
    scale_x_continuous(name="k% (thresholded to positive class)", breaks=seq(0.0, 100.0, 10.0), limits=c(0,100), expand=c(0, 0.3)) + 
    scale_y_continuous(name="Recall (%)", breaks=seq(0.0, 100.0, 10.0), limits=c(0,100), expand=c(0, 0.3)) + 
    numbers + 
    ggtitle("Recall @ k"))
}

# Variables used in data.table expressions have to be defined here
utils::globalVariables(c(
  "Prediction", "Ground Truth", "accuracy", "threshold", 
  "precision", "sensitivity", "percentage", "fpr", "tpr",
  "propensity", "positive_perc", "bucket", "dec_lbl", "part",
  "ymin", "ymax", "sensitivity_lb", "sensitivity_ub"))

Try the classifierplots package in your browser

Any scripts or data that you put into this service are public.

classifierplots documentation built on Jan. 13, 2021, 5:23 p.m.