R/mm3_reformat_data.R

Defines functions .validate.sdat .validate.fmdat .validate_reformat_data_args .rank_scores .factor_labels reformat_data

#
# Reformat input data for Precision-Recall and ROC evaluation
#
reformat_data <- function(scores, labels,
                          modname = as.character(NA), dsid = 1L,
                          posclass = NULL, na_worst = TRUE,
                          ties_method = "equiv", mode = "rocprc", ...) {
  # === Validate input arguments ===
  new_ties_method <- .pmatch_tiesmethod(ties_method, ...)
  new_na_worst <- .get_new_naworst(na_worst, ...)
  new_mode <- .pmatch_mode(mode)
  .validate_reformat_data_args(scores, labels,
    modname = modname, dsid = dsid,
    posclass = posclass, na_worst = new_na_worst,
    ties_method = new_ties_method, mode = new_mode,
    ...
  )

  # === Reformat input data ===
  # Get a factor with "positive" and "negative"
  fmtlabs <- .factor_labels(labels, posclass, validate = FALSE)

  if (mode == "aucroc") {
    # === Create an S3 object ===
    s3obj <- structure(
      list(
        scores = scores,
        labels = fmtlabs[["labels"]]
      ),
      class = "sdat"
    )
  } else {
    # Get score ranks and sorted indices
    sranks <- .rank_scores(scores, new_na_worst, new_ties_method,
      validate = FALSE
    )
    ranks <- sranks[["ranks"]]
    rank_idx <- sranks[["rank_idx"]]

    # === Create an S3 object ===
    s3obj <- structure(
      list(
        scores = scores,
        labels = fmtlabs[["labels"]],
        ranks = ranks,
        rank_idx = rank_idx
      ),
      class = "fmdat"
    )
  }

  # Set attributes
  attr(s3obj, "modname") <- modname
  attr(s3obj, "dsid") <- dsid
  attr(s3obj, "nn") <- fmtlabs[["nn"]]
  attr(s3obj, "np") <- fmtlabs[["np"]]
  attr(s3obj, "args") <- list(
    posclass = posclass, na_worst = new_na_worst,
    ties_method = new_ties_method,
    modname = modname, dsid = dsid
  )
  attr(s3obj, "validated") <- FALSE

  # Call .validate.fmdat() / .validate.sdat()
  .validate(s3obj)
}

#
# Factor labels
#
.factor_labels <- function(labels, posclass, validate = TRUE) {
  # === Validate input arguments ===
  if (validate) {
    .validate_labels(labels)
    .validate_posclass(posclass)
  }

  # Update posclass if necessary
  if (is.null(posclass)) {
    posclass <- NA
  } else if (is.factor(labels)) {
    lv <- levels(labels)
    posclass <- which(lv == posclass)
  }

  # Check the data type of posclass
  if (!is.na(posclass) && typeof(posclass) != typeof(labels[1])) {
    stop("posclass must be the same data type as labels", call. = FALSE)
  }

  # === Generate label factors ===
  flabels <- format_labels(labels, posclass)
  .check_cpp_func_error(flabels, "format_labels")

  flabels
}

#
# Rank scores
#
.rank_scores <- function(scores, na_worst = TRUE, ties_method = "equiv",
                         validate = TRUE) {
  # === Validate input arguments ===
  if (validate) {
    .validate_scores(scores)
    .validate_na_worst(na_worst)
    .validate_ties_method(ties_method)
  }

  # === Create ranks ===
  sranks <- get_score_ranks(scores, na_worst, ties_method)
  .check_cpp_func_error(sranks, "get_score_ranks")

  sranks
}

#
# Validate arguments of reformat_data()
#
.validate_reformat_data_args <- function(scores, labels, modname, dsid,
                                         posclass, na_worst, ties_method,
                                         mode, ...) {
  # Check '...'
  arglist <- list(...)
  if (!is.null(names(arglist))) {
    stop(paste0("Invalid arguments: ", paste(names(arglist), collapse = ", ")),
      call. = FALSE
    )
  }

  # Check scores and labels
  .validate_scores_and_labels(NULL, NULL, scores, labels)

  # Check model name
  .validate_modname(modname)

  # Check dataset ID
  .validate_dsid(dsid)

  # Check posclass
  .validate_posclass(posclass)

  # Check na_worst
  .validate_na_worst(na_worst)

  # Check ties_method
  .validate_ties_method(ties_method)

  # Check mode
  .validate_mode(mode)
}

#
# Validate 'fmdat' object generated by reformat_data()
#
.validate.fmdat <- function(x) {
  # Need to validate only once
  if (methods::is(x, "fmdat") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- c("scores", "labels", "ranks", "rank_idx")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
  arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
  .validate_basic(
    x, "fmdat", "reformat_data", item_names, attr_names,
    arg_names
  )

  # Check values of class items
  if (length(x[["labels"]]) == 0 ||
    length(x[["labels"]]) != length(x[["ranks"]]) ||
    length(x[["labels"]]) != length(x[["rank_idx"]])) {
    stop("List items in fmdat must be all the same lengths", call. = FALSE)
  }

  # Labels
  assertthat::assert_that(
    is.atomic(x[["labels"]]),
    is.vector(x[["labels"]]),
    is.numeric(x[["labels"]])
  )

  # Ranks
  assertthat::assert_that(
    is.atomic(x[["ranks"]]),
    is.vector(x[["ranks"]]),
    is.numeric(x[["ranks"]])
  )

  # Rank index
  assertthat::assert_that(
    is.atomic(x[["rank_idx"]]),
    is.vector(x[["rank_idx"]]),
    is.integer(x[["rank_idx"]])
  )

  attr(x, "validated") <- TRUE
  x
}

#
# Validate 'sdat' object generated by reformat_data()
#
.validate.sdat <- function(x) {
  # Need to validate only once
  if (methods::is(x, "sdat") && attr(x, "validated")) {
    return(x)
  }

  # Validate class items and attributes
  item_names <- c("scores", "labels")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
  arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
  .validate_basic(
    x, "sdat", "reformat_data", item_names, attr_names,
    arg_names
  )

  # Check values of class items
  if (length(x[["labels"]]) == 0 ||
    length(x[["labels"]]) != length(x[["scores"]])) {
    stop("List items in sdat must be all the same lengths", call. = FALSE)
  }

  # Labels
  assertthat::assert_that(
    is.atomic(x[["labels"]]),
    is.vector(x[["labels"]]),
    is.numeric(x[["labels"]])
  )

  attr(x, "validated") <- TRUE
  x
}

Try the precrec package in your browser

Any scripts or data that you put into this service are public.

precrec documentation built on Oct. 12, 2023, 1:06 a.m.