R/pl2_pipeline_main_aucroc.R

Defines functions .summarize_uauc_results .validate.aucroc .pl_main_aucroc

#
# Control the main pipeline iterations for ROC (AUC) with the U statistic
#
.pl_main_aucroc <- function(mdat, model_type, dataset_type, class_name_pf,
                            calc_avg = FALSE, cb_alpha = 0.05,
                            raw_curves = FALSE, na_worst = TRUE,
                            ties_method = "equiv") {
  # === Calculate AUC ROC ===
  plfunc <- function(s) {
    # AUC with the U statistic
    if (attr(mdat[[s]], "nn") == 0 || attr(mdat[[s]], "np") == 0) {
      if (attr(mdat[[s]], "np") > 0) {
        cl <- "positive"
      } else {
        cl <- "negative"
      }
      err_msg <- paste0(
        "AUCs with the U statistic cannot be calculated. ",
        "Only a single class (", cl, ") ",
        "found in dataset (modname: ",
        attr(mdat[[s]], "modname"),
        ", dsid: ", attr(mdat[[s]], "dsid"), ")."
      )
      stop(err_msg, call. = FALSE)
    }
    calc_auc_with_u(mdat[[s]],
      na_worst = na_worst,
      ties_method = ties_method
    )
  }
  aucrocs <- lapply(seq_along(mdat), plfunc)
  auc_df <- .summarize_uauc_results(
    aucrocs, attr(mdat, "uniq_modnames"),
    attr(mdat, "uniq_dsids"), calc_avg,
    cb_alpha, raw_curves
  )

  # === Create an S3 object ===
  s3obj <- structure(auc_df, class = "aucroc")

  # Set attributes
  attr(s3obj, "data_info") <- attr(mdat, "data_info")
  attr(s3obj, "uniq_modnames") <- attr(mdat, "uniq_modnames")
  attr(s3obj, "uniq_dsids") <- attr(mdat, "uniq_dsids")
  attr(s3obj, "model_type") <- model_type
  attr(s3obj, "dataset_type") <- dataset_type
  attr(s3obj, "args") <- list(
    mode = "aucroc",
    calc_avg = calc_avg,
    cb_alpha = cb_alpha,
    raw_curves = raw_curves,
    na_worst = na_worst,
    ties_method = ties_method
  )
  attr(s3obj, "validated") <- FALSE

  # Call .validate.class_name()
  .validate(s3obj)
}

#
# Validate 'aucroc' object generated by .pl_main_aucroc()
#
.validate.aucroc <- function(x) {
  # Need to validate only once
  if (methods::is(x, "aucroc") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- NULL
  attr_names <- c(
    "data_info", "uniq_modnames", "uniq_dsids",
    "model_type", "dataset_type", "args", "validated"
  )
  arg_names <- c(
    "mode", "calc_avg", "cb_alpha", "raw_curves", "na_worst",
    "ties_method"
  )
  .validate_basic(
    x, "aucroc", ".pl_main_aucroc", item_names,
    attr_names, arg_names
  )

  attr(x, "validated") <- TRUE
  x
}

#
# Create a dataframe with AUCs
#
.summarize_uauc_results <- function(aucs, uniq_modnames, uniq_dsids,
                                    calc_avg, cb_alpha, raw_curves) {
  auc_df <- NA

  if (!is.null(aucs)) {
    n <- length(aucs)
    vmodname <- factor(character(n), levels = uniq_modnames)
    vdsid <- factor(character(n), levels = uniq_dsids)
    vaucs <- numeric(n)
    vustat <- numeric(n)

    for (i in seq_along(aucs)) {
      vmodname[i] <- attr(aucs[[i]], "modname")
      vdsid[i] <- attr(aucs[[i]], "dsid")
      vaucs[i] <- aucs[[i]]$auc
      vustat[i] <- aucs[[i]]$ustat
    }

    auc_df <- data.frame(
      modnames = vmodname,
      dsids = vdsid,
      aucs = vaucs,
      ustats = vustat
    )
  }

  list(uaucs = auc_df)
}

Try the precrec package in your browser

Any scripts or data that you put into this service are public.

precrec documentation built on Oct. 12, 2023, 1:06 a.m.