R/pl3_calc_auc_with_u.R

#
# Calculate AUCs with U statistic
#
calc_auc_with_u <- function(sdat, scores = NULL, labels = NULL, na_worst = TRUE,
                            ties_method = "equiv", keep_sdat = FALSE,
                            ustat_method = "frank", ...) {

  # === Validate input arguments ===
  # Create sdat from scores and labels if sdat is missing
  sdat <- .create_src_obj(sdat, "sdat", reformat_data, scores, labels,
                          mode = "aucroc", ...)
  .validate(sdat)

  # === Calculate AUCs (ROC) ===
  # Call a cpp function via Rcpp interface
  dt_loaded <- TRUE
  if (ustat_method == "frank") {
    dt_loaded <- .load_data_table()
    if (dt_loaded) {
      if (na_worst) {
        na.last <- FALSE
      } else {
        na.last <- TRUE
      }
      if (ties_method == "random") {
        ties.method <- "random"
      } else {
        ties.method <- "average"
      }

      frank_func <- function(x) {
        data.table::frank(x, na.last = na.last, ties.method = ties.method)
      }

      uauc <- calc_uauc_frank(attr(sdat, "np"), attr(sdat, "nn"),
                              sdat[["scores"]], sdat[["labels"]],
                              na.last, ties.method, frank_func)
      .check_cpp_func_error(uauc, "calc_uauc_fsort")
    }
  }

  if (ustat_method == "sort" || (ustat_method == "frank" && !dt_loaded)) {
    uauc <- calc_uauc(attr(sdat, "np"), attr(sdat, "nn"), sdat[["scores"]],
                      sdat[["labels"]], na_worst, ties_method)
    .check_cpp_func_error(uauc, "calc_uauc")
  }

  # === Create an S3 object ===
  cpp_errmsg <- uauc[["errmsg"]]
  uauc[["errmsg"]] <- NULL
  s3obj <- structure(uauc, class = "uauc")

  # Set attributes
  attr(s3obj, "modname") <- attr(sdat, "modname")
  attr(s3obj, "dsid") <- attr(sdat, "dsid")
  attr(s3obj, "nn") <- attr(sdat, "nn")
  attr(s3obj, "np") <- attr(sdat, "np")
  attr(s3obj, "args") <- list(...)
  attr(s3obj, "cpp_errmsg") <- cpp_errmsg
  if (keep_sdat) {
    attr(s3obj, "src") <- sdat
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.uauc()
  .validate(s3obj)
}

#
# Validate 'uauc' object generated by create_confmats()
#
.validate.uauc <- function(uauc) {
  # Need to validate only once
  if (methods::is(uauc, "uauc") && attr(uauc, "validated")) {
    return(uauc)
  }

  # Validate class items and attributes
  item_names <- c("auc")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "cpp_errmsg",
                  "src", "validated")
  arg_names <- c("na_worst", "na.last", "ties_method", "ties.method",
                 "modname", "dsid", "keep_fmdat")
  .validate_basic(uauc, "uauc", "calc_auc_with_u", item_names, attr_names,
                  arg_names)

  # AUC
  auc <- uauc[["auc"]]
  assertthat::assert_that(assertthat::is.number(auc),
                          auc >= 0, auc <= 1)

  attr(uauc, "validated") <- TRUE
  uauc
}
guillermozbta/precrec documentation built on May 11, 2019, 7:22 p.m.