Nothing
#' @title metric
#'
#' @description Returns a metric function which can be used for the experiments
#' (especially the cross-validation experiments) to compute the performance.
#'
#' @details
#' This function is a utility function to select performance metrics from the
#' `measures` R package and to reformat them into a form that is required
#' by the `mlexperiments` R package. For `mlexperiments` it is required that
#' a metric function takes the two arguments `ground_truth`, and `predictions`,
#' as well as additional names arguments that are necessary to compute the
#' performance, which are provided via the ellipsis argument (...).
#' When using the performance metric with an experiment of class
#' `"MLCrossValidation"`, such arguments can be defined as a list provided to
#' the field `performance_metric_args` of the R6 class.
#' The main purpose of `mlexperiments::metric()` is convenience and to
#' re-use already existing implementations of the metrics. However, custom
#' functions can be provided easily to compute the performance of the
#' experiments, simply by providing a function that takes the above mentioned
#' arguments and returns one performance metric value.
#'
#' @param name A metric name. Accepted names are the names of the metric
#' function exported from the `measures` R package.
#'
#' @return Returns a function that can be used as function to calculate the
#' performance metric throughout the experiments.
#'
#' @examples
#' if (requireNamespace("measures", quietly = TRUE)) {
#' metric("AUC")
#' }
#'
#' @export
#'
metric <- function(name) {
stopifnot(
"`name` must be a character of length() == 1" = is.character(name) &&
length(name) == 1L
)
if (!requireNamespace("measures", quietly = TRUE)) {
stop(
paste0(
"Package \"measures\" must be installed to use ",
"function 'metric()'."
),
call. = FALSE
)
}
# test if name is available in a different typing
metrics_table <- .get_metrics_metadata()
if (!(name %in% metrics_table[, as.character(get("function_name"))])) {
# try upper case
name_try <- toupper(name)
if (name_try %in% metrics_table[, as.character(get("function_name"))]) {
name <- name_try
} else {
stop("`name` is not a function exported from R package {measures}")
}
}
FUN <- utils::getFromNamespace(x = name, ns = "measures") # nolint
fun_name <- paste0("measures::", name)
# get first two default-arguments
default_args <- formals(FUN)
default_args_names <- names(default_args)
first_two_default_args <- default_args_names[1:2]
if ("response" %in% first_two_default_args) {
response_name <- "response"
} else {
response_name <- "probabilities"
}
# compose function body
fun_body <- paste0(
"args <- list(\n",
" truth = ground_truth,\n",
" ",
response_name,
" = predictions\n",
") \n",
"fun_default_args <- c(",
paste0("\"", default_args_names, collapse = "\", "),
"\")\n",
"fun_default_args <- setdiff(fun_default_args, \"...\")\n",
"if (length(kwargs) > 0L) {\n",
" valid_kwargs_names <- intersect(fun_default_args, names(kwargs))\n",
" kwargs <- kwargs[which(names(kwargs) %in% valid_kwargs_names)]\n",
" if (length(kwargs) > 0L) {\n",
" args <- c(args, kwargs)\n}\n}\n",
"return(do.call(",
fun_name,
", args))"
)
fun <- paste0(
"function(ground_truth, predictions, ...) {\n",
"kwargs <- list(...)\n",
#fun_body_pre,
fun_body,
"\n}"
)
return(eval(parse(text = fun)))
}
#' @title metric_types_helper
#'
#' @description Prepares the data to be conform with the requirements of
#' the metrics from `measures`.
#'
#' @param FUN A metric function, created with [mlexperiments::metric()].
#' @param y The outcome vector.
#' @param perf_args A list. The arguments to call the metric function with.
#'
#' @details
#' The `measures` R package makes some restrictions on the data type of
#' the ground truth and the predictions, depending on the metric, i.e. the
#' type of the task (regression or classification).
#' Thus, it is necessary to convert the inputs to the metric function
#' accordingly, which is done with this helper function.
#'
#' @return Returns the calculated performance measure.
#'
#' @examples
#' if (requireNamespace("measures", quietly = TRUE)) {
#' set.seed(123)
#' ground_truth <- sample(0:1, 100, replace = TRUE)
#' predictions <- sample(0:1, 100, replace = TRUE)
#' FUN <- metric("ACC")
#'
#' perf_args <- list(
#' ground_truth = ground_truth,
#' predictions = predictions
#' )
#'
#' metric_types_helper(
#' FUN = FUN,
#' y = ground_truth,
#' perf_args = perf_args
#' )
#' }
#'
#' @export
#'
metric_types_helper <- function(FUN, y, perf_args) {
# nolint
stopifnot(
"`FUN` must be a function" = is.function(FUN),
"`perf_args` must be a list" = is.list(perf_args),
"`perf_args` must contain named elements `ground_truth` and `predictions`" = all(
c("ground_truth", "predictions") %in% names(perf_args)
)
)
# note that this is very specific to the measures package
if (!requireNamespace("measures", quietly = TRUE)) {
stop(
paste0(
"Package \"measures\" must be installed to use ",
"function 'metric_types_helper()'."
),
call. = FALSE
)
}
# function name
pat <- ".*measures::(.*),.*"
fun_line <- grep(
pattern = pat,
x = deparse(FUN),
value = TRUE
)
error <- FALSE
# if there is a function from the measures-package, we can have
# the helpers here
if (length(fun_line) > 0) {
fun_name <- gsub(pattern = pat, replacement = "\\1", x = fun_line)
metric_metadata <- .metric_metadata(fun_name = fun_name)
if (fun_name == "PPV") {
# fix wrong metadata-list entry from measures::listAllMeasures
metric_metadata$probabilities <- FALSE
}
if (fun_name %in% c("ACC", "MMCE", "BER")) {
# infer binary cases
if (.test_binary_y_and_pred_probs(y, perf_args)) {
if (.test_pred_probs_value_boundaries(perf_args)) {
metric_metadata$binary <- TRUE
}
}
}
# fix binary metrics here
# logic for conversion of probabilities to classes in case of binary
# classification
if (
.test_binary_y_and_pred_probs(y, perf_args) &&
isTRUE(metric_metadata$binary) &&
isFALSE(metric_metadata$probabilities)
) {
# now test value bondaries
if (.test_pred_probs_value_boundaries(perf_args)) {
if (!is.factor(y)) {
# convert to factor
y <- factor(y)
}
lvls <- levels(y)
if ("positive" %in% names(perf_args)) {
val_positive <- perf_args$positive
val_negative <- setdiff(lvls, val_positive)
} else {
if ("0" %in% lvls && "1" %in% lvls) {
val_positive <- "1"
val_negative <- "0"
} else {
stop("Argument 'pos_level' is missing.")
}
}
perf_args$predictions <- ifelse(
test = perf_args$predictions > 0.5,
yes = val_positive,
no = val_negative
)
perf_args$predictions <- factor(
x = perf_args$predictions,
levels = lvls
)
if (any(is.na(perf_args$predictions))) {
error <- TRUE
}
}
}
}
tryCatch(
expr = {
if (isTRUE(error)) {
errorCondition("An error happend preparing response for binary metric.")
}
return(do.call(FUN, perf_args))
},
error = function(e) {
if (
grepl(
pattern = "Assertion on 'truth' failed: Must be of type 'factor'",
x = e
)
) {
# convert to factor
perf_args$ground_truth <- factor(
x = perf_args$ground_truth,
levels = lvls
)
error <- FALSE
} else if (
grepl(
pattern = paste0(
"Assertion on 'response' failed: Must have length ",
"\\d+, but has length \\d+\\."
),
x = e
)
) {
msg <- paste0(
"An error occurred... Try to use 'predict_args <- list(",
"reshape = TRUE)'"
)
stop(paste0(msg, "\n", e))
}
if (isFALSE(error)) {
# recursive execution
args <- list(
FUN = FUN,
y = y,
perf_args = perf_args
)
return(do.call(metric_types_helper, args))
} else {
stop(e)
}
}
)
}
.metric_from_char <- function(metric_vector) {
sapply(
X = metric_vector,
FUN = function(x) {
metric(x)
},
USE.NAMES = TRUE,
simplify = FALSE
)
}
.get_metrics_metadata <- function() {
metrics_table <- measures::listAllMeasures() |>
data.table::data.table()
return(metrics_table)
}
.metric_metadata <- function(fun_name) {
outlist <- list(
probabilities = FALSE,
task = FALSE,
binary = FALSE
)
if (requireNamespace("measures", quietly = TRUE)) {
metrics_table <- .get_metrics_metadata()
metric_row <- metrics_table[get("function_name") == fun_name, ]
if (metric_row[, .N] == 1) {
if (grepl(pattern = "classification", x = metric_row$task)) {
outlist[["task"]] <- "classification"
} else if (grepl(pattern = "regression", x = metric_row$task)) {
outlist[["task"]] <- "regression"
}
if (outlist[["task"]] == "classification") {
if (grepl(pattern = "binary", x = metric_row$task)) {
outlist[["binary"]] <- TRUE
}
outlist[["probabilities"]] <- metric_row$probabilities
}
}
}
return(outlist)
}
.test_binary_y_and_pred_probs <- function(y, perf_args) {
if (
length(unique(y)) <= 2 &&
length(unique(perf_args$predictions)) > 2
) {
return(TRUE)
} else {
return(FALSE)
}
}
.test_pred_probs_value_boundaries <- function(perf_args) {
if (
min(perf_args$predictions) >= 0 &&
max(perf_args$predictions) <= 1
) {
return(TRUE)
} else {
return(FALSE)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.