R/pl4_calc_measures.R

Defines functions .validate.pevals calc_measures

#
# Calculate basic evaluation measures from confusion matrices
#
calc_measures <- function(cmats, scores = NULL, labels = NULL, ...) {
  # === Validate input arguments ===
  # Create cmats from scores and labels if cmats is missing
  cmats <- .create_src_obj(
    cmats, "cmats", create_confmats, scores, labels,
    ...
  )
  .validate(cmats)

  # === Create confusion matrices for all ranks ===
  # Call a cpp function via Rcpp interface
  pevals <- calc_basic_measures(
    attr(cmats, "np"), attr(cmats, "nn"),
    cmats[["tp"]], cmats[["fp"]],
    cmats[["tn"]], cmats[["fn"]]
  )
  .check_cpp_func_error(pevals, "calc_basic_measures")

  # === Create an S3 object ===
  s3obj <- structure(pevals["basic"], class = "pevals")

  # Set attributes
  attr(s3obj, "modname") <- attr(cmats, "modname")
  if (all(is.na(attr(cmats, "src")))) {
    s3obj[["basic"]][["score"]] <- rep(
      NA,
      length(s3obj[["basic"]][["rank"]])
    )
    s3obj[["basic"]][["label"]] <- rep(
      NA,
      length(s3obj[["basic"]][["rank"]])
    )
  } else {
    ridx <- attr(cmats, "src")[["rank_idx"]]
    tscores <- attr(cmats, "src")[["scores"]][ridx]
    tlabels <- as.numeric(attr(cmats, "src")[["labels"]])[ridx]
    tlabels <- tlabels - 1
    tlabels[tlabels == 0] <- -1
    s3obj[["basic"]][["score"]] <- c(NA, tscores)
    s3obj[["basic"]][["label"]] <- c(NA, tlabels)
  }
  attr(s3obj, "dsid") <- attr(cmats, "dsid")
  attr(s3obj, "nn") <- attr(cmats, "nn")
  attr(s3obj, "np") <- attr(cmats, "np")
  attr(s3obj, "args") <- list(...)
  attr(s3obj, "cpp_errmsg") <- pevals[["errmsg"]]
  attr(s3obj, "src") <- cmats
  attr(s3obj, "validated") <- FALSE

  # Call .validate.cmats()
  .validate(s3obj)
}

#
# Validate 'pevals' object generated by calc_measures()
#
.validate.pevals <- function(x) {
  # Need to validate only once
  if (methods::is(x, "pevals") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- "basic"
  attr_names <- c(
    "modname", "dsid", "nn", "np", "args", "cpp_errmsg",
    "src", "validated"
  )
  arg_names <- c(
    "na_worst", "na.last", "ties.method", "ties_method",
    "modname", "dsid", "keep_fmdat"
  )
  .validate_basic(
    x, "pevals", "calc_measures", item_names, attr_names,
    arg_names
  )

  pb <- x[["basic"]]

  # Check values of class items
  n <- length(pb[["error"]])
  if (length(pb[["accuracy"]]) != n ||
    length(pb[["specificity"]]) != n ||
    length(pb[["sensitivity"]]) != n ||
    length(pb[["precision"]]) != n ||
    length(pb[["mcc"]]) != n ||
    length(pb[["fscore"]]) != n ||
    length(pb[["score"]]) != n ||
    length(pb[["label"]]) != n) {
    stop("Evaluation vectors must be all the same lengths", call. = FALSE)
  }

  # Scores
  assertthat::assert_that(
    is.atomic(pb[["score"]]),
    is.vector(pb[["score"]])
  )

  # Labels
  assertthat::assert_that(
    is.atomic(pb[["label"]]),
    is.vector(pb[["label"]])
  )

  # Error rate
  assertthat::assert_that(
    is.atomic(pb[["error"]]),
    is.vector(pb[["error"]]),
    is.numeric(pb[["error"]])
  )

  # Accuracy
  assertthat::assert_that(
    is.atomic(pb[["accuracy"]]),
    is.vector(pb[["accuracy"]]),
    is.numeric(pb[["accuracy"]])
  )

  # Error rate & Arruracy
  assertthat::assert_that(
    pb[["error"]][1] + pb[["accuracy"]][1] == 1,
    pb[["error"]][n] + pb[["accuracy"]][n] == 1
  )

  # SP
  assertthat::assert_that(
    is.atomic(pb[["specificity"]]),
    is.vector(pb[["specificity"]]),
    is.numeric(pb[["specificity"]]),
    pb[["specificity"]][1] == 1,
    pb[["specificity"]][n] == 0
  )

  # SN
  assertthat::assert_that(
    is.atomic(pb[["sensitivity"]]),
    is.vector(pb[["sensitivity"]]),
    is.numeric(pb[["sensitivity"]]),
    pb[["sensitivity"]][1] == 0,
    pb[["sensitivity"]][n] == 1
  )

  # PREC
  assertthat::assert_that(
    is.atomic(pb[["precision"]]),
    is.vector(pb[["precision"]]),
    is.numeric(pb[["precision"]]),
    pb[["precision"]][1] == pb[["precision"]][2]
  )

  # Matthews correlation coefficient
  assertthat::assert_that(
    is.atomic(pb[["mcc"]]),
    is.vector(pb[["mcc"]]),
    is.numeric(pb[["mcc"]])
  )

  # F-score
  assertthat::assert_that(
    is.atomic(pb[["fscore"]]),
    is.vector(pb[["fscore"]]),
    is.numeric(pb[["fscore"]])
  )

  attr(x, "validated") <- TRUE
  x
}

Try the precrec package in your browser

Any scripts or data that you put into this service are public.

precrec documentation built on Oct. 12, 2023, 1:06 a.m.