#' Create a set of tools
#'
#' The \code{create_toolset} function takes names of predefined tools and
#' generates a list of wrapper functions for Precision-Recall curve
#' calculations.
#'
#' @param tool_names A character vector to specify the names of
#' performance evaluation tools. The names for the following five tools can be
#' currently used.
#'
#' \itemize{
#' \item ROCR
#' \item AUCCalculator
#' \item PerfMeas
#' \item PRROC
#' \item precrec
#' }
#'
#' @param set_names A character vector to specify a predefined set name.
#' Following six sets are currently available.
#'
#' \describe{
#' \item{"def5"}{A set of 5 tools with \code{calc_auc = TRUE}
#' and \code{store_res = TRUE}}
#' \item{"auc5"}{A set of 5 tools with \code{calc_auc = TRUE}
#' and \code{store_res = FALSE}}
#' \item{"crv5"}{A set of 5 tools with \code{calc_auc = FALSE}
#' and \code{store_res = TRUE}}
#' \item{"def4"}{A set of 4 tools with \code{calc_auc = TRUE}
#' and \code{store_res = TRUE}}
#' \item{"auc4"}{A set of 4 tools with \code{calc_auc = TRUE}
#' and \code{store_res = FALSE}}
#' \item{"crv4"}{A set of 4 tools with \code{calc_auc = FALSE}
#' and \code{store_res = TRUE}}
#' }
#'
#' @param calc_auc A Boolean value to specify whether the AUC score should be
#' calculated.
#'
#' @param store_res A Boolean value to specify whether the calculated curve is
#' retrieved and stored
#'
#' @return A list of \code{R6} tool objects.
#'
#' @seealso \code{\link{run_benchmark}} and \code{\link{run_evalcurve}} require
#' the list of the tools generated by this function
#' \code{\link{ToolROCR}}, \code{\link{ToolAUCCalculator}},
#' \code{\link{ToolPerfMeas}}, \code{\link{ToolPRROC}}, and
#' \code{\link{Toolprecrec}} as R6 tool classes.
#'
#' @examples
#' ## Create ROCR and precrec
#' toolset1 <- create_toolset(c("ROCR", "precrec"))
#' toolset1
#'
#' ## Create auc5 tools
#' toolset2 <- create_toolset(set_names = "auc5")
#' toolset2
#'
#' @export
create_toolset <- function(tool_names = NULL, set_names = NULL, calc_auc = TRUE,
store_res = TRUE) {
# Validate arguments
new_args <- .validate_create_toolset_args(
tool_names, set_names, calc_auc,
store_res
)
# Create a tool set
init_data <- .prepare_init(
new_args$tool_names, new_args$set_names,
new_args$calc_auc, new_args$store_res
)
toolobjs <- .create_toolobjs(init_data)
toolobjs
}
#
# Prepare init data for tool classes
#
.prepare_init <- function(tool_names, set_names, calc_auc, store_res) {
# Initialize
new_tool_names <- NULL
new_set_names <- NULL
new_init_params <- NULL
# Set tool names
if (!is.null(tool_names)) {
new_tool_names <- tool_names
new_set_names <- tool_names
new_init_params <- replicate(length(tool_names),
list(
calc_auc = calc_auc,
store_res = store_res
),
simplify = FALSE
)
for (i in seq_along(tool_names)) {
new_init_params[[i]]$setname <- tool_names[i]
}
}
# Set tool names from predefined sets
if (!is.null(set_names)) {
for (sname in set_names) {
if (grepl("5$", sname)) {
ntnames <- c("ROCR", "AUCCalculator", "PerfMeas", "PRROC", "precrec")
} else if (grepl("4$", sname)) {
ntnames <- c("ROCR", "AUCCalculator", "PerfMeas", "precrec")
}
nsname <- rep(sname, length(ntnames))
if (grepl("^crv", sname)) {
new_calc_auc <- FALSE
new_store_res <- TRUE
} else if (grepl("^auc", sname)) {
new_calc_auc <- TRUE
new_store_res <- FALSE
} else if (grepl("^def", sname)) {
new_calc_auc <- TRUE
new_store_res <- TRUE
}
nparams <- replicate(length(ntnames), list(
calc_auc = new_calc_auc,
store_res = new_store_res,
setname = sname
),
simplify = FALSE
)
if (sname == "auc5") {
nparams[[4]]$curve <- FALSE
}
new_tool_names <- c(new_tool_names, ntnames)
new_set_names <- c(new_set_names, nsname)
new_init_params <- c(new_init_params, nparams)
}
}
# Return updated names with parameters
list(new_tool_names, new_set_names, new_init_params)
}
#
# Create tool objects
#
.create_toolobjs <- function(init_data) {
tool_names <- init_data[[1]]
# set_names <- init_data[[2]]
init_params <- init_data[[3]]
tfunc <- function(i) {
tool_cls <- eval(as.name(paste0("Tool", tool_names[i])))
if (length(init_params[[i]]) == 0) {
obj <- tool_cls$new()
} else {
obj <- do.call(tool_cls$new, init_params[[i]])
}
obj
}
toolobjs <- lapply(seq_along(tool_names), tfunc)
names(toolobjs) <- .rename_tool_names(tool_names)
rfunc <- function(i) {
if (toolobjs[[i]]$get_toolname() == toolobjs[[i]]$get_setname()) {
toolobjs[[i]]$set_setname(names(toolobjs)[[i]])
}
}
lapply(seq_along(tool_names), rfunc)
toolobjs
}
#
# Rename duplicated tool names
#
.rename_tool_names <- function(tool_names) {
renamed <- list()
for (idx in seq_along(tool_names)[duplicated(tool_names)]) {
tname <- as.character(tool_names[idx])
if (tname %in% names(renamed)) {
renamed[[tname]] <- renamed[[tname]] + 1
} else {
renamed[[tname]] <- 2
}
tool_names[idx] <- paste0(tname, ".", renamed[[tname]])
}
tool_names
}
#
# Validate arguments and return updated arguments
#
.validate_create_toolset_args <- function(tool_names, set_names, calc_auc,
store_res) {
if (is.null(tool_names) && is.null(set_names)) {
stop("Invalid tool_names and/or set_names", call. = FALSE)
}
new_tool_names <- NULL
if (!is.null(tool_names)) {
tool_names <- tolower(tool_names)
for (tname in tool_names) {
assertthat::assert_that(assertthat::is.string(tname))
idx <- pmatch(tname, c(
"rocr", "auccalculator", "perfmeas", "prroc",
"precrec"
))
if (is.na(idx)) {
stop("Invalid tool_names", call. = FALSE)
}
if (idx == 1) {
new_tool_names <- c(new_tool_names, "ROCR")
} else if (idx == 2) {
new_tool_names <- c(new_tool_names, "AUCCalculator")
} else if (idx == 3) {
new_tool_names <- c(new_tool_names, "PerfMeas")
} else if (idx == 4) {
new_tool_names <- c(new_tool_names, "PRROC")
} else if (idx == 5) {
new_tool_names <- c(new_tool_names, "precrec")
}
}
}
if (!is.null(set_names)) {
set_names <- tolower(set_names)
t_set_names <- c("def5", "auc5", "crv5", "def4", "auc4", "crv4")
if (length(setdiff(set_names, t_set_names)) != 0) {
stop("Invalid set_names. Valid set_names are 'def5', 'auc5', 'crv5',
'def4', 'auc4', or 'crv4'.",
call. = FALSE
)
}
}
assertthat::assert_that(assertthat::is.flag(calc_auc))
assertthat::assert_that(assertthat::is.flag(store_res))
list(
tool_names = new_tool_names, set_names = set_names, calc_auc = calc_auc,
store_res = store_res
)
}
#' Create a set of tools
#'
#' The \code{create_toolset} function takes names of predefined tools and
#' generates a list of wrapper functions for Precision-Recall curve
#' calculations.
#'
#' @param tool_name A single string to specify the name of a user-defined tool.
#'
#' @param func A function to calculate a Precision-Recall curve and the AUC. It
#' should take an element of the test dataset generated by
#' \code{\link{create_testset}} as an argument. It also should return a list
#' with three elements - 'x', 'y', and 'auc' that represent calculated recall
#' and precision values plus the AUC score.
#' See \code{\link{create_example_func}} for an example.
#'
#' @param calc_auc A Boolean value to specify whether the AUC score should be
#' calculated.
#'
#' @param store_res A Boolean value to specify whether the calculated curve is
#' retrieved and stored.
#'
#' @param x Set pre-calculated recall values.
#'
#' @param y Set pre-calculated precision values.
#'
#' @return A list of \code{R6} tool objects.
#'
#' @seealso \code{\link{create_toolset}} to create a predefined tool set.
#' \code{\link{create_testset}} for \code{testset}.
#' \code{\link{create_example_func}} to create an example function.
#'
#' @examples
#' ## Create a new tool interface called "xyz"
#' efunc <- create_example_func()
#' toolset1 <- create_usrtool("xyz", efunc)
#' toolset1
#'
#' ## Example function with a correct argument
#' testset <- create_usrdata("bench", scores = c(0.1, 0.2), labels = c(1, 0))
#' retf <- efunc(testset[[1]])
#' retf
#'
#' @export
create_usrtool <- function(tool_name, func, calc_auc = TRUE, store_res = TRUE,
x = NA, y = NA) {
# Validate arguments
new_args <- .validate_create_usrtool_args(
tool_name, func, calc_auc,
store_res, x, y
)
# Define a wrapper function
if (is.function(new_args$func)) {
usr_wrapper <- function(testset, calc_auc, store_res) {
# Calculate Precision-Recall curve
preds <- new_args$func(testset)
# Get AUC
if (calc_auc) {
if (is.function(preds$auc)) {
aucscore <- preds$auc()
} else {
aucscore <- preds$auc
}
} else {
aucscore <- NA
}
# Return x and y values if requested
if (store_res) {
if (is.function(preds$x)) {
xvals <- preds$x()
} else {
xvals <- preds$x
}
if (is.function(preds$y)) {
yvals <- preds$y()
} else {
yvals <- preds$y
}
list(x = xvals, y = yvals, auc = aucscore)
} else {
NULL
}
}
} else {
usr_wrapper <- NA
}
# Create a tool interface
tool_cls <- R6::R6Class(
paste0("Tool", new_args$tool_name),
inherit = ToolIFBase,
private = list(
toolname = new_args$tool_name, f_wrapper = usr_wrapper,
helpfile = FALSE
)
)
init_params <- .prepare_init(
new_args$tool_name, NULL, new_args$calc_auc,
new_args$store_res
)[[3]]
if (!any(is.na(x)) && !any(is.na(y))) {
init_params[[1]]$x <- x
init_params[[1]]$y <- y
}
obj <- list(do.call(tool_cls$new, init_params[[1]]))
names(obj) <- new_args$tool_name
obj
}
#
# Validate arguments and return updated arguments
#
.validate_create_usrtool_args <- function(tool_name, func, calc_auc,
store_res, x, y) {
assertthat::assert_that(assertthat::is.string(tool_name))
if (missing(func)) {
assertthat::assert_that(is.numeric(x))
assertthat::assert_that(is.numeric(y))
func <- NA
} else {
assertthat::assert_that(is.function(func))
args <- formals(func)
if (length(args) != 1) {
stop("Invalid func. It must contain only one argument.", call. = FALSE)
}
testset <- create_usrdata("bench", scores = c(0.1, 0.2), labels = c(1, 0))
single_testset <- testset[[1]]
catcherr <- tryCatch(func(single_testset),
error = function(e) {
return("ERROR")
},
warning = function(w) invisible(NULL)
)
if (assertthat::is.string(catcherr) && catcherr == "ERROR") {
stop("Invalid func. It failed with func(single_testset).", call. = FALSE)
}
func_ret <- func(single_testset)
if (!is.list(func_ret) || is.null(names(func_ret)) ||
!all(names(func_ret) == c("x", "y", "auc"))) {
stop("Invalid func. It must return list(x, y, auc).", call. = FALSE)
}
assertthat::assert_that(assertthat::is.flag(calc_auc))
assertthat::assert_that(assertthat::is.flag(store_res))
}
list(
tool_name = tool_name, func = func, calc_auc = calc_auc,
store_res = store_res, x = x, y = y
)
}
#' Create an example for the func argument of the create_usrtool function
#'
#' The \code{create_example_func} function creates an example for the
#' \code{\link{create_usrtool}} function.
#'
#' @return A function as an example for \code{\link{create_usrtool}}
#'
#' @seealso \code{\link{create_usrtool}} requires the same format.
#' \code{\link{create_testset}} for \code{testset}.
#'
#' @examples
#' ## Create a function
#' func <- create_example_func()
#' func
#'
#' @export
create_example_func <- function() {
function(single_testset) {
# Prepare data
scores <- single_testset$get_scores()
# Calculate Precision-Recall curve
list(
x = seq(0, 1, 1 / length(scores)),
y = seq(0, 1, 1 / length(scores)),
auc = 0.5
)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.