R/generateThreshVsPerf.R

Defines functions plotROCCurves plotThreshVsPerf generateThreshVsPerfData.list generateThreshVsPerfData.BenchmarkResult generateThreshVsPerfData.ResampleResult generateThreshVsPerfData.Prediction generateThreshVsPerfData

Documented in generateThreshVsPerfData plotROCCurves plotThreshVsPerf

#' @title Generate threshold vs. performance(s) for 2-class classification.
#'
#' @description
#' Generates data on threshold vs. performance(s) for 2-class classification that can be used for plotting.
#'
#' @family generate_plot_data
#' @family thresh_vs_perf
#' @aliases ThreshVsPerfData
#'
#' @template arg_plotroc_obj
#' @template arg_measures
#' @param gridsize (`integer(1)`)\cr
#'   Grid resolution for x-axis (threshold).
#'   Default is 100.
#' @param aggregate (`logical(1)`)\cr
#'   Whether to aggregate [ResamplePrediction]s or to plot the performance
#'   of each iteration separately.
#'   Default is `TRUE`.
#' @param task.id (`character(1)`)\cr
#'   Selected task in [BenchmarkResult] to do plots for, ignored otherwise.
#'   Default is first task.
#' @return ([ThreshVsPerfData]). A named list containing the measured performance
#'   across the threshold grid, the measures, and whether the performance estimates were
#'   aggregated (only applicable for (list of) [ResampleResult]s).
#' @export
generateThreshVsPerfData = function(obj, measures, gridsize = 100L, aggregate = TRUE, task.id = NULL) {
  UseMethod("generateThreshVsPerfData")
}
#' @export
generateThreshVsPerfData.Prediction = function(obj, measures, gridsize = 100L, aggregate = TRUE,
  task.id = NULL) {
  checkPrediction(obj, task.type = "classif", binary = TRUE, predict.type = "prob")
  generateThreshVsPerfData.list(namedList("prediction", obj), measures, gridsize, aggregate, task.id)
}
#' @export
generateThreshVsPerfData.ResampleResult = function(obj, measures, gridsize = 100L, aggregate = TRUE,
  task.id = NULL) {
  obj = getRRPredictions(obj)
  checkPrediction(obj, task.type = "classif", binary = TRUE, predict.type = "prob")
  generateThreshVsPerfData.Prediction(obj, measures, gridsize, aggregate)
}
#' @export
generateThreshVsPerfData.BenchmarkResult = function(obj, measures, gridsize = 100L, aggregate = TRUE,
  task.id = NULL) {
  tids = getBMRTaskIds(obj)
  if (is.null(task.id)) {
    task.id = tids[1L]
  } else {
    assertChoice(task.id, tids)
  }
  obj = getBMRPredictions(obj, task.ids = task.id, as.df = FALSE)[[1L]]

  for (x in obj) {
    checkPrediction(x, task.type = "classif", binary = TRUE, predict.type = "prob")
  }
  generateThreshVsPerfData.list(obj, measures, gridsize, aggregate, task.id)
}
#' @export
generateThreshVsPerfData.list = function(obj, measures, gridsize = 100L, aggregate = TRUE, task.id = NULL) {

  assertList(obj, c("Prediction", "ResampleResult"), min.len = 1L)
  ## unwrap ResampleResult to Prediction and set default names
  if (inherits(obj[[1L]], "ResampleResult")) {
    if (is.null(names(obj))) {
      names(obj) = extractSubList(obj, "learner.id")
    }
    obj = extractSubList(obj, "pred", simplify = FALSE)
  }

  assertList(obj, names = "unique")
  td = extractSubList(obj, "task.desc", simplify = FALSE)[[1L]]
  measures = checkMeasures(measures, td)
  mids = replaceDupeMeasureNames(measures, "id")
  names(measures) = mids
  grid = data.frame(threshold = seq(0, 1, length.out = gridsize))
  resamp = all(vlapply(obj, function(x) inherits(x, "ResamplePrediction")))
  out = lapply(obj, function(x) {
    do.call("rbind", lapply(grid$threshold, function(th) {
      pp = setThreshold(x, threshold = th)
      if (!aggregate && resamp) {
        iter = seq_len(pp$instance$desc$iters)
        asMatrixRows(lapply(iter, function(i) {
          pp$data = pp$data[pp$data$iter == i, ]
          c(setNames(performance(pp, measures = measures), mids), "iter" = i, "threshold" = th)
        }))
      } else {
        c(setNames(performance(pp, measures = measures), mids), "threshold" = th)
      }
    }))
  })

  if (length(obj) == 1L && inherits(obj[[1L]], "Prediction")) {
    out = out[[1L]]
    colnames(out)[!colnames(out) %in% c("iter", "threshold", "learner")] = mids
  } else {
    out = setDF(rbindlist(lapply(out, as.data.table), fill = TRUE, idcol = "learner", use.names = TRUE))
    colnames(out)[!colnames(out) %in% c("iter", "threshold", "learner")] = mids
  }

  makeS3Obj("ThreshVsPerfData",
    measures = measures,
    data = as.data.frame(out),
    aggregate = aggregate)
}

#' @title Plot threshold vs. performance(s) for 2-class classification using ggplot2.
#'
#' @description
#' Plots threshold vs. performance(s) data that has been generated with [generateThreshVsPerfData].
#'
#' @family plot
#' @family thresh_vs_perf
#'
#' @param obj ([ThreshVsPerfData])\cr
#'   Result of [generateThreshVsPerfData].
#' @param measures ([Measure] | list of [Measure])\cr
#'   Performance measure(s) to plot.
#'   Must be a subset of those used in [generateThreshVsPerfData].
#'   Default is all the measures stored in `obj` generated by
#'   [generateThreshVsPerfData].
#' @param facet (`character(1)`)\cr
#'   Selects \dQuote{measure} or \dQuote{learner} to be the facetting variable.
#'   The variable mapped to `facet` must have more than one unique value, otherwise it will
#'   be ignored. The variable not chosen is mapped to color if it has more than one unique value.
#'   The default is \dQuote{measure}.
#' @param mark.th (`numeric(1)`)\cr
#'   Mark given threshold with vertical line?
#'   Default is `NA` which means not to do it.
#' @param pretty.names (`logical(1)`)\cr
#'   Whether to use the [Measure] name instead of the id in the plot.
#'   Default is `TRUE`.
#' @template arg_facet_nrow_ncol
#' @template ret_gg2
#' @export
#' @examples
#' lrn = makeLearner("classif.rpart", predict.type = "prob")
#' mod = train(lrn, sonar.task)
#' pred = predict(mod, sonar.task)
#' pvs = generateThreshVsPerfData(pred, list(acc, setAggregation(acc, train.mean)))
#' plotThreshVsPerf(pvs)
plotThreshVsPerf = function(obj, measures = obj$measures,
  facet = "measure", mark.th = NA_real_,
  pretty.names = TRUE, facet.wrap.nrow = NULL, facet.wrap.ncol = NULL) {

  assertClass(obj, classes = "ThreshVsPerfData")
  mappings = c("measure", "learner")
  assertChoice(facet, mappings)
  color = mappings[mappings != facet]
  measures = checkMeasures(measures, obj)
  checkSubset(extractSubList(measures, "id"), extractSubList(obj$measures, "id"))
  mids = replaceDupeMeasureNames(measures, "id")
  names(measures) = mids

  id.vars = "threshold"
  resamp = "iter" %in% colnames(obj$data)
  if (resamp) id.vars = c(id.vars, "iter")
  if ("learner" %in% colnames(obj$data)) id.vars = c(id.vars, "learner")
  obj$data = obj$data[, c(id.vars, names(measures))]

  if (pretty.names) {
    mnames = replaceDupeMeasureNames(measures, "name")
    colnames(obj$data) = mapValues(colnames(obj$data), names(measures), mnames)
  } else {
    mnames = names(measures)
  }

  data = setDF(melt(as.data.table(obj$data), measure.vars = mnames, variable.name = "measure", value.name = "performance", id.vars = id.vars))
  if (!is.null(data$learner)) {
    nlearn = length(unique(data$learner))
  } else {
    nlearn = 1L
  }
  nmeas = length(unique(data$measure))

  if ((color == "learner" && nlearn == 1L) || (color == "measure" && nmeas == 1L)) {
    color = NULL
  }

  if ((facet == "learner" && nlearn == 1L) || (facet == "measure" && nmeas == 1L)) {
    facet = NULL
  }

  if (resamp && !obj$aggregate && is.null(color)) {
    group = "iter"
  } else if (resamp && !obj$aggregate && !is.null(color)) {
    data$int = interaction(data[["iter"]], data[[color]])
    group = "int"
  } else {
    group = NULL
  }

  plt = ggplot(data, aes_string(x = "threshold", y = "performance"))
  plt = plt + geom_line(aes_string(group = group, color = color))

  if (!is.na(mark.th)) {
    plt = plt + geom_vline(xintercept = mark.th)
  }

  if (!is.null(facet)) {
    plt = plt + facet_wrap(facet, scales = "free_y", nrow = facet.wrap.nrow,
      ncol = facet.wrap.ncol)
  }
  else if (length(obj$measures) == 1L) {
    plt = plt + ylab(obj$measures[[1]]$name)
  } else {
    plt = plt + ylab("performance")
  }

  return(plt)
}
#' @title Plots a ROC curve using ggplot2.
#'
#' @description
#' Plots a ROC curve from predictions.
#'
#' @family plot
#' @family thresh_vs_perf
#'
#' @param obj ([ThreshVsPerfData])\cr
#'   Result of [generateThreshVsPerfData].
#' @param measures ([list(2)` of [Measure])\cr
#'   Default is the first 2 measures passed to [generateThreshVsPerfData].
#' @param diagonal (`logical(1)`)\cr
#'   Whether to plot a dashed diagonal line.
#'   Default is `TRUE`.
#' @param pretty.names (`logical(1)`)\cr
#'   Whether to use the [Measure] name instead of the id in the plot.
#'   Default is `TRUE`.
#' @param facet.learner (`logical(1)`)\cr
#'   Weather to use facetting or different colors to compare multiple learners.
#'   Default is `FALSE`.
#' @template ret_gg2
#' @export
#' @examples
#' \donttest{
#' lrn = makeLearner("classif.rpart", predict.type = "prob")
#' fit = train(lrn, sonar.task)
#' pred = predict(fit, task = sonar.task)
#' roc = generateThreshVsPerfData(pred, list(fpr, tpr))
#' plotROCCurves(roc)
#'
#' r = bootstrapB632plus(lrn, sonar.task, iters = 3)
#' roc_r = generateThreshVsPerfData(r, list(fpr, tpr), aggregate = FALSE)
#' plotROCCurves(roc_r)
#'
#' r2 = crossval(lrn, sonar.task, iters = 3)
#' roc_l = generateThreshVsPerfData(list(boot = r, cv = r2), list(fpr, tpr), aggregate = FALSE)
#' plotROCCurves(roc_l)
#' }
plotROCCurves = function(obj, measures, diagonal = TRUE, pretty.names = TRUE, facet.learner = FALSE) {

  assertClass(obj, "ThreshVsPerfData")

  if (missing(measures)) {
    measures = obj$measures[1:2]
  }

  assertList(measures, "Measure", len = 2)
  assertFlag(diagonal)
  assertFlag(pretty.names)
  assertFlag(facet.learner)

  if (is.null(names(measures))) {
    names(measures) = extractSubList(measures, "id")
  }

  if (pretty.names) {
    mnames = replaceDupeMeasureNames(measures, "name")
  } else {
    mnames = names(measures)
  }

  if (!is.null(obj$data$learner)) {
    mlearn = length(unique(obj$data$learner)) > 1L
  } else {
    mlearn = FALSE
  }
  resamp = "iter" %in% colnames(obj$data)


  aes = list(x = names(measures)[1], y = names(measures)[2])

  if (!obj$aggregate && mlearn && resamp) {
    obj$data$int = interaction(obj$data$learner, obj$data$iter)
    aes$group = "int"
  } else if (!obj$aggregate && !mlearn && resamp) {
    aes$group = "iter"
  } else if (obj$aggregate && mlearn && !resamp) {
    aes$group = "learner"
  } else {
    obj$data = obj$data[order(obj$data$threshold), ]
  }

  if (mlearn && !facet.learner) {
    aes$color = "learner"
  }

  p = ggplot(obj$data, do.call(aes_string, aes)) + geom_path() + labs(x = mnames[1], y = mnames[2])

  if (mlearn && facet.learner) {
    p = p + facet_wrap(~learner)
  }

  if (diagonal && all(vlapply(obj$data[, names(measures)], function(x) max(x, na.rm = TRUE) <= 1))) {
    p = p + geom_abline(aes(intercept = 0, slope = 1), linetype = "dashed", alpha = .5)
  }
  p
}
mlr-org/mlr documentation built on Jan. 12, 2023, 5:16 a.m.