R/pl3_calc_auc_with_u.R

Defines functions .validate.uauc calc_auc_with_u

#
# Calculate AUCs with U statistic
#
calc_auc_with_u <- function(sdat, scores = NULL, labels = NULL, na_worst = TRUE,
                            ties_method = "equiv", keep_sdat = FALSE,
                            ustat_method = "frank", ...) {
  # === Validate input arguments ===
  # Create sdat from scores and labels if sdat is missing
  sdat <- .create_src_obj(sdat, "sdat", reformat_data, scores, labels,
    mode = "aucroc", ...
  )
  .validate(sdat)

  # === Calculate AUCs (ROC) ===
  # Call a cpp function via Rcpp interface
  dt_loaded <- TRUE
  if (ustat_method == "frank") {
    dt_loaded <- .load_data_table()
    if (dt_loaded) {
      if (na_worst) {
        na_last <- FALSE
      } else {
        na_last <- TRUE
      }
      if (ties_method == "random") {
        ties_method <- "random"
      } else {
        ties_method <- "average"
      }

      frank_func <- function(x) {
        data.table::frank(x, na.last = na_last, ties.method = ties_method)
      }

      uauc <- calc_uauc_frank(
        attr(sdat, "np"), attr(sdat, "nn"),
        sdat[["scores"]], sdat[["labels"]],
        na_last, ties_method, frank_func
      )
      .check_cpp_func_error(uauc, "calc_uauc_fsort")
    }
  }

  if (ustat_method == "sort" || (ustat_method == "frank" && !dt_loaded)) {
    uauc <- calc_uauc(
      attr(sdat, "np"), attr(sdat, "nn"), sdat[["scores"]],
      sdat[["labels"]], na_worst, ties_method
    )
    .check_cpp_func_error(uauc, "calc_uauc")
  }

  # === Create an S3 object ===
  cpp_errmsg <- uauc[["errmsg"]]
  uauc[["errmsg"]] <- NULL
  s3obj <- structure(uauc, class = "uauc")

  # Set attributes
  attr(s3obj, "modname") <- attr(sdat, "modname")
  attr(s3obj, "dsid") <- attr(sdat, "dsid")
  attr(s3obj, "nn") <- attr(sdat, "nn")
  attr(s3obj, "np") <- attr(sdat, "np")
  attr(s3obj, "args") <- list(...)
  attr(s3obj, "cpp_errmsg") <- cpp_errmsg
  if (keep_sdat) {
    attr(s3obj, "src") <- sdat
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.uauc()
  .validate(s3obj)
}

#
# Validate 'uauc' object generated by create_confmats()
#
.validate.uauc <- function(x) {
  # Need to validate only once
  if (methods::is(x, "uauc") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- "auc"
  attr_names <- c(
    "modname", "dsid", "nn", "np", "args", "cpp_errmsg",
    "src", "validated"
  )
  arg_names <- c(
    "na_worst", "na.last", "ties_method", "ties.method",
    "modname", "dsid", "keep_fmdat"
  )
  .validate_basic(
    x, "uauc", "calc_auc_with_u", item_names, attr_names,
    arg_names
  )

  # AUC
  auc <- x[["auc"]]
  assertthat::assert_that(
    assertthat::is.number(auc),
    auc >= 0, auc <= 1
  )

  attr(x, "validated") <- TRUE
  x
}

Try the precrec package in your browser

Any scripts or data that you put into this service are public.

precrec documentation built on Oct. 12, 2023, 1:06 a.m.