#
# Create ROC and Precision-Recall curves
#
create_curves <- function(pevals, scores = NULL, labels = NULL,
x_bins = 1000, keep_pevals = FALSE, ...) {
# === Validate input arguments ===
# Create pevals from scores and labels if pevals is missing
pevals <- .create_src_obj(pevals, "pevals", calc_measures, scores, labels,
...)
if (is.null(x_bins) || is.na(x_bins)) {
x_bins <- 1
}
.validate_x_bins(x_bins)
.validate(pevals)
# === Create ROC and Precision-Recall curves ===
roc_curve <- create_roc(pevals, x_bins = x_bins,
keep_pevals = keep_pevals, ...)
prc_curve <- create_prc(pevals, x_bins = x_bins,
keep_pevals = keep_pevals, ...)
curves <- list(roc = roc_curve, prc = prc_curve)
# === Create an S3 object ===
s3obj <- structure(curves, class = "curves")
# Set attributes
attr(s3obj, "modname") <- attr(pevals, "modname")
attr(s3obj, "dsid") <- attr(pevals, "dsid")
attr(s3obj, "nn") <- attr(pevals, "nn")
attr(s3obj, "np") <- attr(pevals, "np")
attr(s3obj, "args") <- c(list(x_bins = x_bins), list(...))
if (keep_pevals) {
attr(s3obj, "src") <- pevals
} else {
attr(s3obj, "src") <- NA
}
attr(s3obj, "validated") <- FALSE
# Call .validate.curves()
.validate(s3obj)
}
#
# Create a ROC curve
#
create_roc <- function(pevals, scores = NULL, labels = NULL, x_bins = 1000,
keep_pevals = FALSE, ...) {
# === Create a ROC curve ===
.create_curve("specificity", "sensitivity", create_roc_curve,
"create_roc_curve", "roc_curve", pevals, scores, labels,
x_bins, keep_pevals, ...)
}
#
# Create a Precision-Recall curve
#
create_prc <- function(pevals, scores = NULL, labels = NULL, x_bins = 1000,
keep_pevals = FALSE, ...) {
# === Create a Precision-Recall curve ===
.create_curve("sensitivity", "precision", create_prc_curve,
"create_prc_curve", "prc_curve", pevals, scores, labels,
x_bins, keep_pevals, ...)
}
#
# Create ROC or Precision-Recall curve
#
.create_curve <- function(x_name, y_name, func, func_name, class_name,
pevals, scores = NULL, labels = NULL, x_bins = 1000,
keep_pevals = FALSE, ...) {
# === Validate input arguments ===
# Create pevals from scores and labels if pevals is missing
pevals <- .create_src_obj(pevals, "pevals", calc_measures, scores, labels,
...)
.validate_x_bins(x_bins)
.validate(pevals)
# === Create a curve ===
# Calculate a curve
pb <- pevals[["basic"]]
crv <- func(attr(pevals, "src")[["tp"]], attr(pevals, "src")[["fp"]],
pb[[x_name]], pb[[y_name]], x_bins)
.check_cpp_func_error(crv, func_name)
# Calculate AUC
auc <- calc_auc(crv[["curve"]][["x"]], crv[["curve"]][["y"]])
if (auc[["errmsg"]] == "invalid-x-vals") {
warning(paste0("Invalid ", x_name,
" values detected. AUC can be inaccurate."))
} else {
.check_cpp_func_error(auc, "calc_auc")
}
# === Create an S3 object ===
s3obj <- structure(crv[["curve"]], class = class_name)
# Set attributes
attr(s3obj, "modname") <- attr(pevals, "modname")
attr(s3obj, "dsid") <- attr(pevals, "dsid")
attr(s3obj, "nn") <- attr(pevals, "nn")
attr(s3obj, "np") <- attr(pevals, "np")
attr(s3obj, "auc") <- auc[["auc"]]
attr(s3obj, "xlim") <- c(0, 1)
attr(s3obj, "ylim") <- c(0, 1)
attr(s3obj, "pauc") <- NA
attr(s3obj, "spauc") <- NA
attr(s3obj, "args") <- c(list(x_bins = x_bins), list(...))
attr(s3obj, "cpp_errmsg1") <- crv[["errmsg"]]
attr(s3obj, "cpp_errmsg2") <- auc[["errmsg"]]
if (keep_pevals) {
attr(s3obj, "src") <- pevals
} else {
attr(s3obj, "src") <- NA
}
attr(s3obj, "validated") <- FALSE
# Call .validate.roc_curve() or .validate.prc_curve()
.validate(s3obj)
}
#
# Validate 'roc_curve' object generated by create_roc()
#
.validate.roc_curve <- function(roc_curve) {
# Need to validate only once
if (methods::is(roc_curve, "roc_curve") && attr(roc_curve, "validated")) {
return(roc_curve)
}
# Validate class items and attributes
.validate_curve(roc_curve, "roc_curve", "create_roc")
attr(roc_curve, "validated") <- TRUE
roc_curve
}
#
# Validate 'prc_curve' object generated by create_roc()
#
.validate.prc_curve <- function(prc_curve) {
# Need to validate only once
if (methods::is(prc_curve, "prc_curve") && attr(prc_curve, "validated")) {
return(prc_curve)
}
# Validate class items and attributes
.validate_curve(prc_curve, "prc_curve", "create_prc")
attr(prc_curve, "validated") <- TRUE
prc_curve
}
#
# Validate 'roc_curve' or 'prc_curve'
#
.validate_curve <- function(obj, class_name, func_name) {
# Validate class items and attributes
item_names <- c("x", "y", "orig_points")
attr_names <- c("modname", "dsid", "nn", "np", "auc", "args",
"cpp_errmsg1", "cpp_errmsg2", "src", "validated")
arg_names <- c("x_bins", "na_worst", "na.last", "ties_method", "ties.method",
"modname", "dsid", "keep_fmdat", "keep_cmats")
.validate_basic(obj, class_name, func_name, item_names, attr_names,
arg_names)
# Check values of class items
if ((length(obj[["x"]]) != length(obj[["y"]]))
|| (length(obj[["x"]]) != length(obj[["orig_points"]]))) {
stop("x, y, and orig_points must be all the same lengths", call. = FALSE)
} else if (!(length(obj[["x"]]) > 2)) {
stop("The minimum length of x, y, and orig_points must be 3",
call. = FALSE)
}
# Check values of class attributes
# AUC
assertthat::assert_that((attr(obj, "auc") >= 0) && (attr(obj, "auc") <= 1))
}
#
# Validate 'curves' object generated by create_curves()
#
.validate.curves <- function(curves) {
# Need to validate only once
if (methods::is(curves, "curves") && attr(curves, "validated")) {
return(curves)
}
# Validate class items and attributes
item_names <- c("roc", "prc")
attr_names <- c("modname", "dsid", "nn", "np", "args", "src",
"validated")
arg_names <- c("x_bins", "na_worst", "na.last", "ties_method", "ties.method",
"modname", "dsid", "keep_fmdat", "keep_cmats")
.validate_basic(curves, "curves", "calc_measures", item_names, attr_names,
arg_names)
# Check values of class items
curves[["roc"]] <- .validate(curves[["roc"]])
curves[["prc"]] <- .validate(curves[["prc"]])
attr(curves, "validated") <- TRUE
curves
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.