R/pl3_create_confmats.R

#
# Calculate confusion matrices for all ranks
#
create_confmats <- function(fmdat, scores = NULL, labels = NULL,
                            keep_fmdat = FALSE, ...) {

  # === Validate input arguments ===
  # Create fmdat from scores and labels if fmdat is missing
  fmdat <- .create_src_obj(fmdat, "fmdat", reformat_data, scores, labels, ...)
  .validate(fmdat)

  # === Create confusion matrices for all ranks ===
  # Call a cpp function via Rcpp interface
  cmats <- create_confusion_matrices(fmdat[["labels"]], fmdat[["ranks"]],
                                     fmdat[["rank_idx"]])
  .check_cpp_func_error(cmats, "create_confusion_matrices")

  # === Create an S3 object ===
  cpp_errmsg <- cmats[["errmsg"]]
  cmats[["errmsg"]] <- NULL
  s3obj <- structure(cmats, class = "cmats")

  # Set attributes
  attr(s3obj, "modname") <- attr(fmdat, "modname")
  attr(s3obj, "dsid") <- attr(fmdat, "dsid")
  attr(s3obj, "nn") <- attr(fmdat, "nn")
  attr(s3obj, "np") <- attr(fmdat, "np")
  attr(s3obj, "args") <- list(...)
  attr(s3obj, "cpp_errmsg") <- cpp_errmsg
  if (keep_fmdat) {
    attr(s3obj, "src") <- fmdat
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.cmats()
  .validate(s3obj)
}

#
# Validate 'cmats' object generated by create_confmats()
#
.validate.cmats <- function(cmats) {
  # Need to validate only once
  if (methods::is(cmats, "cmats") && attr(cmats, "validated")) {
    return(cmats)
  }

  # Validate class items and attributes
  item_names <- c("pos_num", "neg_num", "tp", "fp", "tn", "fn", "ranks")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "cpp_errmsg",
                  "src", "validated")
  arg_names <- c("na_worst", "na.last", "ties_method", "ties.method",
                 "modname", "dsid", "keep_fmdat")
  .validate_basic(cmats, "cmats", "create_confmats", item_names, attr_names,
                  arg_names)

  # Check values of class items
  n <- length(cmats[["tp"]])
  if (length(cmats[["fp"]]) != n || length(cmats[["tn"]]) != n
      || length(cmats[["fn"]]) != n) {
    stop("tp, fp, tn, and fn in cmats must be all the same lengths",
         call. = FALSE)
  }

  # TP
  assertthat::assert_that(is.atomic(cmats[["tp"]]),
                          is.vector(cmats[["tp"]]),
                          is.numeric(cmats[["tp"]]),
                          cmats[["tp"]][1] == 0,
                          cmats[["tp"]][n] == cmats[["pos_num"]])

  # FP
  assertthat::assert_that(is.atomic(cmats[["fp"]]),
                          is.vector(cmats[["fp"]]),
                          is.numeric(cmats[["fp"]]),
                          cmats[["fp"]][1] == 0,
                          cmats[["fp"]][n] == cmats[["neg_num"]])

  # FN
  assertthat::assert_that(is.atomic(cmats[["fn"]]),
                          is.vector(cmats[["fn"]]),
                          is.numeric(cmats[["fn"]]),
                          cmats[["fn"]][1] == cmats[["pos_num"]],
                          cmats[["fn"]][n] == 0)

  # TN
  assertthat::assert_that(is.atomic(cmats[["tn"]]),
                          is.vector(cmats[["tn"]]),
                          is.numeric(cmats[["tn"]]),
                          cmats[["tn"]][1] == cmats[["neg_num"]],
                          cmats[["tn"]][n] == 0)

  attr(cmats, "validated") <- TRUE
  cmats
}
guillermozbta/precrec documentation built on May 11, 2019, 7:22 p.m.