R/mm3_reformat_data.R

Defines functions validate.sdat validate.fmdat validate_reformat_data_args rank_scores factor_labels reformat_data

#
# Reformat input data for Precision-Recall and ROC evaluation
#
reformat_data <- function(scores, labels,
                          modname = as.character(NA), dsid = 1L,
                          posclass = NULL, na_worst = TRUE,
                          ties_method = "equiv", mode = "rocprc", ...) {

  # === Validate input arguments ===
  new_ties_method <- .pmatch_tiesmethod(ties_method, ...)
  new_na_worst <- .get_new_naworst(na_worst, ...)
  new_mode <- .pmatch_mode(mode)
  .validate_reformat_data_args(scores, labels, modname = modname, dsid = dsid,
                               posclass = posclass, na_worst = new_na_worst,
                               ties_method = new_ties_method, mode = new_mode,
                               ...)

  # === Reformat input data ===
  # Get a factor with "positive" and "negative"
  fmtlabs <- .factor_labels(labels, posclass, validate = FALSE)

  if (mode == "aucroc") {
    # === Create an S3 object ===
    s3obj <- structure(list(scores = scores,
                            labels = fmtlabs[["labels"]]),
                       class = "sdat")
  } else {
    # Get score ranks and sorted indices
    sranks <- .rank_scores(scores, new_na_worst, new_ties_method,
                           validate = FALSE)
    ranks <- sranks[["ranks"]]
    rank_idx <- sranks[["rank_idx"]]

    # === Create an S3 object ===
    s3obj <- structure(list(scores = scores,
                            labels = fmtlabs[["labels"]],
                            ranks = ranks,
                            rank_idx = rank_idx),
                       class = "fmdat")
  }

  # Set attributes
  attr(s3obj, "modname") <- modname
  attr(s3obj, "dsid") <- dsid
  attr(s3obj, "nn") <- fmtlabs[["nn"]]
  attr(s3obj, "np") <- fmtlabs[["np"]]
  attr(s3obj, "args") <- list(posclass = posclass, na_worst = new_na_worst,
                              ties_method = new_ties_method,
                              modname = modname, dsid = dsid)
  attr(s3obj, "validated") <- FALSE

  # Call .validate.fmdat() / .validate.sdat()
  .validate(s3obj)
}

#
# Factor labels
#
.factor_labels <- function(labels, posclass, validate = TRUE) {
  # === Validate input arguments ===
  if (validate) {
    .validate_labels(labels)
    .validate_posclass(posclass)
  }

  # Update posclass if necessary
  if (is.null(posclass)) {
    posclass <- NA
  } else if (is.factor(labels)) {
    lv <- levels(labels)
    posclass <- which(lv == posclass)
  }

  # Check the data type of posclass
  if (!is.na(posclass) && typeof(posclass) != typeof(labels[1])) {
    stop("posclass must be the same data type as labels", call. = FALSE)
  }

  # === Generate label factors ===
  flabels <- format_labels(labels, posclass)
  .check_cpp_func_error(flabels, "format_labels")

  flabels
}

#
# Rank scores
#
.rank_scores <- function(scores, na_worst = TRUE, ties_method = "equiv",
                         validate = TRUE) {

  # === Validate input arguments ===
  if (validate) {
    .validate_scores(scores)
    .validate_na_worst(na_worst)
    .validate_ties_method(ties_method)
  }

  # === Create ranks ===
  #   ranks <- rank(scores, na_worst, ties_method)
  sranks <- get_score_ranks(scores, na_worst, ties_method)
  .check_cpp_func_error(sranks, "get_score_ranks")

  sranks
}

#
# Validate arguments of reformat_data()
#
.validate_reformat_data_args <- function(scores, labels, modname, dsid,
                                         posclass, na_worst, ties_method,
                                         mode, ...) {

  # Check '...'
  arglist <- list(...)
  if (!is.null(names(arglist))){
    stop(paste0("Invalid arguments: ", paste(names(arglist), collapse = ", ")),
         call. = FALSE)
  }

  # Check scores and labels
  .validate_scores_and_labels(NULL, NULL, scores, labels)

  # Check model name
  .validate_modname(modname)

  # Check dataset ID
  .validate_dsid(dsid)

  # Check posclass
  .validate_posclass(posclass)

  # Check na_worst
  .validate_na_worst(na_worst)

  # Check ties_method
  .validate_ties_method(ties_method)

  # Check mode
  .validate_mode(mode)

}

#
# Validate 'fmdat' object generated by reformat_data()
#
.validate.fmdat <- function(fmdat) {
  # Need to validate only once
  if (methods::is(fmdat, "fmdat") && attr(fmdat, "validated")) {
    return(fmdat)
  }

  # Validate class items and attributes
  item_names <- c("scores", "labels", "ranks", "rank_idx")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
  arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
  .validate_basic(fmdat, "fmdat", "reformat_data", item_names, attr_names,
                  arg_names)

  # Check values of class items
  if (length(fmdat[["labels"]]) == 0
      || length(fmdat[["labels"]]) != length(fmdat[["ranks"]])
      || length(fmdat[["labels"]]) != length(fmdat[["rank_idx"]])) {
    stop("List items in fmdat must be all the same lengths", call. = FALSE)
  }

  # Labels
  assertthat::assert_that(is.atomic(fmdat[["labels"]]),
                          is.vector(fmdat[["labels"]]),
                          is.numeric(fmdat[["labels"]]))

  # Ranks
  assertthat::assert_that(is.atomic(fmdat[["ranks"]]),
                          is.vector(fmdat[["ranks"]]),
                          is.numeric(fmdat[["ranks"]]))

  # Rank index
  assertthat::assert_that(is.atomic(fmdat[["rank_idx"]]),
                          is.vector(fmdat[["rank_idx"]]),
                          is.integer(fmdat[["rank_idx"]]))

  attr(fmdat, "validated") <- TRUE
  fmdat
}

#
# Validate 'sdat' object generated by reformat_data()
#
.validate.sdat <- function(sdat) {
  # Need to validate only once
  if (methods::is(sdat, "sdat") && attr(sdat, "validated")) {
    return(sdat)
  }

  # Validate class items and attributes
  item_names <- c("scores", "labels")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
  arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
  .validate_basic(sdat, "sdat", "reformat_data", item_names, attr_names,
                  arg_names)

  # Check values of class items
  if (length(sdat[["labels"]]) == 0
      || length(sdat[["labels"]]) != length(sdat[["scores"]])) {
    stop("List items in sdat must be all the same lengths", call. = FALSE)
  }

  # Labels
  assertthat::assert_that(is.atomic(sdat[["labels"]]),
                          is.vector(sdat[["labels"]]),
                          is.numeric(sdat[["labels"]]))

  attr(sdat, "validated") <- TRUE
  sdat
}
takayasaito/precrec documentation built on Aug. 24, 2017, 8:07 a.m.