R/generateCalibration.R

#' @title Generate classifier calibration data.
#'
#' @description
#' A calibrated classifier is one where the predicted probability of a class closely matches the
#' rate at which that class occurs, e.g. for data points which are assigned a predicted probability
#' of class A of .8, approximately 80 percent of such points should belong to class A if the classifier
#' is well calibrated. This is estimated empirically by grouping data points with similar predicted
#' probabilities for each class, and plotting the rate of each class within each bin against the
#' predicted probability bins.
#'
#' @family generate_plot_data
#' @family calibration
#' @aliases CalibrationData
#'
#' @template arg_plotroc_obj
#' @param breaks [\code{character(1)} | \code{numeric}]\cr
#'   If \code{character(1)}, the algorithm to use in generating probability bins.
#'   See \code{\link{hist}} for details.
#'   If \code{numeric}, the cut points for the bins.
#'   Default is \dQuote{Sturges}.
#' @param groups [\code{integer(1)}]\cr
#'   The number of bins to construct.
#'   If specified, \code{breaks} is ignored.
#'   Default is \code{NULL}.
#' @param task.id [\code{character(1)}]\cr
#'   Selected task in \code{\link{BenchmarkResult}} to do plots for, ignored otherwise.
#'   Default is first task.
#'
#' @return [CalibrationData]. A \code{list} containing:
#'   \item{proportion}{[\code{data.frame}] with columns:
#'     \itemize{
#'       \item \code{Learner} Name of learner.
#'       \item \code{bin} Bins calculated according to the \code{breaks} or \code{groups} argument.
#'       \item \code{Class} Class labels (for binary classification only the positive class).
#'       \item \code{Proportion} Proportion of observations from class \code{Class} among all
#'         observations with posterior probabilities of class \code{Class} within the
#'         interval given in \code{bin}.
#'     }}
#'   \item{data}{[\code{data.frame}] with columns:
#'     \itemize{
#'       \item \code{Learner} Name of learner.
#'       \item \code{truth} True class label.
#'       \item \code{Class} Class labels (for binary classification only the positive class).
#'       \item \code{Probability} Predicted posterior probability of \code{Class}.
#'       \item \code{bin} Bin corresponding to \code{Probability}.
#'     }}
#'   \item{task}{[\code{\link{TaskDesc}}]\cr
#'     Task description.}
#'
#' @references Vuk, Miha, and Curk, Tomaz. \dQuote{ROC Curve, Lift Chart, and Calibration Plot.} Metodoloski zvezki. Vol. 3. No. 1 (2006): 89-108.
#' @export
generateCalibrationData = function(obj, breaks = "Sturges", groups = NULL, task.id = NULL)
  UseMethod("generateCalibrationData")
#' @export
generateCalibrationData.Prediction = function(obj, breaks = "Sturges", groups = NULL, task.id = NULL) {
  checkPrediction(obj, task.type = "classif", predict.type = "prob")
  generateCalibrationData.list(namedList("prediction", obj), breaks, groups, task.id)
}
#' @export
generateCalibrationData.ResampleResult = function(obj, breaks = "Sturges", groups = NULL, task.id = NULL) {
  obj = getRRPredictions(obj)
  checkPrediction(obj, task.type = "classif", predict.type = "prob")
  generateCalibrationData.Prediction(obj, breaks, groups, task.id)
}
#' @export
generateCalibrationData.BenchmarkResult = function(obj, breaks = "Sturges", groups = NULL, task.id = NULL) {
  tids = getBMRTaskIds(obj)
  if (is.null(task.id))
    task.id = tids[1L]
  else
    assertChoice(task.id, tids)
  obj = getBMRPredictions(obj, task.ids = task.id, as.df = FALSE)[[1L]]

  for (x in obj)
    checkPrediction(x, task.type = "classif", predict.type = "prob")
  generateCalibrationData.list(obj, breaks, groups, task.id)
}
#' @export
generateCalibrationData.list = function(obj, breaks = "Sturges", groups = NULL, task.id = NULL) {
  assertList(obj, c("Prediction", "ResampleResult"), min.len = 1L)
  ## unwrap ResampleResult to Prediction and set default names
  if (inherits(obj[[1L]], "ResampleResult")) {
    if (is.null(names(obj)))
      names(obj) = extractSubList(obj, "learner.id")
    obj = extractSubList(obj, "pred", simplify = FALSE)
  }
  assertList(obj, names = "unique")
  td = obj[[1L]]$task.desc

  out = lapply(obj, function(pred) {
    df = data.table("truth" = getPredictionTruth(pred),
                    getPredictionProbabilities(pred, cl = getTaskClassLevels(td)))
    df = melt(df, id.vars = "truth", value.name = "Probability", variable.name = "Class")

    if (is.null(groups)) {
      break.points = hist(df$Probability, breaks = breaks, plot = FALSE)$breaks
      df$bin = cut(df$Probability, break.points, include.lowest = TRUE, ordered_results = TRUE)
    } else {
      requirePackages("Hmisc", default.method = "load", why = "Equal width binning of probabilities.")
      assertInt(groups, lower = 2, upper = length(unique(df$Probability)))
      df$bin = Hmisc::cut2(df$Probability, g = groups, digits = 3)
    }
    fun = function(x) {
      tab = table(x$Class, x$truth)
      s = rowSums(tab)
      as.list(ifelse(s == 0, 0, diag(tab) / s))
    }
    list(data = df, proportion = df[, fun(.SD), by = "bin"])
  })
  data = rbindlist(lapply(out, function(x) x$data), idcol = "Learner")
  proportion = rbindlist(lapply(out, function(x) x$proportion), idcol = "Learner")
  if (length(td$class.levels) == 2L) {
    proportion = proportion[, !td$negative, with = FALSE]
    data = data[data$Class != td$negative, ]
  }
  max.bin = sapply(stri_split(levels(proportion$bin), regex = ",|]|\\)"),
                   function(x) as.numeric(x[length(x)]))
  proportion$bin = ordered(proportion$bin, levels = levels(proportion$bin)[order(max.bin)])
  proportion = melt(proportion, id.vars = c("Learner", "bin"), value.name = "Proportion", variable.name = "Class")
  data$bin = ordered(data$bin, levels = levels(data$bin)[order(max.bin)])
  setDF(data)
  setDF(proportion)

  makeS3Obj("CalibrationData",
            proportion = proportion,
            data = data,
            task = td)
}
#' @title Plot calibration data using ggplot2.
#'
#' @description
#' Plots calibration data from \code{\link{generateCalibrationData}}.
#'
#' @family plot
#' @family calibration
#'
#' @param obj [\code{CalibrationData}]\cr
#'   Result of \code{\link{generateCalibrationData}}.
#' @param smooth [\code{logical(1)}]\cr
#'   Whether to use a loess smoother.
#'   Default is \code{FALSE}.
#' @param reference [\code{logical(1)}]\cr
#'   Whether to plot a reference line showing perfect calibration.
#'   Default is \code{TRUE}.
#' @param rag [\code{logical(1)}]\cr
#'   Whether to include a rag plot which shows a rug plot on the top which pertains to
#'   positive cases and on the bottom which pertains to negative cases.
#'   Default is \code{TRUE}.
#' @template arg_facet_nrow_ncol
#' @template ret_gg2
#' @export
#' @examples
#' \dontrun{
#' lrns = list(makeLearner("classif.rpart", predict.type = "prob"),
#'             makeLearner("classif.nnet", predict.type = "prob"))
#' fit = lapply(lrns, train, task = iris.task)
#' pred = lapply(fit, predict, task = iris.task)
#' names(pred) = c("rpart", "nnet")
#' out = generateCalibrationData(pred, groups = 3)
#' plotCalibration(out)
#'
#' fit = lapply(lrns, train, task = sonar.task)
#' pred = lapply(fit, predict, task = sonar.task)
#' names(pred) = c("rpart", "lda")
#' out = generateCalibrationData(pred)
#' plotCalibration(out)
#' }
plotCalibration = function(obj, smooth = FALSE, reference = TRUE, rag = TRUE, facet.wrap.nrow = NULL, facet.wrap.ncol = NULL) {
  assertClass(obj, "CalibrationData")
  assertFlag(smooth)
  assertFlag(reference)
  assertFlag(rag)

  obj$proportion$xend = length(levels(obj$proportion$bin))

  p = ggplot(obj$proportion, aes_string("bin", "Proportion", color = "Class", group = "Class"))
  p = p + scale_x_discrete(drop = FALSE)

  if (smooth)
    p = p + stat_smooth(se = FALSE, span = 2, method = "loess")
  else
    p = p + geom_point() + geom_line()

  if (length(unique(obj$proportion$Learner)) > 1L) {
    p = p + facet_wrap(~ Learner, nrow = facet.wrap.nrow, ncol = facet.wrap.ncol)
  }

  if (reference)
    p = p + geom_segment(aes_string(1, 0, xend = "xend", yend = 1), colour = "black", linetype = "dashed")

  if (rag) {
    top.data = obj$data[obj$data$truth == obj$data$Class, ]
    top.data$x = jitter(as.numeric(top.data$bin))
    p = p + geom_rug(data = top.data, aes_string("x", y = 1), sides = "t", alpha = .25)
    bottom.data = obj$data[obj$data$truth != obj$data$Class, ]
    bottom.data$x = jitter(as.numeric(bottom.data$bin))
    p = p + geom_rug(data = bottom.data, aes_string("x", y = 1), sides = "b", alpha = .25)
  }
  p = p + labs(x = "Probability Bin", y = "Class Proportion")
  p + theme(axis.text.x = element_text(angle = 90, hjust = 1))
}
shuodata/mlr-master documentation built on May 20, 2019, 3:33 p.m.