R/g_part.R

Defines functions .gather_paucs_avg .gather_paucs .calc_pauc .prepare_part_calc part.mmcurves part.smcurves part.mscurves part.sscurves part.default part

Documented in part part.mmcurves part.mscurves part.smcurves part.sscurves

#' Calculate partial AUCs
#'
#' The \code{part} function takes an \code{S3} object generated by
#'   \code{\link{evalmod}} and calculate partial AUCs and Standardized partial
#'   AUCs of ROC and Precision-Recall curves.
#'   Standardized pAUCs are standardized to the range between 0 and 1.
#'
#' @param curves An \code{S3} object generated by \code{\link{evalmod}}.
#'   The \code{part} function accepts the following S3 objects.
#'
#'   \tabular{lll}{
#'     \strong{\code{S3} object}
#'     \tab \strong{# of models}
#'     \tab \strong{# of test datasets} \cr
#'
#'     sscurves \tab single   \tab single   \cr
#'     mscurves \tab multiple \tab single   \cr
#'     smcurves \tab single   \tab multiple \cr
#'     mmcurves \tab multiple \tab multiple
#'   }
#'
#'    See the \strong{Value} section of \code{\link{evalmod}} for more details.
#'
#' @param xlim A numeric vector of length two to specify x range between
#'   two points in [0, 1]
#'
#' @param ylim A numeric vector of length two to specify y range between
#'   two points in [0, 1]
#'
#' @param curvetype A character vector with the following curve types.
#'   \tabular{ll}{
#'     \strong{curvetype} \tab \strong{description} \cr
#'     ROC \tab ROC curve \cr
#'     PRC \tab Precision-Recall curve
#'   }
#'   Multiple \code{curvetype} can be combined, such as
#'   \code{c("ROC", "PRC")}.
#'
#' @return The \code{part} function returns the same S3 object specified as
#'   input with calculated pAUCs and standardized pAUCs.
#'
#' @seealso \code{\link{evalmod}} for generating \code{S3} objects with
#'   performance evaluation measures. \code{\link{pauc}} for retrieving
#'   a dataset of pAUCs.
#'
#' @examples
#' \dontrun{
#'
#' ## Load library
#' library(ggplot2)
#'
#' ##################################################
#' ### Single model & single test dataset
#' ###
#'
#' ## Load a dataset with 10 positives and 10 negatives
#' data(P10N10)
#'
#' ## Generate an sscurve object that contains ROC and Precision-Recall curves
#' sscurves <- evalmod(scores = P10N10$scores, labels = P10N10$labels)
#'
#' ## Calculate partial AUCs
#' sscurves.part <- part(sscurves, xlim = c(0.25, 0.75))
#'
#' ## Show AUCs
#' sscurves.part
#'
#' ## Plot partial curve
#' plot(sscurves.part)
#'
#' ## Plot partial curve with ggplot
#' autoplot(sscurves.part)
#'
#'
#' ##################################################
#' ### Multiple models & single test dataset
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(1, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'   modnames = samps[["modnames"]]
#' )
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mscurves <- evalmod(mdat)
#'
#' ## Calculate partial AUCs
#' mscurves.part <- part(mscurves, xlim = c(0, 0.75), ylim = c(0.25, 0.75))
#'
#' ## Show AUCs
#' mscurves.part
#'
#' ## Plot partial curves
#' plot(mscurves.part)
#'
#' ## Plot partial curves with ggplot
#' autoplot(mscurves.part)
#'
#'
#' ##################################################
#' ### Single model & multiple test datasets
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(4, 100, 100, "good_er")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'   modnames = samps[["modnames"]],
#'   dsids = samps[["dsids"]]
#' )
#'
#' ## Generate an smcurve object that contains ROC and Precision-Recall curves
#' smcurves <- evalmod(mdat)
#'
#' ## Calculate partial AUCs
#' smcurves.part <- part(smcurves, xlim = c(0.25, 0.75))
#'
#' ## Show AUCs
#' smcurves.part
#'
#' ## Plot partial curve
#' plot(smcurves.part)
#'
#' ## Plot partial curve with ggplot
#' autoplot(smcurves.part)
#'
#'
#' ##################################################
#' ### Multiple models & multiple test datasets
#' ###
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(4, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'   modnames = samps[["modnames"]],
#'   dsids = samps[["dsids"]]
#' )
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mmcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Calculate partial AUCs
#' mmcurves.part <- part(mmcurves, xlim = c(0, 0.25))
#'
#' ## Show AUCs
#' mmcurves.part
#'
#' ## Plot partial curves
#' plot(mmcurves.part)
#'
#' ## Plot partial curves with ggplot
#' autoplot(mmcurves.part)
#' }
#'
#' @export
part <- function(curves, xlim = NULL, ylim = NULL, curvetype = NULL) {
  UseMethod("part", curves)
}

#' @export
part.default <- function(curves, xlim = NULL, ylim = NULL, curvetype = NULL) {
  stop("An object of unknown class is specified")
}

#' @rdname part
#' @export
part.sscurves <- function(curves, xlim = c(0, 1), ylim = c(0, 1),
                          curvetype = c("ROC", "PRC")) {
  .prepare_part_calc(curves, xlim, ylim, curvetype, FALSE)
}

#' @rdname part
#' @export
part.mscurves <- function(curves, xlim = c(0, 1), ylim = c(0, 1),
                          curvetype = c("ROC", "PRC")) {
  .prepare_part_calc(curves, xlim, ylim, curvetype, FALSE)
}

#' @rdname part
#' @export
part.smcurves <- function(curves, xlim = c(0, 1), ylim = c(0, 1),
                          curvetype = c("ROC", "PRC")) {
  if (attr(curves, "args")$raw_curves) {
    .prepare_part_calc(curves, xlim, ylim, curvetype, FALSE)
  } else {
    .prepare_part_calc(curves, xlim, ylim, curvetype, TRUE)
  }
}

#' @rdname part
#' @export
part.mmcurves <- function(curves, xlim = c(0, 1), ylim = c(0, 1),
                          curvetype = c("ROC", "PRC")) {
  if (attr(curves, "args")$raw_curves) {
    .prepare_part_calc(curves, xlim, ylim, curvetype, FALSE)
  } else {
    .prepare_part_calc(curves, xlim, ylim, curvetype, TRUE)
  }
}

#
# Prepare partial AUC calculation
#
.prepare_part_calc <- function(curves, xlim, ylim, curvetype, avg_only) {
  # Validation
  .validate(curves)
  .check_limits(xlim, ylim)
  .check_curvetype(curvetype)
  new_curvetype <- .pmatch_curvetype_rocprc(curvetype)

  # Calculate partial AUC scores for ROC
  if ("ROC" %in% new_curvetype) {
    if (avg_only) {
      attr(curves, "grp_avg")[["rocs"]] <- .calc_pauc(
        attr(
          curves,
          "grp_avg"
        )[["rocs"]],
        xlim, ylim, avg_only
      )
    } else {
      curves[["rocs"]] <- .calc_pauc(curves[["rocs"]], xlim, ylim, avg_only)
    }
    attr(curves[["rocs"]], "xlim") <- xlim
    attr(curves[["rocs"]], "ylim") <- ylim
  }

  # Calculate partial AUC scores for precision-recall
  if ("PRC" %in% new_curvetype) {
    if (avg_only) {
      attr(curves, "grp_avg")[["prcs"]] <- .calc_pauc(
        attr(
          curves,
          "grp_avg"
        )[["prcs"]],
        xlim, ylim, avg_only
      )
    } else {
      curves[["prcs"]] <- .calc_pauc(curves[["prcs"]], xlim, ylim, avg_only)
    }
    attr(curves[["prcs"]], "xlim") <- xlim
    attr(curves[["prcs"]], "ylim") <- ylim
  }

  if (avg_only) {
    attr(curves, "paucs") <- .gather_paucs_avg(curves)
  } else {
    attr(curves, "paucs") <- .gather_paucs(curves)
  }

  attr(curves, "partial") <- TRUE
  curves
}

#
# Calculate partial AUC scores
#
.calc_pauc <- function(curves, xlim, ylim, avg_only) {
  for (i in seq_along(curves)) {
    # Trim x
    x <- curves[[i]][["x"]]
    x[x < xlim[1]] <- xlim[1]
    x[x > xlim[2]] <- xlim[2]

    # Trim y
    if (avg_only) {
      y <- curves[[i]][["y_avg"]]
    } else {
      y <- curves[[i]][["y"]]
    }
    y[y < ylim[1]] <- ylim[1]
    y[y > ylim[2]] <- ylim[2]

    # Area
    ssarea <- (xlim[2] - xlim[1]) * (ylim[2] - ylim[1])

    # Calculate pAUC and stardarized pAUC
    if (all(c(0, 1) == xlim) && all(c(0, 1) == ylim)) {
      pauc <- attr(curves[[i]], "auc")
      spauc <- attr(curves[[i]], "auc")
    } else {
      pauc <- calc_auc(x, y)
      if (ylim[1] != 0) {
        pauc <- pauc[["auc"]] - ((xlim[2] - xlim[1]) * ylim[1])
      } else {
        pauc <- pauc[["auc"]]
      }

      spauc <- pauc / ssarea
    }

    # Max 1
    if (pauc > 1) {
      pauc <- 1
    }

    if (spauc > 1) {
      spauc <- 1
    }

    attr(curves[[i]], "pauc") <- pauc
    attr(curves[[i]], "spauc") <- spauc

    attr(curves[[i]], "xlim") <- xlim
    attr(curves[[i]], "ylim") <- ylim
  }

  curves
}

#
# Get pAUCs
#
.gather_paucs <- function(curves) {
  # Collect AUCs of ROC or PRC curves
  ct_len <- 2
  aucs <- attr(curves, "aucs")
  paucs <- data.frame(
    modnames = aucs$modnames,
    dsids = aucs$dsids,
    curvetypes = aucs$curvetypes,
    paucs = rep(NA, length(aucs$modnames)),
    spaucs = rep(NA, length(aucs$modnames)),
    stringsAsFactors = FALSE
  )

  for (i in seq_along(curves[["rocs"]])) {
    idx <- ct_len * i - 1
    paucs[["paucs"]][idx:(idx + 1)] <- c(
      attr(curves[["rocs"]][[i]], "pauc"),
      attr(curves[["prcs"]][[i]], "pauc")
    )
    paucs[["spaucs"]][idx:(idx + 1)] <- c(
      attr(curves[["rocs"]][[i]], "spauc"),
      attr(curves[["prcs"]][[i]], "spauc")
    )
  }

  paucs
}

#
# Get pAUCs of average curves
#
.gather_paucs_avg <- function(curves) {
  avg_crvs <- attr(curves, "grp_avg")

  # Collect AUCs of ROC or PRC curves
  ct_len <- 2
  modnames <- attr(avg_crvs[["rocs"]], "uniq_modnames")
  paucs <- data.frame(
    modnames = rep(modnames, each = ct_len),
    curvetypes = rep(c("ROC", "PRC"), length(modnames)),
    paucs = rep(NA, length(modnames) * ct_len),
    spaucs = rep(NA, length(modnames) * ct_len),
    stringsAsFactors = FALSE
  )

  for (i in seq_along(avg_crvs[["rocs"]])) {
    idx <- ct_len * i - 1
    idx2 <- idx + 1
    paucs[["paucs"]][idx:idx2] <- c(
      attr(avg_crvs[["rocs"]][[i]], "pauc"),
      attr(avg_crvs[["prcs"]][[i]], "pauc")
    )
    paucs[["spaucs"]][idx:idx2] <- c(
      attr(avg_crvs[["rocs"]][[i]], "spauc"),
      attr(avg_crvs[["prcs"]][[i]], "spauc")
    )
  }

  paucs
}

Try the precrec package in your browser

Any scripts or data that you put into this service are public.

precrec documentation built on Oct. 12, 2023, 1:06 a.m.