#
# Reformat input data for Precision-Recall and ROC evaluation
#
reformat_data <- function(scores, labels,
modname = as.character(NA), dsid = 1L,
posclass = NULL, na_worst = TRUE,
ties_method = "equiv", mode = "rocprc", ...) {
# === Validate input arguments ===
new_ties_method <- .pmatch_tiesmethod(ties_method, ...)
new_na_worst <- .get_new_naworst(na_worst, ...)
new_mode <- .pmatch_mode(mode)
.validate_reformat_data_args(scores, labels,
modname = modname, dsid = dsid,
posclass = posclass, na_worst = new_na_worst,
ties_method = new_ties_method, mode = new_mode,
...
)
# === Reformat input data ===
# Get a factor with "positive" and "negative"
fmtlabs <- .factor_labels(labels, posclass, validate = FALSE)
if (mode == "aucroc") {
# === Create an S3 object ===
s3obj <- structure(
list(
scores = scores,
labels = fmtlabs[["labels"]]
),
class = "sdat"
)
} else {
# Get score ranks and sorted indices
sranks <- .rank_scores(scores, new_na_worst, new_ties_method,
validate = FALSE
)
ranks <- sranks[["ranks"]]
rank_idx <- sranks[["rank_idx"]]
# === Create an S3 object ===
s3obj <- structure(
list(
scores = scores,
labels = fmtlabs[["labels"]],
ranks = ranks,
rank_idx = rank_idx
),
class = "fmdat"
)
}
# Set attributes
attr(s3obj, "modname") <- modname
attr(s3obj, "dsid") <- dsid
attr(s3obj, "nn") <- fmtlabs[["nn"]]
attr(s3obj, "np") <- fmtlabs[["np"]]
attr(s3obj, "args") <- list(
posclass = posclass, na_worst = new_na_worst,
ties_method = new_ties_method,
modname = modname, dsid = dsid
)
attr(s3obj, "validated") <- FALSE
# Call .validate.fmdat() / .validate.sdat()
.validate(s3obj)
}
#
# Factor labels
#
.factor_labels <- function(labels, posclass, validate = TRUE) {
# === Validate input arguments ===
if (validate) {
.validate_labels(labels)
.validate_posclass(posclass)
}
# Update posclass if necessary
if (is.null(posclass)) {
posclass <- NA
} else if (is.factor(labels)) {
lv <- levels(labels)
posclass <- which(lv == posclass)
}
# Check the data type of posclass
if (!is.na(posclass) && typeof(posclass) != typeof(labels[1])) {
stop("posclass must be the same data type as labels", call. = FALSE)
}
# === Generate label factors ===
flabels <- format_labels(labels, posclass)
.check_cpp_func_error(flabels, "format_labels")
flabels
}
#
# Rank scores
#
.rank_scores <- function(scores, na_worst = TRUE, ties_method = "equiv",
validate = TRUE) {
# === Validate input arguments ===
if (validate) {
.validate_scores(scores)
.validate_na_worst(na_worst)
.validate_ties_method(ties_method)
}
# === Create ranks ===
sranks <- get_score_ranks(scores, na_worst, ties_method)
.check_cpp_func_error(sranks, "get_score_ranks")
sranks
}
#
# Validate arguments of reformat_data()
#
.validate_reformat_data_args <- function(scores, labels, modname, dsid,
posclass, na_worst, ties_method,
mode, ...) {
# Check '...'
arglist <- list(...)
if (!is.null(names(arglist))) {
stop(paste0("Invalid arguments: ", paste(names(arglist), collapse = ", ")),
call. = FALSE
)
}
# Check scores and labels
.validate_scores_and_labels(NULL, NULL, scores, labels)
# Check model name
.validate_modname(modname)
# Check dataset ID
.validate_dsid(dsid)
# Check posclass
.validate_posclass(posclass)
# Check na_worst
.validate_na_worst(na_worst)
# Check ties_method
.validate_ties_method(ties_method)
# Check mode
.validate_mode(mode)
}
#
# Validate 'fmdat' object generated by reformat_data()
#
.validate.fmdat <- function(x) {
# Need to validate only once
if (methods::is(x, "fmdat") && attr(x, "validated")) {
return(x)
}
# Validate class items and attributes
item_names <- c("scores", "labels", "ranks", "rank_idx")
attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
.validate_basic(
x, "fmdat", "reformat_data", item_names, attr_names,
arg_names
)
# Check values of class items
if (length(x[["labels"]]) == 0 ||
length(x[["labels"]]) != length(x[["ranks"]]) ||
length(x[["labels"]]) != length(x[["rank_idx"]])) {
stop("List items in fmdat must be all the same lengths", call. = FALSE)
}
# Labels
assertthat::assert_that(
is.atomic(x[["labels"]]),
is.vector(x[["labels"]]),
is.numeric(x[["labels"]])
)
# Ranks
assertthat::assert_that(
is.atomic(x[["ranks"]]),
is.vector(x[["ranks"]]),
is.numeric(x[["ranks"]])
)
# Rank index
assertthat::assert_that(
is.atomic(x[["rank_idx"]]),
is.vector(x[["rank_idx"]]),
is.integer(x[["rank_idx"]])
)
attr(x, "validated") <- TRUE
x
}
#
# Validate 'sdat' object generated by reformat_data()
#
.validate.sdat <- function(x) {
# Need to validate only once
if (methods::is(x, "sdat") && attr(x, "validated")) {
return(x)
}
# Validate class items and attributes
item_names <- c("scores", "labels")
attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
.validate_basic(
x, "sdat", "reformat_data", item_names, attr_names,
arg_names
)
# Check values of class items
if (length(x[["labels"]]) == 0 ||
length(x[["labels"]]) != length(x[["scores"]])) {
stop("List items in sdat must be all the same lengths", call. = FALSE)
}
# Labels
assertthat::assert_that(
is.atomic(x[["labels"]]),
is.vector(x[["labels"]]),
is.numeric(x[["labels"]])
)
attr(x, "validated") <- TRUE
x
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.