#' Plot performance evaluation measures with ggplot2
#'
#' The \code{autoplot} function plots performance evaluation measures
#' by using \pkg{ggplot2} instead of the general R plot.
#'
#' @param object An \code{S3} object generated by \code{\link{evalmod}}.
#' The \code{autoplot} function accepts the following \code{S3} objects for two
#' different modes, "rocprc" and "basic".
#'
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (\code{mode = "rocprc"})
#'
#' \tabular{lll}{
#' \strong{\code{S3} object}
#' \tab \strong{# of models}
#' \tab \strong{# of test datasets} \cr
#'
#' sscurves \tab single \tab single \cr
#' mscurves \tab multiple \tab single \cr
#' smcurves \tab single \tab multiple \cr
#' mmcurves \tab multiple \tab multiple
#' }
#'
#' \item Basic evaluation measures (\code{mode = "basic"})
#'
#' \tabular{lll}{
#' \strong{\code{S3} object}
#' \tab \strong{# of models}
#' \tab \strong{# of test datasets} \cr
#'
#' sspoints \tab single \tab single \cr
#' mspoints \tab multiple \tab single \cr
#' smpoints \tab single \tab multiple \cr
#' mmpoints \tab multiple \tab multiple
#' }
#' }
#'
#' See the \strong{Value} section of \code{\link{evalmod}} for more details.
#'
#' @param curvetype A character vector with the following curve types.
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (mode = "rocprc")
#' \tabular{ll}{
#' \strong{curvetype}
#' \tab \strong{description} \cr
#'
#' ROC \tab ROC curve \cr
#' PRC \tab Precision-Recall curve
#' }
#' Multiple \code{curvetype} can be combined, such as
#' \code{c("ROC", "PRC")}.
#'
#' \item Basic evaluation measures (mode = "basic")
#' \tabular{ll}{
#' \strong{curvetype}
#' \tab \strong{description} \cr
#'
#' error \tab Normalized ranks vs. error rate \cr
#' accuracy \tab Normalized ranks vs. accuracy \cr
#' specificity \tab Normalized ranks vs. specificity \cr
#' sensitivity \tab Normalized ranks vs. sensitivity \cr
#' precision \tab Normalized ranks vs. precision \cr
#' mcc \tab Normalized ranks vs. Matthews correlation coefficient \cr
#' fscore \tab Normalized ranks vs. F-score
#' }
#' Multiple \code{curvetype} can be combined, such as
#' \code{c("precision", "sensitivity")}.
#' }
#'
#' @param ... Following additional arguments can be specified.
#'
#' \describe{
#' \item{type}{
#' A character to specify the line type as follows.
#' \describe{
#' \item{"l"}{lines}
#' \item{"p"}{points}
#' \item{"b"}{both lines and points}
#' }
#' }
#' \item{show_cb}{
#' A Boolean value to specify whether point-wise confidence
#' bounds are drawn. It is effective only when \code{calc_avg} of the
#' \code{\link{evalmod}} function is set to \code{TRUE} .
#' }
#' \item{raw_curves}{
#' A Boolean value to specify whether raw curves are
#' shown instead of the average curve. It is effective only
#' when \code{raw_curves} of the \code{\link{evalmod}} function is set to
#' \code{TRUE}.
#' }
#' \item{show_legend}{
#' A Boolean value to specify whether the legend is shown.
#' }
#' \item{ret_grob}{
#' A logical value to indicate whether
#' \code{autoplot} returns a \code{grob} object. The \code{grob} object
#' is internally generated by \code{\link[gridExtra]{arrangeGrob}}.
#' The \code{\link[grid]{grid.draw}} function takes a \code{grob} object and
#' shows a plot. It is effective only when a multiple-panel plot is
#' generated, for example, when \code{curvetype} is \code{c("ROC", "PRC")}.
#' }
#' \item{reduce_points}{
#' A Boolean value to decide whether the points should be reduced
#' when \code{mode = "rocprc"}. The points are reduced according to
#' \code{x_bins} of the \code{\link{evalmod}} function.
#' The default values is \code{TRUE}.
#' }
#' }
#'
#' @return The \code{autoplot} function returns a \code{ggplot} object
#' for a single-panel plot and a frame-grob object for a multiple-panel plot.
#'
#' @seealso \code{\link{evalmod}} for generating an \code{S3} object.
#' \code{\link{fortify}} for converting a curves and points object
#' to a data frame. \code{\link{plot}} for plotting the equivalent curves
#' with the general R plot.
#'
#' @examples
#' \dontrun{
#'
#' ## Load libraries
#' library(ggplot2)
#' library(grid)
#'
#' ##################################################
#' ### Single model & single test dataset
#' ###
#'
#' ## Load a dataset with 10 positives and 10 negatives
#' data(P10N10)
#'
#' ## Generate an sscurve object that contains ROC and Precision-Recall curves
#' sscurves <- evalmod(scores = P10N10$scores, labels = P10N10$labels)
#'
#' ## Plot both ROC and Precision-Recall curves
#' autoplot(sscurves)
#'
#' ## Reduced/Full supporting points
#' sampss <- create_sim_samples(1, 50000, 50000)
#' evalss <- evalmod(scores = sampss$scores, labels = sampss$labels)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalss))
#'
#' # Full supporting points
#' system.time(autoplot(evalss, reduce_points = FALSE))
#'
#' ## Get a grob object for multiple plots
#' pp1 <- autoplot(sscurves, ret_grob = TRUE)
#' plot.new()
#' grid.draw(pp1)
#'
#' ## A ROC curve
#' autoplot(sscurves, curvetype = "ROC")
#'
#' ## A Precision-Recall curve
#' autoplot(sscurves, curvetype = "PRC")
#'
#' ## Generate an sspoints object that contains basic evaluation measures
#' sspoints <- evalmod(
#' mode = "basic", scores = P10N10$scores,
#' labels = P10N10$labels
#' )
#'
#' ## Normalized ranks vs. basic evaluation measures
#' autoplot(sspoints)
#'
#' ## Normalized ranks vs. precision
#' autoplot(sspoints, curvetype = "precision")
#'
#'
#' ##################################################
#' ### Multiple models & single test dataset
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(1, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]]
#' )
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mscurves <- evalmod(mdat)
#'
#' ## ROC and Precision-Recall curves
#' autoplot(mscurves)
#'
#' ## Reduced/Full supporting points
#' sampms <- create_sim_samples(5, 50000, 50000)
#' evalms <- evalmod(scores = sampms$scores, labels = sampms$labels)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalms))
#'
#' # Full supporting points
#' system.time(autoplot(evalms, reduce_points = FALSE))
#'
#' ## Hide the legend
#' autoplot(mscurves, show_legend = FALSE)
#'
#' ## Generate an mspoints object that contains basic evaluation measures
#' mspoints <- evalmod(mdat, mode = "basic")
#'
#' ## Normalized ranks vs. basic evaluation measures
#' autoplot(mspoints)
#'
#' ## Hide the legend
#' autoplot(mspoints, show_legend = FALSE)
#'
#'
#' ##################################################
#' ### Single model & multiple test datasets
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "good_er")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]],
#' dsids = samps[["dsids"]]
#' )
#'
#' ## Generate an smcurve object that contains ROC and Precision-Recall curves
#' smcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(smcurves, raw_curves = FALSE)
#'
#' ## Hide confidence bounds
#' autoplot(smcurves, raw_curves = FALSE, show_cb = FALSE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(smcurves, raw_curves = TRUE, show_cb = FALSE)
#'
#' ## Reduced/Full supporting points
#' sampsm <- create_sim_samples(4, 5000, 5000)
#' mdatsm <- mmdata(sampsm$scores, sampsm$labels, expd_first = "dsids")
#' evalsm <- evalmod(mdatsm, raw_curves = TRUE)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalsm, raw_curves = TRUE))
#'
#' # Full supporting points
#' system.time(autoplot(evalsm, raw_curves = TRUE, reduce_points = FALSE))
#'
#' ## Generate an smpoints object that contains basic evaluation measures
#' smpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Normalized ranks vs. average basic evaluation measures
#' autoplot(smpoints)
#'
#'
#' ##################################################
#' ### Multiple models & multiple test datasets
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]],
#' dsids = samps[["dsids"]]
#' )
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mmcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(mmcurves, raw_curves = FALSE)
#'
#' ## Show confidence bounds
#' autoplot(mmcurves, raw_curves = FALSE, show_cb = TRUE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(mmcurves, raw_curves = TRUE)
#'
#' ## Reduced/Full supporting points
#' sampmm <- create_sim_samples(4, 5000, 5000)
#' mdatmm <- mmdata(sampmm$scores, sampmm$labels,
#' modnames = c("m1", "m2"),
#' dsids = c(1, 2), expd_first = "modnames"
#' )
#' evalmm <- evalmod(mdatmm, raw_curves = TRUE)
#'
#' # Reduced supporting point
#' system.time(autoplot(evalmm, raw_curves = TRUE))
#'
#' # Full supporting points
#' system.time(autoplot(evalmm, raw_curves = TRUE, reduce_points = FALSE))
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' mmpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Normalized ranks vs. average basic evaluation measures
#' autoplot(mmpoints)
#'
#'
#' ##################################################
#' ### N-fold cross validation datasets
#' ###
#'
#' ## Load test data
#' data(M2N50F5)
#'
#' ## Speficy nessesary columns to create mdat
#' cvdat <- mmdata(
#' nfold_df = M2N50F5, score_cols = c(1, 2),
#' lab_col = 3, fold_col = 4,
#' modnames = c("m1", "m2"), dsids = 1:5
#' )
#'
#' ## Generate an mmcurve object that contains ROC and Precision-Recall curves
#' cvcurves <- evalmod(cvdat)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(cvcurves)
#'
#' ## Show confidence bounds
#' autoplot(cvcurves, show_cb = TRUE)
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' cvpoints <- evalmod(cvdat, mode = "basic")
#'
#' ## Normalized ranks vs. average basic evaluation measures
#' autoplot(cvpoints)
#' }
#'
#' @name autoplot
NULL
#
# Process ... for curve objects
#
.get_autoplot_arglist <- function(evalmod_args,
def_curvetype, def_type, def_show_cb,
def_raw_curves, def_add_np_nn,
def_show_legend, def_ret_grob,
def_reduce_points, def_multiplot_lib, ...) {
arglist <- list(...)
if (is.null(arglist[["curvetype"]])) {
arglist[["curvetype"]] <- def_curvetype
}
if (is.null(arglist[["type"]])) {
arglist[["type"]] <- def_type
}
if (is.null(arglist[["show_cb"]])) {
arglist[["show_cb"]] <- def_show_cb
}
if (!evalmod_args[["calc_avg"]] && arglist[["show_cb"]]) {
stop("Invalid show_cb. Inconsistent with calc_avg of evalmod.",
call. = FALSE
)
}
if (is.null(arglist[["raw_curves"]])) {
if (!is.null(def_raw_curves)) {
arglist[["raw_curves"]] <- def_raw_curves
} else if (!is.null(evalmod_args[["raw_curves"]])) {
arglist[["raw_curves"]] <- evalmod_args[["raw_curves"]]
} else {
arglist[["raw_curves"]] <- FALSE
}
}
if (!evalmod_args[["raw_curves"]] && arglist[["raw_curves"]]) {
stop("Invalid raw_curves. Inconsistent with the value of evalmod.",
call. = FALSE
)
}
if (is.null(arglist[["add_np_nn"]])) {
arglist[["add_np_nn"]] <- def_add_np_nn
}
if (is.null(arglist[["show_legend"]])) {
arglist[["show_legend"]] <- def_show_legend
}
if (is.null(arglist[["ret_grob"]])) {
arglist[["ret_grob"]] <- def_ret_grob
}
if (is.null(arglist[["reduce_points"]])) {
arglist[["reduce_points"]] <- def_reduce_points
}
if (is.null(arglist[["multiplot_lib"]])) {
arglist[["multiplot_lib"]] <- def_multiplot_lib
}
arglist
}
#
# Prepare autoplot and return a data frame
#
.prepare_autoplot <- function(object, curve_df = NULL, curvetype = NULL, ...) {
# === Check package availability ===
.load_ggplot2()
# === Validate input arguments ===
.validate(object)
# === Prepare a data frame for ggplot2 ===
if (is.null(curve_df)) {
curve_df <- ggplot2::fortify(object, ...)
}
if (!is.null(curvetype)) {
ctype <- curvetype
curve_df <- subset(curve_df, curvetype == ctype)
}
curve_df
}
#
# Load ggplot2
#
.load_ggplot2 <- function() {
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop(
paste(
"ggplot2 is required to perform this function.",
"Please install it."
),
call. = FALSE
)
}
}
#
# Load grid
#
.load_grid <- function() {
if (!requireNamespace("grid", quietly = TRUE)) {
stop("grid is required to perform this function. Please install it.",
call. = FALSE
)
}
}
#
# Load gridExtra
#
.load_gridextra <- function() {
if (!requireNamespace("gridExtra", quietly = TRUE)) {
stop("gridExtra is required to perform this function. Please install it.",
call. = FALSE
)
}
}
#
# Load patchwork
#
.load_patchwork <- function() {
if (requireNamespace("patchwork", quietly = TRUE)) {
return(TRUE)
} else {
warning(
paste0(
"patchwork is not installed. ",
"grid and gridExtra will be used instead."
),
call. = FALSE
)
return(FALSE)
}
}
#
# Plot ROC and Precision-Recall
#
.autoplot_multi <- function(object, arglist) {
curvetype <- arglist[["curvetype"]]
type <- arglist[["type"]]
add_np_nn <- arglist[["add_np_nn"]]
show_legend <- arglist[["show_legend"]]
ret_grob <- arglist[["ret_grob"]]
reduce_points <- arglist[["reduce_points"]]
multiplot_lib <- arglist[["multiplot_lib"]]
show_cb <- arglist[["show_cb"]]
if (!attr(object, "args")$calc_avg) {
show_cb <- FALSE
}
raw_curves <- arglist[["raw_curves"]]
if (show_cb) {
raw_curves <- FALSE
}
# === Check package availability ===
.load_ggplot2()
.validate(object)
.check_curvetype(curvetype, object)
.check_type(type)
.check_show_cb(show_cb, object)
.check_raw_curves(raw_curves, object)
.check_show_legend(show_legend)
.check_add_np_nn(add_np_nn)
.check_ret_grob(ret_grob)
.check_multiplot_lib(multiplot_lib)
# === Create a ggplot object for ROC&PRC, ROC, or PRC ===
curve_df <- ggplot2::fortify(object,
raw_curves = raw_curves,
reduce_points = reduce_points
)
func_plot <- function(ctype) {
.autoplot_single(object, curve_df,
curvetype = ctype, type = type,
show_cb = show_cb, raw_curves = raw_curves,
reduce_points = reduce_points, show_legend = show_legend,
add_np_nn = add_np_nn
)
}
lcurves <- lapply(curvetype, func_plot)
names(lcurves) <- curvetype
if (length(lcurves) > 1) {
do.call(.combine_plots, c(lcurves,
show_legend = show_legend,
ret_grob = ret_grob,
multiplot_lib = multiplot_lib,
nplots = length(lcurves)
))
} else {
lcurves[[1]]
}
}
#
# .grid_arrange_shared_legend
#
# Modified version of grid_arrange_shared_legend from RPubs
# URL of the original version:
# http://rpubs.com/sjackman/grid_arrange_shared_legend
#
.grid_arrange_shared_legend <- function(..., main_ncol = 2) {
plots <- list(...)
g <- ggplot2::ggplotGrob(plots[[1]]
+ ggplot2::theme(legend.position = "bottom"))$grobs
legend <- g[[which(lapply(g, function(x) x$name) == "guide-box")]]
lheight <- sum(legend$height)
fncol <- function(...) gridExtra::arrangeGrob(..., ncol = main_ncol)
fnolegend <- function(x) x + ggplot2::theme(legend.position = "none")
gridExtra::arrangeGrob(
do.call(fncol, lapply(plots, fnolegend)),
legend,
heights = grid::unit.c(grid::unit(1, "npc") - lheight, lheight),
ncol = 1
)
}
#
# Combine ROC and Precision-Recall plots by grid and gridExtra
#
.combine_plots_grid <- function(..., show_legend, ret_grob, nplots) {
if (nplots == 2 || nplots == 4) {
ncol <- 2
} else {
ncol <- 3
}
if (show_legend) {
grobframe <- .grid_arrange_shared_legend(..., main_ncol = ncol)
} else {
grobframe <- gridExtra::arrangeGrob(..., ncol = ncol)
}
if (ret_grob) {
grobframe
} else {
graphics::plot.new()
grid::grid.draw(grobframe)
}
}
#
# Combine ROC and Precision-Recall plots by patchwork
#
.combine_plots_patchwork <- function(..., show_legend) {
plotlist <- list(...)
if (length(plotlist) == 2 || length(plotlist) == 4) {
ncol <- 2
} else {
ncol <- 3
}
p <- patchwork::wrap_plots(plotlist, ncol = ncol)
if (show_legend) {
p <- p + patchwork::plot_layout(guides = "collect")
if (length(plotlist) > 2) {
p <- p + ggplot2::theme(legend.position = "bottom")
}
}
p
}
#
# Combine ROC and Precision-Recall plots
#
.combine_plots <- function(..., show_legend, ret_grob, multiplot_lib, nplots) {
if (multiplot_lib == "patchwork") {
if (.load_patchwork()) {
return(.combine_plots_patchwork(..., show_legend = show_legend))
} else {
multiplot_lib <- "grid"
}
}
if (multiplot_lib == "grid") {
.load_grid()
.load_gridextra()
.combine_plots_grid(...,
show_legend = show_legend,
ret_grob = ret_grob, nplots = nplots
)
}
}
#
# Plot ROC or Precision-Recall
#
.autoplot_single <- function(object, curve_df, curvetype = "ROC", type = "l",
show_cb = FALSE, raw_curves = FALSE,
reduce_points = TRUE, show_legend = FALSE,
add_np_nn = TRUE, ...) {
curve_df <- .prepare_autoplot(object,
curve_df = curve_df,
curvetype = curvetype,
raw_curves = raw_curves,
reduce_points = reduce_points, ...
)
# === Create a ggplot object ===
x_col <- rlang::sym("x")
y_col <- rlang::sym("y")
ymin_col <- rlang::sym("ymin")
ymax_col <- rlang::sym("ymax")
modname_col <- rlang::sym("modname")
dsid_modname_col <- rlang::sym("dsid_modname")
if (show_cb) {
p <- ggplot2::ggplot(
curve_df,
ggplot2::aes(
x = !!x_col, y = !!y_col,
ymin = !!ymin_col, ymax = !!ymax_col
)
)
if (type == "l") {
p <- p + ggplot2::geom_smooth(ggplot2::aes(color = !!modname_col),
stat = "identity", na.rm = TRUE,
linewidth = 0.5
)
} else if (type == "b" || type == "p") {
p <- p + ggplot2::geom_ribbon(
ggplot2::aes(
ymin = !!ymin_col,
ymax = !!ymax_col,
group = !!modname_col
),
stat = "identity", alpha = 0.25,
fill = "grey25", na.rm = TRUE
)
if (type == "b") {
p <- p + ggplot2::geom_line(ggplot2::aes(color = !!modname_col),
alpha = 0.25, na.rm = TRUE
)
}
p <- p + ggplot2::geom_point(
ggplot2::aes(
x = !!x_col, y = !!y_col,
color = !!modname_col
),
na.rm = TRUE
)
}
} else if (raw_curves) {
p <- ggplot2::ggplot(
curve_df,
ggplot2::aes(
x = !!x_col, y = !!y_col,
group = !!dsid_modname_col,
color = !!modname_col
)
)
if (type == "l") {
p <- p + ggplot2::geom_line(na.rm = TRUE)
} else if (type == "b" || type == "p") {
if (type == "b") {
p <- p + ggplot2::geom_line(alpha = 0.25, na.rm = TRUE)
}
p <- p + ggplot2::geom_point(na.rm = TRUE)
}
} else {
p <- ggplot2::ggplot(curve_df, ggplot2::aes(
x = !!x_col, y = !!y_col,
color = !!modname_col
))
if (type == "l") {
p <- p + ggplot2::geom_line(na.rm = TRUE)
} else if (type == "b" || type == "p") {
if (type == "b") {
p <- p + ggplot2::geom_line(alpha = 0.25, na.rm = TRUE)
}
p <- p + ggplot2::geom_point(na.rm = TRUE)
}
}
if (curvetype == "ROC") {
func_g <- .geom_basic_roc
} else if (curvetype == "PRC") {
func_g <- .geom_basic_prc
} else {
func_g <- .geom_basic_point
}
if (curvetype == "ROC") {
xlim <- attr(object[["rocs"]], "xlim")
ylim <- attr(object[["rocs"]], "ylim")
} else if (curvetype == "PRC") {
xlim <- attr(object[["prcs"]], "xlim")
ylim <- attr(object[["prcs"]], "ylim")
} else if (curvetype == "mcc" || curvetype == "label") {
xlim <- c(0, 1)
ylim <- c(-1, 1)
ratio <- 0.5
} else if (curvetype == "score") {
xlim <- c(0, 1)
ylim <- NULL
ratio <- NULL
} else {
xlim <- c(0, 1)
ylim <- c(0, 1)
ratio <- 1
}
if (curvetype == "ROC" || curvetype == "PRC") {
if (all(xlim == ylim)) {
ratio <- 1
} else {
ratio <- NULL
}
}
p <- func_g(p, object,
show_legend = show_legend, add_np_nn = add_np_nn,
curve_df = curve_df, xlim = xlim, ylim = ylim, ratio = ratio, ...
)
p
}
#
# Geom basic
#
.geom_basic <- function(p, main, xlab, ylab, show_legend) {
p <- p + ggplot2::theme_bw()
p <- p + ggplot2::ggtitle(main)
p <- p + ggplot2::xlab(xlab)
p <- p + ggplot2::ylab(ylab)
p <- p + ggplot2::theme(legend.title = ggplot2::element_blank())
if (!show_legend) {
p <- p + ggplot2::theme(legend.position = "none")
}
p
}
#
# Make main title
#
.make_rocprc_title <- function(object, pt) {
pn_info <- .get_pn_info(object)
np <- pn_info$avg_np
nn <- pn_info$avg_nn
paste0(pt, " - P: ", np, ", N: ", nn)
}
#
# Geom basic for ROC
#
.geom_basic_roc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
xlim, ylim, ratio, ...) {
pn_info <- .get_pn_info(object)
if (add_np_nn && pn_info$is_consistant) {
main <- .make_rocprc_title(object, "ROC")
} else {
main <- "ROC"
}
p <- p + ggplot2::geom_abline(
intercept = 0, slope = 1, colour = "grey",
linetype = 3
)
p <- .set_coords(p, xlim, ylim, ratio)
p <- .geom_basic(p, main, "1 - Specificity", "Sensitivity", show_legend)
p
}
#
# Geom_line for Precision-Recall
#
.geom_basic_prc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
xlim, ylim, ratio, ...) {
pn_info <- .get_pn_info(object)
if (add_np_nn && pn_info$is_consistant) {
main <- .make_rocprc_title(object, "Precision-Recall")
} else {
main <- "Precision-Recall"
}
p <- p + ggplot2::geom_hline(
yintercept = pn_info$prc_base, colour = "grey",
linetype = 3
)
p <- .set_coords(p, xlim, ylim, ratio)
p <- .geom_basic(p, main, "Recall", "Precision", show_legend)
p
}
#
# Set coordinates for ROC and precision-recall
#
.set_coords <- function(p, xlim, ylim, ratio) {
if (is.null(ratio)) {
p <- p + ggplot2::coord_cartesian(xlim = xlim, ylim = ylim)
} else {
p <- p + ggplot2::coord_fixed(ratio = ratio, xlim = xlim, ylim = ylim)
}
p
}
#
# Geom_line for Precision-Recall
#
.geom_basic_point <- function(p, object, show_legend = TRUE,
curve_df = curve_df, xlim, ylim, ratio, ...) {
s <- curve_df[["curvetype"]][1]
if (s == "mcc") {
main <- "MCC"
} else if (s == "label") {
main <- "Label (1:pos, -1:neg)"
} else {
main <- paste0(toupper(substring(s, 1, 1)), substring(s, 2))
}
p <- .set_coords(p, xlim, ylim, ratio)
p <- .geom_basic(p, main, "normalized rank", s, show_legend)
p
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.