R/pl5_create_curves.R

#
# Create ROC and Precision-Recall curves
#
create_curves <- function(pevals, scores = NULL, labels = NULL,
                          x_bins = 1000, keep_pevals = FALSE, ...) {

  # === Validate input arguments ===
  # Create pevals from scores and labels if pevals is missing
  pevals <- .create_src_obj(pevals, "pevals", calc_measures, scores, labels,
                                 ...)

  if (is.null(x_bins) || is.na(x_bins)) {
    x_bins <- 1
  }
  .validate_x_bins(x_bins)
  .validate(pevals)

  # === Create ROC and Precision-Recall curves ===
  roc_curve <- create_roc(pevals, x_bins = x_bins,
                          keep_pevals = keep_pevals, ...)
  prc_curve <- create_prc(pevals, x_bins = x_bins,
                          keep_pevals = keep_pevals, ...)

  curves <- list(roc = roc_curve, prc = prc_curve)

  # === Create an S3 object ===
  s3obj <- structure(curves, class = "curves")

  # Set attributes
  attr(s3obj, "modname") <- attr(pevals, "modname")
  attr(s3obj, "dsid") <- attr(pevals, "dsid")
  attr(s3obj, "nn") <- attr(pevals, "nn")
  attr(s3obj, "np") <- attr(pevals, "np")
  attr(s3obj, "args") <- c(list(x_bins = x_bins), list(...))
  if (keep_pevals) {
    attr(s3obj, "src") <- pevals
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.curves()
  .validate(s3obj)
}

#
# Create a ROC curve
#
create_roc <- function(pevals, scores = NULL, labels = NULL, x_bins = 1000,
                       keep_pevals = FALSE, ...) {

  # === Create a ROC curve ===
  .create_curve("specificity", "sensitivity", create_roc_curve,
                "create_roc_curve", "roc_curve", pevals, scores, labels,
                x_bins, keep_pevals, ...)
}

#
# Create a Precision-Recall curve
#
create_prc <- function(pevals, scores = NULL, labels = NULL, x_bins = 1000,
                       keep_pevals = FALSE, ...) {

  # === Create a Precision-Recall curve ===
  .create_curve("sensitivity", "precision", create_prc_curve,
                "create_prc_curve", "prc_curve", pevals, scores, labels,
                x_bins, keep_pevals, ...)
}

#
# Create ROC or Precision-Recall curve
#
.create_curve <- function(x_name, y_name, func, func_name, class_name,
                          pevals, scores = NULL, labels = NULL, x_bins = 1000,
                          keep_pevals = FALSE, ...) {

  # === Validate input arguments ===
  # Create pevals from scores and labels if pevals is missing
  pevals <- .create_src_obj(pevals, "pevals", calc_measures, scores, labels,
                            ...)
  .validate_x_bins(x_bins)
  .validate(pevals)

  # === Create a curve ===
  # Calculate a curve
  pb <- pevals[["basic"]]
  crv <- func(attr(pevals, "src")[["tp"]], attr(pevals, "src")[["fp"]],
              pb[[x_name]], pb[[y_name]], x_bins)
  .check_cpp_func_error(crv, func_name)

  # Calculate AUC
  auc <- calc_auc(crv[["curve"]][["x"]], crv[["curve"]][["y"]])
  if (auc[["errmsg"]] == "invalid-x-vals") {
    warning(paste0("Invalid ", x_name,
                   " values detected. AUC can be inaccurate."))
  } else {
    .check_cpp_func_error(auc, "calc_auc")
  }


  # === Create an S3 object ===
  s3obj <- structure(crv[["curve"]], class = class_name)

  # Set attributes
  attr(s3obj, "modname") <- attr(pevals, "modname")
  attr(s3obj, "dsid") <- attr(pevals, "dsid")
  attr(s3obj, "nn") <- attr(pevals, "nn")
  attr(s3obj, "np") <- attr(pevals, "np")
  attr(s3obj, "auc") <- auc[["auc"]]
  attr(s3obj, "xlim") <- c(0, 1)
  attr(s3obj, "ylim") <- c(0, 1)
  attr(s3obj, "pauc") <- NA
  attr(s3obj, "spauc") <- NA
  attr(s3obj, "args") <- c(list(x_bins = x_bins), list(...))
  attr(s3obj, "cpp_errmsg1") <- crv[["errmsg"]]
  attr(s3obj, "cpp_errmsg2") <- auc[["errmsg"]]
  if (keep_pevals) {
    attr(s3obj, "src") <- pevals
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.roc_curve() or .validate.prc_curve()
  .validate(s3obj)
}

#
# Validate 'roc_curve' object generated by create_roc()
#
.validate.roc_curve <- function(roc_curve) {
  # Need to validate only once
  if (methods::is(roc_curve, "roc_curve") && attr(roc_curve, "validated")) {
    return(roc_curve)
  }

  # Validate class items and attributes
  .validate_curve(roc_curve, "roc_curve", "create_roc")

  attr(roc_curve, "validated") <- TRUE
  roc_curve
}

#
# Validate 'prc_curve' object generated by create_roc()
#
.validate.prc_curve <- function(prc_curve) {
  # Need to validate only once
  if (methods::is(prc_curve, "prc_curve") && attr(prc_curve, "validated")) {
    return(prc_curve)
  }

  # Validate class items and attributes
  .validate_curve(prc_curve, "prc_curve", "create_prc")

  attr(prc_curve, "validated") <- TRUE
  prc_curve
}

#
# Validate 'roc_curve' or 'prc_curve'
#
.validate_curve <- function(obj, class_name, func_name) {
  # Validate class items and attributes
  item_names <- c("x", "y", "orig_points")
  attr_names <- c("modname", "dsid", "nn", "np", "auc", "args",
                  "cpp_errmsg1", "cpp_errmsg2", "src", "validated")
  arg_names <- c("x_bins", "na_worst", "na.last", "ties_method", "ties.method",
                 "modname", "dsid", "keep_fmdat", "keep_cmats")
  .validate_basic(obj, class_name, func_name, item_names, attr_names,
                  arg_names)

  # Check values of class items
  if ((length(obj[["x"]]) != length(obj[["y"]]))
      || (length(obj[["x"]]) != length(obj[["orig_points"]]))) {
    stop("x, y, and orig_points must be all the same lengths", call. = FALSE)
  } else if (!(length(obj[["x"]]) > 2)) {
    stop("The minimum length of x, y, and orig_points must be 3",
         call. = FALSE)
  }

  # Check values of class attributes
  # AUC
  assertthat::assert_that((attr(obj, "auc") >= 0) && (attr(obj, "auc") <= 1))

}

#
# Validate 'curves' object generated by create_curves()
#
.validate.curves <- function(curves) {
  # Need to validate only once
  if (methods::is(curves, "curves") && attr(curves, "validated")) {
    return(curves)
  }

  # Validate class items and attributes
  item_names <- c("roc", "prc")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "src",
                  "validated")
  arg_names <- c("x_bins", "na_worst", "na.last", "ties_method", "ties.method",
                 "modname", "dsid", "keep_fmdat", "keep_cmats")
  .validate_basic(curves, "curves", "calc_measures", item_names, attr_names,
                  arg_names)

  # Check values of class items
  curves[["roc"]] <- .validate(curves[["roc"]])
  curves[["prc"]] <- .validate(curves[["prc"]])

  attr(curves, "validated") <- TRUE
  curves
}
guillermozbta/precrec documentation built on May 11, 2019, 7:22 p.m.