R/pl3_create_confmats.R

Defines functions .validate.cmats create_confmats

#
# Calculate confusion matrices for all ranks
#
create_confmats <- function(fmdat, scores = NULL, labels = NULL,
                            keep_fmdat = FALSE, ...) {
  # === Validate input arguments ===
  # Create fmdat from scores and labels if fmdat is missing
  fmdat <- .create_src_obj(fmdat, "fmdat", reformat_data, scores, labels, ...)
  .validate(fmdat)

  # === Create confusion matrices for all ranks ===
  # Call a cpp function via Rcpp interface
  cmats <- create_confusion_matrices(
    fmdat[["labels"]], fmdat[["ranks"]],
    fmdat[["rank_idx"]]
  )
  .check_cpp_func_error(cmats, "create_confusion_matrices")

  # === Create an S3 object ===
  cpp_errmsg <- cmats[["errmsg"]]
  cmats[["errmsg"]] <- NULL
  s3obj <- structure(cmats, class = "cmats")

  # Set attributes
  attr(s3obj, "modname") <- attr(fmdat, "modname")
  attr(s3obj, "dsid") <- attr(fmdat, "dsid")
  attr(s3obj, "nn") <- attr(fmdat, "nn")
  attr(s3obj, "np") <- attr(fmdat, "np")
  attr(s3obj, "args") <- list(...)
  attr(s3obj, "cpp_errmsg") <- cpp_errmsg
  if (keep_fmdat) {
    attr(s3obj, "src") <- fmdat
  } else {
    attr(s3obj, "src") <- NA
  }
  attr(s3obj, "validated") <- FALSE

  # Call .validate.cmats()
  .validate(s3obj)
}

#
# Validate 'cmats' object generated by create_confmats()
#
.validate.cmats <- function(x) {
  # Need to validate only once
  if (methods::is(x, "cmats") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- c("pos_num", "neg_num", "tp", "fp", "tn", "fn", "ranks")
  attr_names <- c(
    "modname", "dsid", "nn", "np", "args", "cpp_errmsg",
    "src", "validated"
  )
  arg_names <- c(
    "na_worst", "na.last", "ties_method", "ties.method",
    "modname", "dsid", "keep_fmdat"
  )
  .validate_basic(
    x, "cmats", "create_confmats", item_names, attr_names,
    arg_names
  )

  # Check values of class items
  n <- length(x[["tp"]])
  if (length(x[["fp"]]) != n || length(x[["tn"]]) != n ||
    length(x[["fn"]]) != n) {
    stop("tp, fp, tn, and fn in cmats must be all the same lengths",
      call. = FALSE
    )
  }

  # TP
  assertthat::assert_that(
    is.atomic(x[["tp"]]),
    is.vector(x[["tp"]]),
    is.numeric(x[["tp"]]),
    x[["tp"]][1] == 0,
    x[["tp"]][n] == x[["pos_num"]]
  )

  # FP
  assertthat::assert_that(
    is.atomic(x[["fp"]]),
    is.vector(x[["fp"]]),
    is.numeric(x[["fp"]]),
    x[["fp"]][1] == 0,
    x[["fp"]][n] == x[["neg_num"]]
  )

  # FN
  assertthat::assert_that(
    is.atomic(x[["fn"]]),
    is.vector(x[["fn"]]),
    is.numeric(x[["fn"]]),
    x[["fn"]][1] == x[["pos_num"]],
    x[["fn"]][n] == 0
  )

  # TN
  assertthat::assert_that(
    is.atomic(x[["tn"]]),
    is.vector(x[["tn"]]),
    is.numeric(x[["tn"]]),
    x[["tn"]][1] == x[["neg_num"]],
    x[["tn"]][n] == 0
  )

  attr(x, "validated") <- TRUE
  x
}
takayasaito/precrec documentation built on Oct. 19, 2023, 7:28 p.m.