R/plotThreshVsPerf.R

#' @title Plot threshold vs. performance(s) for 2-class classification using ggplot2.
#'
#' @description
#' Plots threshold vs. performance(s) data that has been generated with generateThreshVsPerfData.
#'
#'
#' @param performanceByThreshold object (data.table) generated by generateThreshVsPerfData.
#' @export
#' @examples
#' 1
plotThreshVsPerf = function(performanceByThreshold) {

  facet_x = "measure"
  facet_y = if ("task_id" %in% colnames(performanceByThreshold)) "task_id" else NULL
  colorVar = if ("learner_id" %in% colnames(performanceByThreshold)) "learner_id" else NULL

  idVars = intersect(c("threshold", "hash", "task_id", "learner_id", "resampling_id"), colnames(performanceByThreshold))
  dataForPlot = reshape2::melt(performanceByThreshold, id.vars = idVars)
  colnames(dataForPlot)[c((ncol(dataForPlot) - 1):ncol(dataForPlot))] = c("measure", "performance")
  plt = ggplot2::ggplot(dataForPlot, aes_string(x = "threshold", y = "performance"))
  plt = plt + geom_line(aes_string(colour = colorVar))

  if (is.null(facet_y)) {
    plt = plt + facet_wrap((paste0("~", facet_x)), scales = "free_y")
  } else {
    plt = plt + facet_grid((paste0(facet_y, "~", facet_x)), scale = "free")
  }
  return(plotWithTheme(plt))
}
LinlinYin/mlr3vis documentation built on July 7, 2019, 12:11 p.m.