R/pl2_pipeline_main_rocprc.R

Defines functions .validate.crvgrp .validate.mmcurves .validate.smcurves .validate.mscurves .validate.sscurves .validate_curves_common .gather_aucs .summarize_curves .pl_main_rocprc

#
# Control the main pipeline iterations for ROC and Precision-Recall curves
#
.pl_main_rocprc <- function(mdat, model_type, dataset_type, class_name_pf,
                            calc_avg = TRUE, cb_alpha = 0.05,
                            raw_curves = FALSE, x_bins = 1000,
                            interpolate = TRUE) {
  if (!missing(dataset_type) && dataset_type == "single") {
    calc_avg <- FALSE
    raw_curves <- TRUE
  }

  if (!interpolate) {
    calc_avg <- FALSE
    x_bins <- 0
  }

  # === Create ROC and Precision-Recall curves ===
  # Create curves
  plfunc <- function(s) {
    if (attr(mdat[[s]], "nn") == 0 || attr(mdat[[s]], "np") == 0) {
      if (attr(mdat[[s]], "np") > 0) {
        cl <- "positive"
      } else {
        cl <- "negative"
      }
      err_msg <- paste0(
        "Curves cannot be calculated. ",
        "Only a single class (", cl, ") ",
        "found in dataset (modname: ",
        attr(mdat[[s]], "modname"),
        ", dsid: ", attr(mdat[[s]], "dsid"), ")."
      )
      stop(err_msg, call. = FALSE)
    }
    cdat <- create_confmats(mdat[[s]])
    pevals <- calc_measures(cdat)
    create_curves(pevals, x_bins = x_bins)
  }
  lcurves <- lapply(seq_along(mdat), plfunc)

  # Summarize curves by line type
  grpfunc <- function(lt) {
    .summarize_curves(
      lcurves, lt, "crvgrp", mdat, dataset_type,
      calc_avg, cb_alpha, x_bins
    )
  }
  grp_curves <- lapply(c("roc", "prc"), grpfunc)
  names(grp_curves) <- c("rocs", "prcs")

  # Summarize AUCs
  aucs <- .gather_aucs(lcurves, mdat)

  # Summarize average
  grpfunc2 <- function(lt) {
    attr(grp_curves[[lt]], "avgcurves")
  }
  grp_avg <- lapply(names(grp_curves), grpfunc2)
  names(grp_avg) <- names(grp_curves)

  # === Create an S3 object ===
  if (dataset_type == "multiple" && calc_avg && !raw_curves) {
    grpfunc3 <- function(lt) {
      .summarize_curves(NULL, lt, "crvgrp", mdat, NULL, NULL, NULL, NULL)
    }
    grp_curves <- lapply(c("roc", "prc"), grpfunc3)
    names(grp_curves) <- c("rocs", "prcs")
  }
  s3obj <- structure(grp_curves, class = c(
    paste0(class_name_pf, "curves"),
    "curve_info", "aucs"
  ))

  # Set attributes
  attr(s3obj, "aucs") <- aucs
  attr(s3obj, "paucs") <- NA
  attr(s3obj, "grp_avg") <- grp_avg
  attr(s3obj, "data_info") <- attr(mdat, "data_info")
  attr(s3obj, "uniq_modnames") <- attr(mdat, "uniq_modnames")
  attr(s3obj, "uniq_dsids") <- attr(mdat, "uniq_dsids")
  attr(s3obj, "model_type") <- model_type
  attr(s3obj, "dataset_type") <- dataset_type
  attr(s3obj, "partial") <- FALSE
  attr(s3obj, "args") <- list(
    mode = "rocprc",
    calc_avg = calc_avg,
    cb_alpha = cb_alpha,
    raw_curves = raw_curves,
    x_bins = x_bins,
    interpolate = interpolate
  )
  attr(s3obj, "validated") <- FALSE

  # Call .validate.class_name()
  .validate(s3obj)
}

#
# Get ROC or Precision-Recall curves from curves
#
.summarize_curves <- function(lcurves, curve_type, class_name, mdat,
                              dataset_type, calc_avg, cb_alpha, x_bins) {
  if (!is.null(lcurves)) {
    # Summarize ROC or PRC curves
    mc <- lapply(seq_along(lcurves), function(s) lcurves[[s]][[curve_type]])

    # Calculate the average curves
    if (dataset_type == "multiple" && calc_avg) {
      modnames <- attr(mdat, "data_info")[["modnames"]]
      uniq_modnames <- attr(mdat, "uniq_modnames")
      avgcurves <- calc_avg_rocprc(
        mc, modnames, uniq_modnames, cb_alpha,
        x_bins
      )
    } else {
      avgcurves <- NA
    }
  } else {
    mc <- NA
    avgcurves <- NA
  }

  # === Create an S3 object ===
  s3obj <- structure(mc, class = class_name)

  # Set attributes
  attr(s3obj, "data_info") <- attr(mdat, "data_info")
  attr(s3obj, "curve_type") <- curve_type
  attr(s3obj, "xlim") <- c(0, 1)
  attr(s3obj, "ylim") <- c(0, 1)
  attr(s3obj, "uniq_modnames") <- attr(mdat, "uniq_modnames")
  attr(s3obj, "uniq_dsids") <- attr(mdat, "uniq_dsids")
  attr(s3obj, "avgcurves") <- avgcurves
  attr(s3obj, "validated") <- FALSE

  # Call .validate.class_name()
  s3obj <- .validate(s3obj)

  s3obj
}

#
# Get AUCs
#
.gather_aucs <- function(lcurves, mdat) {
  # Collect AUCs of ROC or PRC curves
  ct_len <- 2
  modnames <- attr(mdat, "data_info")[["modnames"]]
  dsids <- attr(mdat, "data_info")[["dsids"]]
  aucs <- data.frame(
    modnames = rep(modnames, each = ct_len),
    dsids = rep(dsids, each = ct_len),
    curvetypes = rep(c("ROC", "PRC"), length(modnames)),
    aucs = rep(NA, length(modnames) * ct_len),
    stringsAsFactors = FALSE
  )

  for (i in seq_along(lcurves)) {
    idx <- ct_len * i - 1
    aucs[["aucs"]][idx:(idx + 1)] <- c(
      attr(lcurves[[i]][["roc"]], "auc"),
      attr(lcurves[[i]][["prc"]], "auc")
    )
  }

  aucs
}

#
# Validate curves object generated by .pl_main_rocprc()
#
.validate_curves_common <- function(curves, class_name) {
  # Need to validate only once
  if (methods::is(curves, class_name) && attr(curves, "validated")) {
    return(curves)
  }

  # Validate class items and attributes
  item_names <- c("rocs", "prcs")
  attr_names <- c(
    "aucs", "grp_avg", "data_info", "uniq_modnames",
    "uniq_dsids", "model_type", "dataset_type", "args",
    "validated"
  )
  arg_names <- c(
    "mode", "calc_avg", "cb_alpha", "raw_curves",
    "x_bins", "interpolate"
  )
  .validate_basic(
    curves, class_name, ".pl_main_rocprc", item_names, attr_names,
    arg_names
  )

  attr(curves, "validated") <- TRUE
  curves
}

#
# Validate 'sscurves' object generated by .pl_main_rocprc()
#
.validate.sscurves <- function(x) {
  .validate_curves_common(x, "sscurves")
}

#
# Validate 'mscurves' object generated by .pl_main_rocprc()
#
.validate.mscurves <- function(x) {
  .validate_curves_common(x, "mscurves")
}

#
# Validate 'smcurves' object generated by .pl_main_rocprc()
#
.validate.smcurves <- function(x) {
  .validate_curves_common(x, "smcurves")
}

#
# Validate 'mmcurves' object generated by .pl_main_rocprc()
#
.validate.mmcurves <- function(x) {
  .validate_curves_common(x, "mmcurves")
}

#
# Validate 'crvgrp' object generated by .pl_main_rocprc()
#
.validate.crvgrp <- function(x) {
  # Need to validate only once
  if (methods::is(x, "crvgrp") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- NULL
  attr_names <- c(
    "data_info", "curve_type", "uniq_modnames", "uniq_dsids",
    "avgcurves", "validated"
  )
  arg_names <- NULL
  .validate_basic(
    x, "crvgrp", ".summarize_curves", item_names,
    attr_names, arg_names
  )

  attr(x, "validated") <- TRUE
  x
}

Try the precrec package in your browser

Any scripts or data that you put into this service are public.

precrec documentation built on Oct. 12, 2023, 1:06 a.m.