R/pl5_create_curves.R

Defines functions .validate.curves .validate_curve .validate.prc_curve .validate.roc_curve .create_curve create_prc create_roc create_curves

#
# Create ROC and Precision-Recall curves
#
create_curves <- function(pevals, scores = NULL, labels = NULL,
                          x_bins = 1000, keep_pevals = FALSE, ...) {
  # === Validate input arguments ===
  # Create pevals from scores and labels if pevals is missing
  pevals <- .create_src_obj(
    pevals, "pevals", calc_measures, scores, labels,
    ...
  )

  if (is.null(x_bins) || any(is.na(x_bins))) {
    x_bins <- 1
  }
  .validate_x_bins(x_bins, allow_zero = TRUE)
  .validate(pevals)

  # === Create ROC and Precision-Recall curves ===
  roc_curve <- create_roc(pevals,
    x_bins = x_bins,
    keep_pevals = keep_pevals, ...
  )
  prc_curve <- create_prc(pevals,
    x_bins = x_bins,
    keep_pevals = keep_pevals, ...
  )

  curves <- list(roc = roc_curve, prc = prc_curve)

  # === Create an S3 object ===
  s3obj <- structure(curves, class = "curves")

  # Set attributes
  attr(s3obj, "modname") <- attr(pevals, "modname")
  attr(s3obj, "dsid") <- attr(pevals, "dsid")
  attr(s3obj, "nn") <- attr(pevals, "nn")
  attr(s3obj, "np") <- attr(pevals, "np")
  attr(s3obj, "args") <- c(list(x_bins = x_bins), list(...))
  if (keep_pevals) {
    attr(s3obj, "src") <- pevals
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.curves()
  .validate(s3obj)
}

#
# Create a ROC curve
#
create_roc <- function(pevals, scores = NULL, labels = NULL, x_bins = 1000,
                       keep_pevals = FALSE, ...) {
  # === Create a ROC curve ===
  .create_curve(
    "specificity", "sensitivity", create_roc_curve,
    "create_roc_curve", "roc_curve", pevals, scores, labels,
    x_bins, keep_pevals, ...
  )
}

#
# Create a Precision-Recall curve
#
create_prc <- function(pevals, scores = NULL, labels = NULL, x_bins = 1000,
                       keep_pevals = FALSE, ...) {
  # === Create a Precision-Recall curve ===
  .create_curve(
    "sensitivity", "precision", create_prc_curve,
    "create_prc_curve", "prc_curve", pevals, scores, labels,
    x_bins, keep_pevals, ...
  )
}

#
# Create ROC or Precision-Recall curve
#
.create_curve <- function(x_name, y_name, func, func_name, class_name,
                          pevals, scores = NULL, labels = NULL, x_bins = 1000,
                          keep_pevals = FALSE, ...) {
  # === Validate input arguments ===
  # Create pevals from scores and labels if pevals is missing
  pevals <- .create_src_obj(
    pevals, "pevals", calc_measures, scores, labels,
    ...
  )
  .validate_x_bins(x_bins, allow_zero = TRUE)
  .validate(pevals)

  # === Create a curve ===
  # Calculate a curve
  pb <- pevals[["basic"]]
  crv <- func(
    attr(pevals, "src")[["tp"]], attr(pevals, "src")[["fp"]],
    pb[[x_name]], pb[[y_name]], x_bins
  )
  .check_cpp_func_error(crv, func_name)

  # Calculate AUC
  auc <- calc_auc(crv[["curve"]][["x"]], crv[["curve"]][["y"]])
  if (auc[["errmsg"]] == "invalid-x-vals") {
    warning(paste0(
      "Invalid ", x_name,
      " values detected. AUC can be inaccurate."
    ))
  } else {
    .check_cpp_func_error(auc, "calc_auc")
  }


  # === Create an S3 object ===
  s3obj <- structure(crv[["curve"]], class = class_name)

  # Set attributes
  attr(s3obj, "modname") <- attr(pevals, "modname")
  attr(s3obj, "dsid") <- attr(pevals, "dsid")
  attr(s3obj, "nn") <- attr(pevals, "nn")
  attr(s3obj, "np") <- attr(pevals, "np")
  attr(s3obj, "auc") <- auc[["auc"]]
  attr(s3obj, "xlim") <- c(0, 1)
  attr(s3obj, "ylim") <- c(0, 1)
  attr(s3obj, "pauc") <- NA
  attr(s3obj, "spauc") <- NA
  attr(s3obj, "args") <- c(list(x_bins = x_bins), list(...))
  attr(s3obj, "cpp_errmsg1") <- crv[["errmsg"]]
  attr(s3obj, "cpp_errmsg2") <- auc[["errmsg"]]
  if (keep_pevals) {
    attr(s3obj, "src") <- pevals
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.roc_curve() or .validate.prc_curve()
  .validate(s3obj)
}

#
# Validate 'roc_curve' object generated by create_roc()
#
.validate.roc_curve <- function(x) {
  # Need to validate only once
  if (methods::is(x, "roc_curve") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  .validate_curve(x, "roc_curve", "create_roc")

  attr(x, "validated") <- TRUE
  x
}

#
# Validate 'prc_curve' object generated by create_roc()
#
.validate.prc_curve <- function(x) {
  # Need to validate only once
  if (methods::is(x, "prc_curve") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  .validate_curve(x, "prc_curve", "create_prc")

  attr(x, "validated") <- TRUE
  x
}

#
# Validate 'roc_curve' or 'prc_curve'
#
.validate_curve <- function(obj, class_name, func_name) {
  # Validate class items and attributes
  item_names <- c("x", "y", "orig_points")
  attr_names <- c(
    "modname", "dsid", "nn", "np", "auc", "args",
    "cpp_errmsg1", "cpp_errmsg2", "src", "validated"
  )
  arg_names <- c(
    "x_bins", "na_worst", "na.last", "ties_method", "ties.method",
    "modname", "dsid", "keep_fmdat", "keep_cmats"
  )
  .validate_basic(
    obj, class_name, func_name, item_names, attr_names,
    arg_names
  )

  # Check values of class items
  if ((length(obj[["x"]]) != length(obj[["y"]])) ||
    (length(obj[["x"]]) != length(obj[["orig_points"]]))) {
    stop("x, y, and orig_points must be all the same lengths", call. = FALSE)
  } else if (!(length(obj[["x"]]) > 2)) {
    stop("The minimum length of x, y, and orig_points must be 3",
      call. = FALSE
    )
  }

  # Check values of class attributes
  # AUC
  assertthat::assert_that((attr(obj, "auc") >= 0) && (attr(obj, "auc") <= 1))
}

#
# Validate 'curves' object generated by create_curves()
#
.validate.curves <- function(x) {
  # Need to validate only once
  if (methods::is(x, "curves") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- c("roc", "prc")
  attr_names <- c(
    "modname", "dsid", "nn", "np", "args", "src",
    "validated"
  )
  arg_names <- c(
    "x_bins", "na_worst", "na.last", "ties_method", "ties.method",
    "modname", "dsid", "keep_fmdat", "keep_cmats"
  )
  .validate_basic(
    x, "curves", "calc_measures", item_names, attr_names,
    arg_names
  )

  # Check values of class items
  x[["roc"]] <- .validate(x[["roc"]])
  x[["prc"]] <- .validate(x[["prc"]])

  attr(x, "validated") <- TRUE
  x
}
takayasaito/precrec documentation built on Oct. 19, 2023, 7:28 p.m.