R/etc_utils_autoplot.R

Defines functions .geom_basic_point .set_coords .geom_basic_prc .geom_basic_roc .make_rocprc_title .geom_basic .autoplot_single .combine_plots .combine_plots_patchwork .combine_plots_grid .grid_arrange_shared_legend .autoplot_multi .load_patchwork .load_gridextra .load_grid .load_ggplot2 .prepare_autoplot .get_autoplot_arglist

#' Plot performance evaluation measures with ggplot2
#'
#' The \code{autoplot} function plots performance evaluation measures
#'   by using \pkg{ggplot2} instead of the general R plot.
#'
#' @param object An \code{S3} object generated by \code{\link{evalmod}}.
#'   The \code{autoplot} function accepts the following \code{S3} objects for two
#'   different modes, "rocprc" and "basic".
#'
#' \enumerate{
#'
#'   \item ROC and Precision-Recall curves (\code{mode = "rocprc"})
#'
#'   \tabular{lll}{
#'     \strong{\code{S3} object}
#'     \tab \strong{# of models}
#'     \tab \strong{# of test datasets} \cr
#'
#'     sscurves \tab single   \tab single   \cr
#'     mscurves \tab multiple \tab single   \cr
#'     smcurves \tab single   \tab multiple \cr
#'     mmcurves \tab multiple \tab multiple
#'   }
#'
#'   \item Basic evaluation measures (\code{mode = "basic"})
#'
#'   \tabular{lll}{
#'     \strong{\code{S3} object}
#'     \tab \strong{# of models}
#'     \tab \strong{# of test datasets} \cr
#'
#'     sspoints \tab single   \tab single   \cr
#'     mspoints \tab multiple \tab single   \cr
#'     smpoints \tab single   \tab multiple \cr
#'     mmpoints \tab multiple \tab multiple
#'   }
#' }
#'
#' See the \strong{Value} section of \code{\link{evalmod}} for more details.
#'
#' @param curvetype A character vector with the following curve types.
#' \enumerate{
#'
#'   \item ROC and Precision-Recall curves (mode = "rocprc")
#'     \tabular{ll}{
#'       \strong{curvetype}
#'       \tab \strong{description} \cr
#'
#'       ROC \tab ROC curve \cr
#'       PRC \tab Precision-Recall curve
#'     }
#'     Multiple \code{curvetype} can be combined, such as
#'     \code{c("ROC", "PRC")}.
#'
#'   \item Basic evaluation measures (mode = "basic")
#'     \tabular{ll}{
#'       \strong{curvetype}
#'       \tab \strong{description} \cr
#'
#'       error \tab Normalized ranks vs. error rate \cr
#'       accuracy \tab Normalized ranks vs. accuracy \cr
#'       specificity \tab Normalized ranks vs. specificity \cr
#'       sensitivity \tab Normalized ranks vs. sensitivity \cr
#'       precision \tab Normalized ranks vs. precision \cr
#'       mcc \tab Normalized ranks vs. Matthews correlation coefficient \cr
#'       fscore \tab Normalized ranks vs. F-score
#'    }
#'    Multiple \code{curvetype} can be combined, such as
#'    \code{c("precision", "sensitivity")}.
#' }
#'
#' @param ... Following additional arguments can be specified.
#'
#' \describe{
#'   \item{type}{
#'     A character to specify the line type as follows.
#'     \describe{
#'       \item{"l"}{lines}
#'       \item{"p"}{points}
#'       \item{"b"}{both lines and points}
#'     }
#'   }
#'   \item{show_cb}{
#'     A Boolean value to specify whether point-wise confidence
#'     bounds are drawn. It is effective only when \code{calc_avg} of the
#'    \code{\link{evalmod}} function is set to \code{TRUE} .
#'   }
#'   \item{raw_curves}{
#'     A Boolean value to specify whether raw curves are
#'     shown instead of the average curve. It is effective only
#'     when \code{raw_curves} of the \code{\link{evalmod}} function is set to
#'     \code{TRUE}.
#'   }
#'   \item{show_legend}{
#'     A Boolean value to specify whether the legend is shown.
#'   }
#'   \item{ret_grob}{
#'     A logical value to indicate whether
#'     \code{autoplot} returns a \code{grob} object. The \code{grob} object
#'     is internally generated by \code{\link[gridExtra]{arrangeGrob}}.
#'     The \code{\link[grid]{grid.draw}} function takes a \code{grob} object and
#'     shows a plot. It is effective only when a multiple-panel plot is
#'     generated, for example, when \code{curvetype} is \code{c("ROC", "PRC")}.
#'   }
#'   \item{reduce_points}{
#'     A Boolean value to decide whether the points should be reduced
#'     when \code{mode = "rocprc"}. The points are reduced according to
#'     \code{x_bins} of the \code{\link{evalmod}} function.
#'     The default values is \code{TRUE}.
#'   }
#' }
#'
#' @return The \code{autoplot} function returns a \code{ggplot} object
#'   for a single-panel plot and a frame-grob object for a multiple-panel plot.
#'
#' @seealso \code{\link{evalmod}} for generating an \code{S3} object.
#'   \code{\link{fortify}} for converting a curves and points object
#'   to a data frame.  \code{\link{plot}} for plotting the equivalent curves
#'   with the general R plot.
#'
#' @examples
#' \dontrun{
#'
#' ## Load libraries
#' library(ggplot2)
#' library(grid)
#'
#' ##################################################
#' ### Single model & single test dataset
#' ###
#'
#' ## Load a dataset with 10 positives and 10 negatives
#' data(P10N10)
#'
#' ## Generate an sscurve object that contains ROC and Precision-Recall curves
#' sscurves <- evalmod(scores = P10N10$scores, labels = P10N10$labels)
#'
#' ## Plot both ROC and Precision-Recall curves
#' autoplot(sscurves)
#'
#' ## Reduced/Full supporting points
#' sampss <- create_sim_samples(1, 50000, 50000)
#' evalss <- evalmod(scores = sampss$scores, labels = sampss$labels)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalss))
#'
#' # Full supporting points
#' system.time(autoplot(evalss, reduce_points = FALSE))
#'
#' ## Get a grob object for multiple plots
#' pp1 <- autoplot(sscurves, ret_grob = TRUE)
#' plot.new()
#' grid.draw(pp1)
#'
#' ## A ROC curve
#' autoplot(sscurves, curvetype = "ROC")
#'
#' ## A Precision-Recall curve
#' autoplot(sscurves, curvetype = "PRC")
#'
#' ## Generate an sspoints object that contains basic evaluation measures
#' sspoints <- evalmod(
#'   mode = "basic", scores = P10N10$scores,
#'   labels = P10N10$labels
#' )
#'
#' ## Normalized ranks vs. basic evaluation measures
#' autoplot(sspoints)
#'
#' ## Normalized ranks vs. precision
#' autoplot(sspoints, curvetype = "precision")
#'
#'
#' ##################################################
#' ### Multiple models & single test dataset
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(1, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'   modnames = samps[["modnames"]]
#' )
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mscurves <- evalmod(mdat)
#'
#' ## ROC and Precision-Recall curves
#' autoplot(mscurves)
#'
#' ## Reduced/Full supporting points
#' sampms <- create_sim_samples(5, 50000, 50000)
#' evalms <- evalmod(scores = sampms$scores, labels = sampms$labels)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalms))
#'
#' # Full supporting points
#' system.time(autoplot(evalms, reduce_points = FALSE))
#'
#' ## Hide the legend
#' autoplot(mscurves, show_legend = FALSE)
#'
#' ## Generate an mspoints object that contains basic evaluation measures
#' mspoints <- evalmod(mdat, mode = "basic")
#'
#' ## Normalized ranks vs. basic evaluation measures
#' autoplot(mspoints)
#'
#' ## Hide the legend
#' autoplot(mspoints, show_legend = FALSE)
#'
#'
#' ##################################################
#' ### Single model & multiple test datasets
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "good_er")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'   modnames = samps[["modnames"]],
#'   dsids = samps[["dsids"]]
#' )
#'
#' ## Generate an smcurve object that contains ROC and Precision-Recall curves
#' smcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(smcurves, raw_curves = FALSE)
#'
#' ## Hide confidence bounds
#' autoplot(smcurves, raw_curves = FALSE, show_cb = FALSE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(smcurves, raw_curves = TRUE, show_cb = FALSE)
#'
#' ## Reduced/Full supporting points
#' sampsm <- create_sim_samples(4, 5000, 5000)
#' mdatsm <- mmdata(sampsm$scores, sampsm$labels, expd_first = "dsids")
#' evalsm <- evalmod(mdatsm, raw_curves = TRUE)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalsm, raw_curves = TRUE))
#'
#' # Full supporting points
#' system.time(autoplot(evalsm, raw_curves = TRUE, reduce_points = FALSE))
#'
#' ## Generate an smpoints object that contains basic evaluation measures
#' smpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Normalized ranks vs. average basic evaluation measures
#' autoplot(smpoints)
#'
#'
#' ##################################################
#' ### Multiple models & multiple test datasets
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'   modnames = samps[["modnames"]],
#'   dsids = samps[["dsids"]]
#' )
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mmcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(mmcurves, raw_curves = FALSE)
#'
#' ## Show confidence bounds
#' autoplot(mmcurves, raw_curves = FALSE, show_cb = TRUE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(mmcurves, raw_curves = TRUE)
#'
#' ## Reduced/Full supporting points
#' sampmm <- create_sim_samples(4, 5000, 5000)
#' mdatmm <- mmdata(sampmm$scores, sampmm$labels,
#'   modnames = c("m1", "m2"),
#'   dsids = c(1, 2), expd_first = "modnames"
#' )
#' evalmm <- evalmod(mdatmm, raw_curves = TRUE)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalmm, raw_curves = TRUE))
#'
#' # Full supporting points
#' system.time(autoplot(evalmm, raw_curves = TRUE, reduce_points = FALSE))
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' mmpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Normalized ranks vs. average basic evaluation measures
#' autoplot(mmpoints)
#'
#'
#' ##################################################
#' ### N-fold cross validation datasets
#' ###
#'
#' ## Load test data
#' data(M2N50F5)
#'
#' ## Speficy nessesary columns to create mdat
#' cvdat <- mmdata(
#'   nfold_df = M2N50F5, score_cols = c(1, 2),
#'   lab_col = 3, fold_col = 4,
#'   modnames = c("m1", "m2"), dsids = 1:5
#' )
#'
#' ## Generate an mmcurve object that contains ROC and Precision-Recall curves
#' cvcurves <- evalmod(cvdat)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(cvcurves)
#'
#' ## Show confidence bounds
#' autoplot(cvcurves, show_cb = TRUE)
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' cvpoints <- evalmod(cvdat, mode = "basic")
#'
#' ## Normalized ranks vs. average basic evaluation measures
#' autoplot(cvpoints)
#' }
#'
#' @name autoplot
NULL

#
# Process ... for curve objects
#
.get_autoplot_arglist <- function(evalmod_args,
                                  def_curvetype, def_type, def_show_cb,
                                  def_raw_curves, def_add_np_nn,
                                  def_show_legend, def_ret_grob,
                                  def_reduce_points, def_multiplot_lib, ...) {
  arglist <- list(...)

  if (is.null(arglist[["curvetype"]])) {
    arglist[["curvetype"]] <- def_curvetype
  }

  if (is.null(arglist[["type"]])) {
    arglist[["type"]] <- def_type
  }

  if (is.null(arglist[["show_cb"]])) {
    arglist[["show_cb"]] <- def_show_cb
  }
  if (!evalmod_args[["calc_avg"]] && arglist[["show_cb"]]) {
    stop("Invalid show_cb. Inconsistent with calc_avg of evalmod.",
      call. = FALSE
    )
  }

  if (is.null(arglist[["raw_curves"]])) {
    if (!is.null(def_raw_curves)) {
      arglist[["raw_curves"]] <- def_raw_curves
    } else if (!is.null(evalmod_args[["raw_curves"]])) {
      arglist[["raw_curves"]] <- evalmod_args[["raw_curves"]]
    } else {
      arglist[["raw_curves"]] <- FALSE
    }
  }
  if (!evalmod_args[["raw_curves"]] && arglist[["raw_curves"]]) {
    stop("Invalid raw_curves. Inconsistent with the value of evalmod.",
      call. = FALSE
    )
  }

  if (is.null(arglist[["add_np_nn"]])) {
    arglist[["add_np_nn"]] <- def_add_np_nn
  }

  if (is.null(arglist[["show_legend"]])) {
    arglist[["show_legend"]] <- def_show_legend
  }

  if (is.null(arglist[["ret_grob"]])) {
    arglist[["ret_grob"]] <- def_ret_grob
  }

  if (is.null(arglist[["reduce_points"]])) {
    arglist[["reduce_points"]] <- def_reduce_points
  }

  if (is.null(arglist[["multiplot_lib"]])) {
    arglist[["multiplot_lib"]] <- def_multiplot_lib
  }

  arglist
}

#
# Prepare autoplot and return a data frame
#
.prepare_autoplot <- function(object, curve_df = NULL, curvetype = NULL, ...) {
  # === Check package availability  ===
  .load_ggplot2()

  # === Validate input arguments ===
  .validate(object)

  # === Prepare a data frame for ggplot2 ===
  if (is.null(curve_df)) {
    curve_df <- ggplot2::fortify(object, ...)
  }

  if (!is.null(curvetype)) {
    ctype <- curvetype
    curve_df <- subset(curve_df, curvetype == ctype)
  }

  curve_df
}

#
# Load ggplot2
#
.load_ggplot2 <- function() {
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop(
      paste(
        "ggplot2 is required to perform this function.",
        "Please install it."
      ),
      call. = FALSE
    )
  }
}

#
# Load grid
#
.load_grid <- function() {
  if (!requireNamespace("grid", quietly = TRUE)) {
    stop("grid is required to perform this function. Please install it.",
      call. = FALSE
    )
  }
}

#
# Load gridExtra
#
.load_gridextra <- function() {
  if (!requireNamespace("gridExtra", quietly = TRUE)) {
    stop("gridExtra is required to perform this function. Please install it.",
      call. = FALSE
    )
  }
}

#
# Load patchwork
#
.load_patchwork <- function() {
  if (requireNamespace("patchwork", quietly = TRUE)) {
    return(TRUE)
  } else {
    warning(
      paste0(
        "patchwork is not installed. ",
        "grid and gridExtra will be used instead."
      ),
      call. = FALSE
    )
    return(FALSE)
  }
}

#
# Plot ROC and Precision-Recall
#
.autoplot_multi <- function(object, arglist) {
  curvetype <- arglist[["curvetype"]]
  type <- arglist[["type"]]
  add_np_nn <- arglist[["add_np_nn"]]
  show_legend <- arglist[["show_legend"]]
  ret_grob <- arglist[["ret_grob"]]
  reduce_points <- arglist[["reduce_points"]]
  multiplot_lib <- arglist[["multiplot_lib"]]

  show_cb <- arglist[["show_cb"]]
  if (!attr(object, "args")$calc_avg) {
    show_cb <- FALSE
  }

  raw_curves <- arglist[["raw_curves"]]
  if (show_cb) {
    raw_curves <- FALSE
  }

  # === Check package availability  ===
  .load_ggplot2()
  .validate(object)
  .check_curvetype(curvetype, object)
  .check_type(type)
  .check_show_cb(show_cb, object)
  .check_raw_curves(raw_curves, object)
  .check_show_legend(show_legend)
  .check_add_np_nn(add_np_nn)
  .check_ret_grob(ret_grob)
  .check_multiplot_lib(multiplot_lib)

  # === Create a ggplot object for ROC&PRC, ROC, or PRC ===
  curve_df <- ggplot2::fortify(object,
    raw_curves = raw_curves,
    reduce_points = reduce_points
  )

  func_plot <- function(ctype) {
    .autoplot_single(object, curve_df,
      curvetype = ctype, type = type,
      show_cb = show_cb, raw_curves = raw_curves,
      reduce_points = reduce_points, show_legend = show_legend,
      add_np_nn = add_np_nn
    )
  }
  lcurves <- lapply(curvetype, func_plot)
  names(lcurves) <- curvetype

  if (length(lcurves) > 1) {
    do.call(.combine_plots, c(lcurves,
      show_legend = show_legend,
      ret_grob = ret_grob,
      multiplot_lib = multiplot_lib,
      nplots = length(lcurves)
    ))
  } else {
    lcurves[[1]]
  }
}

#
# .grid_arrange_shared_legend
#
#   Modified version of grid_arrange_shared_legend from RPubs
#   URL of the original version:
#     http://rpubs.com/sjackman/grid_arrange_shared_legend
#
.grid_arrange_shared_legend <- function(..., main_ncol = 2) {
  plots <- list(...)

  g <- ggplot2::ggplotGrob(plots[[1]]
  + ggplot2::theme(legend.position = "bottom"))$grobs
  legend <- g[[which(lapply(g, function(x) x$name) == "guide-box")]]
  lheight <- sum(legend$height)

  fncol <- function(...) gridExtra::arrangeGrob(..., ncol = main_ncol)
  fnolegend <- function(x) x + ggplot2::theme(legend.position = "none")

  gridExtra::arrangeGrob(
    do.call(fncol, lapply(plots, fnolegend)),
    legend,
    heights = grid::unit.c(grid::unit(1, "npc") - lheight, lheight),
    ncol = 1
  )
}

#
# Combine ROC and Precision-Recall plots by grid and gridExtra
#
.combine_plots_grid <- function(..., show_legend, ret_grob, nplots) {
  if (nplots == 2 || nplots == 4) {
    ncol <- 2
  } else {
    ncol <- 3
  }

  if (show_legend) {
    grobframe <- .grid_arrange_shared_legend(..., main_ncol = ncol)
  } else {
    grobframe <- gridExtra::arrangeGrob(..., ncol = ncol)
  }

  if (ret_grob) {
    grobframe
  } else {
    graphics::plot.new()
    grid::grid.draw(grobframe)
  }
}

#
# Combine ROC and Precision-Recall plots by patchwork
#
.combine_plots_patchwork <- function(..., show_legend) {
  plotlist <- list(...)

  if (length(plotlist) == 2 || length(plotlist) == 4) {
    ncol <- 2
  } else {
    ncol <- 3
  }

  p <- patchwork::wrap_plots(plotlist, ncol = ncol)
  if (show_legend) {
    p <- p + patchwork::plot_layout(guides = "collect")
    if (length(plotlist) > 2) {
      p <- p + ggplot2::theme(legend.position = "bottom")
    }
  }

  p
}

#
# Combine ROC and Precision-Recall plots
#
.combine_plots <- function(..., show_legend, ret_grob, multiplot_lib, nplots) {
  if (multiplot_lib == "patchwork") {
    if (.load_patchwork()) {
      return(.combine_plots_patchwork(..., show_legend = show_legend))
    } else {
      multiplot_lib <- "grid"
    }
  }
  if (multiplot_lib == "grid") {
    .load_grid()
    .load_gridextra()
    .combine_plots_grid(...,
      show_legend = show_legend,
      ret_grob = ret_grob, nplots = nplots
    )
  }
}

#
# Plot ROC or Precision-Recall
#
.autoplot_single <- function(object, curve_df, curvetype = "ROC", type = "l",
                             show_cb = FALSE, raw_curves = FALSE,
                             reduce_points = TRUE, show_legend = FALSE,
                             add_np_nn = TRUE, ...) {
  curve_df <- .prepare_autoplot(object,
    curve_df = curve_df,
    curvetype = curvetype,
    raw_curves = raw_curves,
    reduce_points = reduce_points, ...
  )

  # === Create a ggplot object ===
  x_col <- rlang::sym("x")
  y_col <- rlang::sym("y")
  ymin_col <- rlang::sym("ymin")
  ymax_col <- rlang::sym("ymax")
  modname_col <- rlang::sym("modname")
  dsid_modname_col <- rlang::sym("dsid_modname")
  if (show_cb) {
    p <- ggplot2::ggplot(
      curve_df,
      ggplot2::aes(
        x = !!x_col, y = !!y_col,
        ymin = !!ymin_col, ymax = !!ymax_col
      )
    )
    if (type == "l") {
      p <- p + ggplot2::geom_smooth(ggplot2::aes(color = !!modname_col),
        stat = "identity", na.rm = TRUE,
        linewidth = 0.5
      )
    } else if (type == "b" || type == "p") {
      p <- p + ggplot2::geom_ribbon(
        ggplot2::aes(
          ymin = !!ymin_col,
          ymax = !!ymax_col,
          group = !!modname_col
        ),
        stat = "identity", alpha = 0.25,
        fill = "grey25", na.rm = TRUE
      )
      if (type == "b") {
        p <- p + ggplot2::geom_line(ggplot2::aes(color = !!modname_col),
          alpha = 0.25, na.rm = TRUE
        )
      }
      p <- p + ggplot2::geom_point(
        ggplot2::aes(
          x = !!x_col, y = !!y_col,
          color = !!modname_col
        ),
        na.rm = TRUE
      )
    }
  } else if (raw_curves) {
    p <- ggplot2::ggplot(
      curve_df,
      ggplot2::aes(
        x = !!x_col, y = !!y_col,
        group = !!dsid_modname_col,
        color = !!modname_col
      )
    )

    if (type == "l") {
      p <- p + ggplot2::geom_line(na.rm = TRUE)
    } else if (type == "b" || type == "p") {
      if (type == "b") {
        p <- p + ggplot2::geom_line(alpha = 0.25, na.rm = TRUE)
      }
      p <- p + ggplot2::geom_point(na.rm = TRUE)
    }
  } else {
    p <- ggplot2::ggplot(curve_df, ggplot2::aes(
      x = !!x_col, y = !!y_col,
      color = !!modname_col
    ))
    if (type == "l") {
      p <- p + ggplot2::geom_line(na.rm = TRUE)
    } else if (type == "b" || type == "p") {
      if (type == "b") {
        p <- p + ggplot2::geom_line(alpha = 0.25, na.rm = TRUE)
      }
      p <- p + ggplot2::geom_point(na.rm = TRUE)
    }
  }

  if (curvetype == "ROC") {
    func_g <- .geom_basic_roc
  } else if (curvetype == "PRC") {
    func_g <- .geom_basic_prc
  } else {
    func_g <- .geom_basic_point
  }

  if (curvetype == "ROC") {
    xlim <- attr(object[["rocs"]], "xlim")
    ylim <- attr(object[["rocs"]], "ylim")
  } else if (curvetype == "PRC") {
    xlim <- attr(object[["prcs"]], "xlim")
    ylim <- attr(object[["prcs"]], "ylim")
  } else if (curvetype == "mcc" || curvetype == "label") {
    xlim <- c(0, 1)
    ylim <- c(-1, 1)
    ratio <- 0.5
  } else if (curvetype == "score") {
    xlim <- c(0, 1)
    ylim <- NULL
    ratio <- NULL
  } else {
    xlim <- c(0, 1)
    ylim <- c(0, 1)
    ratio <- 1
  }
  if (curvetype == "ROC" || curvetype == "PRC") {
    if (all(xlim == ylim)) {
      ratio <- 1
    } else {
      ratio <- NULL
    }
  }
  p <- func_g(p, object,
    show_legend = show_legend, add_np_nn = add_np_nn,
    curve_df = curve_df, xlim = xlim, ylim = ylim, ratio = ratio, ...
  )

  p
}

#
# Geom basic
#
.geom_basic <- function(p, main, xlab, ylab, show_legend) {
  p <- p + ggplot2::theme_bw()
  p <- p + ggplot2::ggtitle(main)
  p <- p + ggplot2::xlab(xlab)
  p <- p + ggplot2::ylab(ylab)

  p <- p + ggplot2::theme(legend.title = ggplot2::element_blank())
  if (!show_legend) {
    p <- p + ggplot2::theme(legend.position = "none")
  }

  p
}

#
# Make main title
#
.make_rocprc_title <- function(object, pt) {
  pn_info <- .get_pn_info(object)
  np <- pn_info$avg_np
  nn <- pn_info$avg_nn

  paste0(pt, " - P: ", np, ", N: ", nn)
}

#
# Geom basic for ROC
#
.geom_basic_roc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
                            xlim, ylim, ratio, ...) {
  pn_info <- .get_pn_info(object)

  if (add_np_nn && pn_info$is_consistant) {
    main <- .make_rocprc_title(object, "ROC")
  } else {
    main <- "ROC"
  }

  p <- p + ggplot2::geom_abline(
    intercept = 0, slope = 1, colour = "grey",
    linetype = 3
  )
  p <- .set_coords(p, xlim, ylim, ratio)
  p <- .geom_basic(p, main, "1 - Specificity", "Sensitivity", show_legend)

  p
}

#
# Geom_line for Precision-Recall
#
.geom_basic_prc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
                            xlim, ylim, ratio, ...) {
  pn_info <- .get_pn_info(object)

  if (add_np_nn && pn_info$is_consistant) {
    main <- .make_rocprc_title(object, "Precision-Recall")
  } else {
    main <- "Precision-Recall"
  }

  p <- p + ggplot2::geom_hline(
    yintercept = pn_info$prc_base, colour = "grey",
    linetype = 3
  )
  p <- .set_coords(p, xlim, ylim, ratio)
  p <- .geom_basic(p, main, "Recall", "Precision", show_legend)

  p
}

#
# Set coordinates for ROC and precision-recall
#
.set_coords <- function(p, xlim, ylim, ratio) {
  if (is.null(ratio)) {
    p <- p + ggplot2::coord_cartesian(xlim = xlim, ylim = ylim)
  } else {
    p <- p + ggplot2::coord_fixed(ratio = ratio, xlim = xlim, ylim = ylim)
  }

  p
}

#
# Geom_line for Precision-Recall
#
.geom_basic_point <- function(p, object, show_legend = TRUE,
                              curve_df = curve_df, xlim, ylim, ratio, ...) {
  s <- curve_df[["curvetype"]][1]
  if (s == "mcc") {
    main <- "MCC"
  } else if (s == "label") {
    main <- "Label (1:pos, -1:neg)"
  } else {
    main <- paste0(toupper(substring(s, 1, 1)), substring(s, 2))
  }
  p <- .set_coords(p, xlim, ylim, ratio)
  p <- .geom_basic(p, main, "normalized rank", s, show_legend)

  p
}
takayasaito/precrec documentation built on Oct. 19, 2023, 7:28 p.m.