R/tool_interface.R

Defines functions create_example_func .validate_create_usrtool_args create_usrtool .validate_create_toolset_args .rename_tool_names .create_toolobjs .prepare_init create_toolset

Documented in create_example_func create_toolset create_usrtool

#' Create a set of tools
#'
#' The \code{create_toolset} function takes names of predefined tools and
#'   generates a list of wrapper functions for Precision-Recall curve
#'   calculations.
#'
#' @param tool_names A character vector to specify the names of
#'   performance evaluation tools. The names for the following five tools can be
#'   currently used.
#'
#'   \itemize{
#'     \item ROCR
#'     \item AUCCalculator
#'     \item PerfMeas
#'     \item PRROC
#'     \item precrec
#'   }
#'
#' @param set_names A character vector to specify a predefined set name.
#'   Following six sets are currently available.
#'
#'   \describe{
#'     \item{"def5"}{A set of 5 tools with \code{calc_auc = TRUE}
#'       and \code{store_res = TRUE}}
#'     \item{"auc5"}{A set of 5 tools with \code{calc_auc = TRUE}
#'       and \code{store_res = FALSE}}
#'     \item{"crv5"}{A set of 5 tools with \code{calc_auc = FALSE}
#'       and \code{store_res = TRUE}}
#'     \item{"def4"}{A set of 4 tools with \code{calc_auc = TRUE}
#'       and \code{store_res = TRUE}}
#'     \item{"auc4"}{A set of 4 tools with \code{calc_auc = TRUE}
#'       and \code{store_res = FALSE}}
#'     \item{"crv4"}{A set of 4 tools with \code{calc_auc = FALSE}
#'       and \code{store_res = TRUE}}
#'   }
#'
#' @param calc_auc A Boolean value to specify whether the AUC score should be
#'   calculated.
#'
#' @param store_res A Boolean value to specify whether the calculated curve is
#'   retrieved and stored
#'
#' @return A list of \code{R6} tool objects.
#'
#' @seealso \code{\link{run_benchmark}} and \code{\link{run_evalcurve}} require
#'  the list of the tools generated by this function
#'  \code{\link{ToolROCR}}, \code{\link{ToolAUCCalculator}},
#'  \code{\link{ToolPerfMeas}}, \code{\link{ToolPRROC}}, and
#'  \code{\link{Toolprecrec}} as R6 tool classes.
#'
#' @examples
#' ## Create ROCR and precrec
#' toolset1 <- create_toolset(c("ROCR", "precrec"))
#' toolset1
#'
#' ## Create auc5 tools
#' toolset2 <- create_toolset(set_names = "auc5")
#' toolset2
#'
#' @export
create_toolset <- function(tool_names = NULL, set_names = NULL, calc_auc = TRUE,
                           store_res = TRUE) {
  # Validate arguments
  new_args <- .validate_create_toolset_args(
    tool_names, set_names, calc_auc,
    store_res
  )

  # Create a tool set
  init_data <- .prepare_init(
    new_args$tool_names, new_args$set_names,
    new_args$calc_auc, new_args$store_res
  )
  toolobjs <- .create_toolobjs(init_data)
  toolobjs
}

#
# Prepare init data for tool classes
#
.prepare_init <- function(tool_names, set_names, calc_auc, store_res) {
  # Initialize
  new_tool_names <- NULL
  new_set_names <- NULL
  new_init_params <- NULL

  # Set tool names
  if (!is.null(tool_names)) {
    new_tool_names <- tool_names
    new_set_names <- tool_names
    new_init_params <- replicate(length(tool_names),
      list(
        calc_auc = calc_auc,
        store_res = store_res
      ),
      simplify = FALSE
    )
    for (i in seq_along(tool_names)) {
      new_init_params[[i]]$setname <- tool_names[i]
    }
  }

  # Set tool names from predefined sets
  if (!is.null(set_names)) {
    for (sname in set_names) {
      if (grepl("5$", sname)) {
        ntnames <- c("ROCR", "AUCCalculator", "PerfMeas", "PRROC", "precrec")
      } else if (grepl("4$", sname)) {
        ntnames <- c("ROCR", "AUCCalculator", "PerfMeas", "precrec")
      }
      nsname <- rep(sname, length(ntnames))

      if (grepl("^crv", sname)) {
        new_calc_auc <- FALSE
        new_store_res <- TRUE
      } else if (grepl("^auc", sname)) {
        new_calc_auc <- TRUE
        new_store_res <- FALSE
      } else if (grepl("^def", sname)) {
        new_calc_auc <- TRUE
        new_store_res <- TRUE
      }
      nparams <- replicate(length(ntnames), list(
        calc_auc = new_calc_auc,
        store_res = new_store_res,
        setname = sname
      ),
      simplify = FALSE
      )
      if (sname == "auc5") {
        nparams[[4]]$curve <- FALSE
      }

      new_tool_names <- c(new_tool_names, ntnames)
      new_set_names <- c(new_set_names, nsname)
      new_init_params <- c(new_init_params, nparams)
    }
  }

  # Return updated names with parameters
  list(new_tool_names, new_set_names, new_init_params)
}

#
# Create tool objects
#
.create_toolobjs <- function(init_data) {
  tool_names <- init_data[[1]]
  # set_names <- init_data[[2]]
  init_params <- init_data[[3]]

  tfunc <- function(i) {
    tool_cls <- eval(as.name(paste0("Tool", tool_names[i])))
    if (length(init_params[[i]]) == 0) {
      obj <- tool_cls$new()
    } else {
      obj <- do.call(tool_cls$new, init_params[[i]])
    }

    obj
  }
  toolobjs <- lapply(seq_along(tool_names), tfunc)
  names(toolobjs) <- .rename_tool_names(tool_names)

  rfunc <- function(i) {
    if (toolobjs[[i]]$get_toolname() == toolobjs[[i]]$get_setname()) {
      toolobjs[[i]]$set_setname(names(toolobjs)[[i]])
    }
  }
  lapply(seq_along(tool_names), rfunc)

  toolobjs
}

#
# Rename duplicated tool names
#
.rename_tool_names <- function(tool_names) {
  renamed <- list()
  for (idx in seq_along(tool_names)[duplicated(tool_names)]) {
    tname <- as.character(tool_names[idx])
    if (tname %in% names(renamed)) {
      renamed[[tname]] <- renamed[[tname]] + 1
    } else {
      renamed[[tname]] <- 2
    }
    tool_names[idx] <- paste0(tname, ".", renamed[[tname]])
  }
  tool_names
}

#
# Validate arguments and return updated arguments
#
.validate_create_toolset_args <- function(tool_names, set_names, calc_auc,
                                          store_res) {
  if (is.null(tool_names) && is.null(set_names)) {
    stop("Invalid tool_names and/or set_names", call. = FALSE)
  }

  new_tool_names <- NULL
  if (!is.null(tool_names)) {
    tool_names <- tolower(tool_names)
    for (tname in tool_names) {
      assertthat::assert_that(assertthat::is.string(tname))
      idx <- pmatch(tname, c(
        "rocr", "auccalculator", "perfmeas", "prroc",
        "precrec"
      ))
      if (is.na(idx)) {
        stop("Invalid tool_names", call. = FALSE)
      }

      if (idx == 1) {
        new_tool_names <- c(new_tool_names, "ROCR")
      } else if (idx == 2) {
        new_tool_names <- c(new_tool_names, "AUCCalculator")
      } else if (idx == 3) {
        new_tool_names <- c(new_tool_names, "PerfMeas")
      } else if (idx == 4) {
        new_tool_names <- c(new_tool_names, "PRROC")
      } else if (idx == 5) {
        new_tool_names <- c(new_tool_names, "precrec")
      }
    }
  }

  if (!is.null(set_names)) {
    set_names <- tolower(set_names)
    t_set_names <- c("def5", "auc5", "crv5", "def4", "auc4", "crv4")
    if (length(setdiff(set_names, t_set_names)) != 0) {
      stop("Invalid set_names. Valid set_names are 'def5', 'auc5', 'crv5',
           'def4', 'auc4', or 'crv4'.",
        call. = FALSE
      )
    }
  }

  assertthat::assert_that(assertthat::is.flag(calc_auc))
  assertthat::assert_that(assertthat::is.flag(store_res))

  list(
    tool_names = new_tool_names, set_names = set_names, calc_auc = calc_auc,
    store_res = store_res
  )
}

#' Create a set of tools
#'
#' The \code{create_toolset} function takes names of predefined tools and
#'   generates a list of wrapper functions for Precision-Recall curve
#'   calculations.
#'
#' @param tool_name A single string to specify the name of a user-defined tool.
#'
#' @param func A function to calculate a Precision-Recall curve and the AUC. It
#'   should take an element of the test dataset generated by
#'   \code{\link{create_testset}} as an argument. It also should return a list
#'   with three elements - 'x', 'y', and 'auc' that represent calculated recall
#'   and precision values plus the AUC score.
#'   See \code{\link{create_example_func}} for an example.
#'
#' @param calc_auc A Boolean value to specify whether the AUC score should be
#'   calculated.
#'
#' @param store_res A Boolean value to specify whether the calculated curve is
#'   retrieved and stored.
#'
#' @param x Set pre-calculated recall values.
#'
#' @param y Set pre-calculated precision values.
#'
#' @return A list of \code{R6} tool objects.
#'
#' @seealso \code{\link{create_toolset}} to create a predefined tool set.
#'  \code{\link{create_testset}} for \code{testset}.
#'  \code{\link{create_example_func}} to create an example function.
#'
#' @examples
#' ## Create a new tool interface called "xyz"
#' efunc <- create_example_func()
#' toolset1 <- create_usrtool("xyz", efunc)
#' toolset1
#'
#' ## Example function with a correct argument
#' testset <- create_usrdata("bench", scores = c(0.1, 0.2), labels = c(1, 0))
#' retf <- efunc(testset[[1]])
#' retf
#'
#' @export
create_usrtool <- function(tool_name, func, calc_auc = TRUE, store_res = TRUE,
                           x = NA, y = NA) {
  # Validate arguments
  new_args <- .validate_create_usrtool_args(
    tool_name, func, calc_auc,
    store_res, x, y
  )

  # Define a wrapper function
  if (is.function(new_args$func)) {
    usr_wrapper <- function(testset, calc_auc, store_res) {
      # Calculate Precision-Recall curve
      preds <- new_args$func(testset)

      # Get AUC
      if (calc_auc) {
        if (is.function(preds$auc)) {
          aucscore <- preds$auc()
        } else {
          aucscore <- preds$auc
        }
      } else {
        aucscore <- NA
      }

      # Return x and y values if requested
      if (store_res) {
        if (is.function(preds$x)) {
          xvals <- preds$x()
        } else {
          xvals <- preds$x
        }

        if (is.function(preds$y)) {
          yvals <- preds$y()
        } else {
          yvals <- preds$y
        }

        list(x = xvals, y = yvals, auc = aucscore)
      } else {
        NULL
      }
    }
  } else {
    usr_wrapper <- NA
  }


  # Create a tool interface
  tool_cls <- R6::R6Class(
    paste0("Tool", new_args$tool_name),
    inherit = ToolIFBase,
    private = list(
      toolname = new_args$tool_name, f_wrapper = usr_wrapper,
      helpfile = FALSE
    )
  )
  init_params <- .prepare_init(
    new_args$tool_name, NULL, new_args$calc_auc,
    new_args$store_res
  )[[3]]

  if (!any(is.na(x)) && !any(is.na(y))) {
    init_params[[1]]$x <- x
    init_params[[1]]$y <- y
  }

  obj <- list(do.call(tool_cls$new, init_params[[1]]))
  names(obj) <- new_args$tool_name

  obj
}

#
# Validate arguments and return updated arguments
#
.validate_create_usrtool_args <- function(tool_name, func, calc_auc,
                                          store_res, x, y) {
  assertthat::assert_that(assertthat::is.string(tool_name))

  if (missing(func)) {
    assertthat::assert_that(is.numeric(x))
    assertthat::assert_that(is.numeric(y))
    func <- NA
  } else {
    assertthat::assert_that(is.function(func))
    args <- formals(func)
    if (length(args) != 1) {
      stop("Invalid func. It must contain only one argument.", call. = FALSE)
    }

    testset <- create_usrdata("bench", scores = c(0.1, 0.2), labels = c(1, 0))
    single_testset <- testset[[1]]
    catcherr <- tryCatch(func(single_testset),
      error = function(e) {
        return("ERROR")
      },
      warning = function(w) invisible(NULL)
    )
    if (assertthat::is.string(catcherr) && catcherr == "ERROR") {
      stop("Invalid func. It failed with func(single_testset).", call. = FALSE)
    }

    func_ret <- func(single_testset)
    if (!is.list(func_ret) || is.null(names(func_ret)) ||
      !all(names(func_ret) == c("x", "y", "auc"))) {
      stop("Invalid func. It must return list(x, y, auc).", call. = FALSE)
    }

    assertthat::assert_that(assertthat::is.flag(calc_auc))
    assertthat::assert_that(assertthat::is.flag(store_res))
  }

  list(
    tool_name = tool_name, func = func, calc_auc = calc_auc,
    store_res = store_res, x = x, y = y
  )
}

#' Create an example for the func argument of the create_usrtool function
#'
#' The \code{create_example_func} function creates an example for the
#'  \code{\link{create_usrtool}} function.
#'
#' @return A function as an example for \code{\link{create_usrtool}}
#'
#' @seealso \code{\link{create_usrtool}} requires the same format.
#'   \code{\link{create_testset}} for \code{testset}.
#'
#' @examples
#' ## Create a function
#' func <- create_example_func()
#' func
#'
#' @export
create_example_func <- function() {
  function(single_testset) {
    # Prepare data
    scores <- single_testset$get_scores()

    # Calculate Precision-Recall curve
    list(
      x = seq(0, 1, 1 / length(scores)),
      y = seq(0, 1, 1 / length(scores)),
      auc = 0.5
    )
  }
}

Try the prcbench package in your browser

Any scripts or data that you put into this service are public.

prcbench documentation built on March 31, 2023, 5:27 p.m.