#
# Control the main pipeline iterations for ROC and Precision-Recall curves
#
.pl_main_rocprc <- function(mdat, model_type, dataset_type, class_name_pf,
calc_avg = TRUE, cb_alpha = 0.05,
raw_curves = FALSE, x_bins = 1000) {
if (!missing(dataset_type) && dataset_type == "single") {
calc_avg <- FALSE
raw_curves <- TRUE
}
# === Create ROC and Precision-Recall curves ===
# Create curves
plfunc <- function(s) {
cdat <- create_confmats(mdat[[s]])
pevals <- calc_measures(cdat)
curves <- create_curves(pevals, x_bins = x_bins)
}
lcurves <- lapply(seq_along(mdat), plfunc)
# Summarize curves by line type
grpfunc <- function(lt) {
.summarize_curves(lcurves, lt, "crvgrp", mdat, dataset_type,
calc_avg, cb_alpha, x_bins)
}
grp_curves <- lapply(c("roc", "prc"), grpfunc)
names(grp_curves)<- c("rocs", "prcs")
# Summarize AUCs
aucs <- .gather_aucs(lcurves, mdat)
# Summarize average
grpfunc2 <- function(lt) {
attr(grp_curves[[lt]], "avgcurves")
}
grp_avg <- lapply(names(grp_curves), grpfunc2)
names(grp_avg)<- names(grp_curves)
# === Create an S3 object ===
if (dataset_type == "multiple" && calc_avg && !raw_curves) {
grpfunc3 <- function(lt) {
.summarize_curves(NULL, lt, "crvgrp", mdat, NULL, NULL, NULL, NULL)
}
grp_curves <- lapply(c("roc", "prc"), grpfunc3)
names(grp_curves)<- c("rocs", "prcs")
}
s3obj <- structure(grp_curves, class = c(paste0(class_name_pf, "curves"),
"curve_info", "aucs"))
# Set attributes
attr(s3obj, "aucs") <- aucs
attr(s3obj, "paucs") <- NA
attr(s3obj, "grp_avg") <- grp_avg
attr(s3obj, "data_info") <- attr(mdat, "data_info")
attr(s3obj, "uniq_modnames") <- attr(mdat, "uniq_modnames")
attr(s3obj, "uniq_dsids") <- attr(mdat, "uniq_dsids")
attr(s3obj, "model_type") <- model_type
attr(s3obj, "dataset_type") <- dataset_type
attr(s3obj, "partial") <- FALSE
attr(s3obj, "args") <- list(mode = "rocprc",
calc_avg = calc_avg,
cb_alpha = cb_alpha,
raw_curves = raw_curves,
x_bins = x_bins)
attr(s3obj, "validated") <- FALSE
# Call .validate.class_name()
.validate(s3obj)
}
#
# Get ROC or Precision-Recall curves from curves
#
.summarize_curves <- function(lcurves, curve_type, class_name, mdat,
dataset_type, calc_avg, cb_alpha, x_bins) {
if (!is.null(lcurves)) {
# Summarize ROC or PRC curves
mc <- lapply(seq_along(lcurves), function(s) lcurves[[s]][[curve_type]])
# Calculate the average curves
if (dataset_type == "multiple" && calc_avg) {
modnames <- attr(mdat, "data_info")[["modnames"]]
uniq_modnames <- attr(mdat, "uniq_modnames")
avgcurves <- calc_avg_rocprc(mc, modnames, uniq_modnames, cb_alpha,
x_bins)
} else {
avgcurves <- NA
}
} else {
mc <- NA
avgcurves <- NA
}
# === Create an S3 object ===
s3obj <- structure(mc, class = class_name)
# Set attributes
attr(s3obj, "data_info") <- attr(mdat, "data_info")
attr(s3obj, "curve_type") <- curve_type
attr(s3obj, "xlim") <- c(0, 1)
attr(s3obj, "ylim") <- c(0, 1)
attr(s3obj, "uniq_modnames") <- attr(mdat, "uniq_modnames")
attr(s3obj, "uniq_dsids") <- attr(mdat, "uniq_dsids")
attr(s3obj, "avgcurves") <- avgcurves
attr(s3obj, "validated") <- FALSE
# Call .validate.class_name()
s3obj <- .validate(s3obj)
s3obj
}
#
# Get AUCs
#
.gather_aucs <- function(lcurves, mdat) {
# Collect AUCs of ROC or PRC curves
ct_len <- 2
modnames <- attr(mdat, "data_info")[["modnames"]]
dsids <- attr(mdat, "data_info")[["dsids"]]
aucs <- data.frame(modnames = rep(modnames, each = ct_len),
dsids = rep(dsids, each = ct_len),
curvetypes = rep(c("ROC", "PRC"), length(modnames)),
aucs = rep(NA, length(modnames) * ct_len),
stringsAsFactors = FALSE)
for (i in seq_along(lcurves)) {
idx <- ct_len * i - 1
aucs[["aucs"]][idx:(idx + 1)] <- c(attr(lcurves[[i]][["roc"]], "auc"),
attr(lcurves[[i]][["prc"]], "auc"))
}
aucs
}
#
# Validate curves object generated by .pl_main_rocprc()
#
.validate_curves_common <- function(curves, class_name) {
# Need to validate only once
if (methods::is(curves, class_name) && attr(curves, "validated")) {
return(curves)
}
# Validate class items and attributes
item_names <- c("rocs", "prcs")
attr_names <- c("aucs", "grp_avg", "data_info", "uniq_modnames",
"uniq_dsids", "model_type", "dataset_type", "args",
"validated")
arg_names <- c("mode", "calc_avg", "cb_alpha", "raw_curves", "x_bins")
.validate_basic(curves, class_name, ".pl_main_rocprc", item_names, attr_names,
arg_names)
attr(curves, "validated") <- TRUE
curves
}
#
# Validate 'sscurves' object generated by .pl_main_rocprc()
#
.validate.sscurves <- function(sscurves) {
.validate_curves_common(sscurves, "sscurves")
}
#
# Validate 'mscurves' object generated by .pl_main_rocprc()
#
.validate.mscurves <- function(mscurves) {
.validate_curves_common(mscurves, "mscurves")
}
#
# Validate 'smcurves' object generated by .pl_main_rocprc()
#
.validate.smcurves <- function(smcurves) {
.validate_curves_common(smcurves, "smcurves")
}
#
# Validate 'mmcurves' object generated by .pl_main_rocprc()
#
.validate.mmcurves <- function(mmcurves) {
.validate_curves_common(mmcurves, "mmcurves")
}
#
# Validate 'crvgrp' object generated by .pl_main_rocprc()
#
.validate.crvgrp <- function(crvgrp) {
# Need to validate only once
if (methods::is(crvgrp, "crvgrp") && attr(crvgrp, "validated")) {
return(crvgrp)
}
# Validate class items and attributes
item_names <- NULL
attr_names <- c("data_info", "curve_type", "uniq_modnames", "uniq_dsids",
"avgcurves", "validated")
arg_names <- NULL
.validate_basic(crvgrp, "crvgrp", ".summarize_curves", item_names,
attr_names, arg_names)
attr(crvgrp, "validated") <- TRUE
crvgrp
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.