R/pl2_pipeline_main_aucroc.R

Defines functions summarize_uauc_results validate.aucroc pl_main_aucroc

#
# Control the main pipeline iterations for ROC (AUC) with the U statistic
#
.pl_main_aucroc <- function(mdat, model_type, dataset_type, class_name_pf,
                            calc_avg = FALSE, cb_alpha = 0.05,
                            raw_curves = FALSE, na_worst = TRUE,
                            ties_method = "equiv") {

  # === Calculate AUC ROC ===
  plfunc <- function(s) {
    uauc <- calc_auc_with_u(mdat[[s]], na_worst = na_worst,
                            ties_method = ties_method)
  }
  aucrocs <- lapply(seq_along(mdat), plfunc)
  auc.df <- .summarize_uauc_results(aucrocs, attr(mdat, "uniq_modnames"),
                                    attr(mdat, "uniq_dsids"), calc_avg,
                                    cb_alpha, raw_curves)

  # === Create an S3 object ===
  s3obj <- structure(auc.df, class = "aucroc")

  # Set attributes
  attr(s3obj, "data_info") <- attr(mdat, "data_info")
  attr(s3obj, "uniq_modnames") <- attr(mdat, "uniq_modnames")
  attr(s3obj, "uniq_dsids") <- attr(mdat, "uniq_dsids")
  attr(s3obj, "model_type") <- model_type
  attr(s3obj, "dataset_type") <- dataset_type
  attr(s3obj, "args") <- list(mode = "aucroc",
                              calc_avg = calc_avg,
                              cb_alpha = cb_alpha,
                              raw_curves = raw_curves,
                              na_worst = na_worst,
                              ties_method = ties_method)
  attr(s3obj, "validated") <- FALSE

  # Call .validate.class_name()
  .validate(s3obj)
}

#
# Validate 'aucroc' object generated by .pl_main_aucroc()
#
.validate.aucroc <- function(aucroc) {
  # Need to validate only once
  if (methods::is(aucroc, "aucroc") && attr(aucroc, "validated")) {
    return(aucroc)
  }

  # Validate class items and attributes
  item_names <- NULL
  attr_names <- c("data_info", "uniq_modnames", "uniq_dsids",
                  "model_type", "dataset_type", "args", "validated")
  arg_names <- c("mode", "calc_avg", "cb_alpha", "raw_curves", "na_worst",
                 "ties_method")
  .validate_basic(aucroc, "aucroc", ".pl_main_aucroc", item_names,
                  attr_names, arg_names)

  attr(aucroc, "validated") <- TRUE
  aucroc
}

#
# Create a dataframe with AUCs
#
.summarize_uauc_results <- function(aucs, uniq_modnames, uniq_dsids,
                                    calc_avg, cb_alpha, raw_curves) {
  auc_df <- NA

  if (!is.null(aucs)) {
    n <- length(aucs)
    vmodname <- factor(character(n), levels = uniq_modnames)
    vdsid <- factor(character(n), levels = uniq_dsids)
    vaucs <- numeric(n)
    vustat <- numeric(n)

    for (i in seq_along(aucs)) {
      vmodname[i] <- attr(aucs[[i]], "modname")
      vdsid[i] <- attr(aucs[[i]], "dsid")
      vaucs[i] <- aucs[[i]]$auc
      vustat[i] <- aucs[[i]]$ustat
    }

    auc_df <- data.frame(modnames = vmodname,
                         dsids = vdsid,
                         aucs = vaucs,
                         ustats = vustat)
  }

  list(uaucs = auc_df)
}
takayasaito/precrec documentation built on Aug. 24, 2017, 8:07 a.m.