#' @export
#' @title make a (multicore) mlr benchmark experiment of your learner(s) on your task(s)
#' @author Thomas Goossens
#' @import mlr
#' @import parallelMap
#' @importFrom magrittr %>%
#' @param tasks a list which elements are object of class \code{mlr::makeRegrTask()}
#' @param learners a list which elements are object of class \code{mlr::makeLearner()}
#' @param measures a list of the mlr performance metrics you want to get. Default is \code{mlr::list(rmse)}
#' @param keep.pred a boolean specifying if you want to keep the bmr preds. defaut = \code{TRUE}. Necessary to further analyze benchamrk performances
#' @param models a boolean specifying if you want to keep the bmr models. defaut = \code{FALSE}
#' @param level a character specifying the paralelllization level. Default = "mlr.benchmark"
#' @param resampling a character specifying the type of mlr's Cross-Validation strategy. Default = \code{"LOO"}
#' @param cpus an integer specifying the number of cpus to use for the benchamrk. Default is 4
#' @param temp_dir a character specifying the path of an exising directory where you want to save the bmr temporary outpus.
#' @param prefix a character specifying the prefix you want to add to each bmr temporary file name.
#' @param groupSize a numeric specifying the number of tasks you want to benchamrk in a single batch. If \code{NULL} the value will be set to the length of \code{tasks} argument. Default is \code{NULL}
#' @param removeTemp a boolean specifying if the temporary .rds generated by the function must be deleted at the end of the process.
#' @param crash a boolean. \code{TRUE} if you want the function crash and stop. \code{FALSE} if yo uwant the function to handle the error and continue. This is based on the \code{mlr::configureMlr(on.learner.error = "warn")}
#' Default to FALSE
#' @return A 2 elements named list
#' \itemize{
#' \item \code{snitch} : a boolean. Is \code{TRUE} if function has provided the expected result. Is \code{FALSE} is function throws an error
#' \item \code{output} : a named list which elements are :\itemize{
#' \item \code{value} : an element of class \code{mlr::benchmark()}
#' \item \code{condition} : a character specifying the condition encountered by the function : success, warning, or error.
#' \item \code{message} : a character specifying the message relative to the condition.
#' }
#' }
#' @details The function handles learners error. See \code{mlr::configureMlr()}.
#' @examples
#'\dontrun{
#' # load magrittr for pipe use : %>%
#' library(magrittr)
#'
#' # create the dataset
#' myDataset = makeDataset(
#' dfrom = "2017-03-04T15:00:00Z",
#' dto = "2017-03-04T18:00:00Z",
#' sensor = "tsa")
#'
#' # extract the list of hourly sets of records
#' myDataset = myDataset$output$value
#'
#' # create the tasks
#' myTasks = purrr::map(myDataset, makeTask, target = "tsa")
#'
#' # extract the used sids of each task from the outputs
#' myUsedSids = myTasks %>% purrr::modify_depth(1, ~.$output$stations$used)
#'
#' # extract the tasks from the outputs
#' myTasks = myTasks %>% purrr::modify_depth(1, ~.$output$value$task)
#'
#' # Conduct a batch of benchmarks experiments without saving temp files
#'myBmrsBatch = makeBmrsBatch(
#' tasks = myTasks,
#' learners = agrometeorLearners,
#' measures = list(mlr::rmse),
#' keep.pred = TRUE,
#' models = FALSE,
#' groupSize = NULL,
#' level = "mlr.benchmark",
#' resamplings = "LOO",
#' cpus = 1,
#' prefix = NULL,
#' temp_dir = NULL,
#' removeTemp = FALSE,
#' crash = FALSE)
#'
#' # Keep the relevant information
#' myBmrsBatch = myBmrsBatch$output$value
#'
#' # make a plot from the myBmrsBatch
#' mlr::plotBMRBoxplots(myBmrsBatch,
#' measure = mlr::rmse,
#' order.lrn = getBMRLearnerIds(myBmrsBatch),
#' pretty.names = FALSE)
#' }
makeBmrsBatch <- function(
tasks,
learners,
measures = list(rmse),
keep.pred = TRUE,
models = FALSE,
level = "mlr.benchmark",
resamplings = "LOO",
cpus = 4,
temp_dir = NULL,
prefix = NULL,
groupSize = NULL,
removeTemp = FALSE,
crash = FALSE){
output = list(value = NULL, condition = list(type = NULL, message = NULL))
snitch = FALSE
doBenchmark = function(){
# configure mlr to avoid crash if a learner fails on a task
# https://stackoverflow.com/questions/55608882/how-to-make-the-benchmark-function-not-to-fail-if-a-specific-learner-fails-on-a
if (!isTRUE(crash)) {
mlr::configureMlr(on.learner.error = "warn", on.error.dump = TRUE)
} else{
mlr::configureMlr(on.learner.error = "stop", on.error.dump = FALSE)
}
if (is.null(groupSize)) {
groupSize = length(tasks)
}
# hack if groupSize set to > than length(tasks)
if (groupSize > length(tasks)) {
groupSize = length(tasks)
}
# split tasks in multiple subgroups to avoid memory saturation
tasks.groups.start = seq(from = 1, to = length(tasks), by = groupSize)
tasks.groups.end = seq(from = groupSize, to = length(tasks), by = groupSize)
# conducting the bmrs by subgroups and writting subgroup results to temp .rds file if required
bmr = lapply(seq_along(as.list(tasks.groups.start)),
function(x) {
# message
message(paste0(
"Conducting batch of benchmark experiments for tasks " ,
tasks.groups.start[x], "-",
tasks.groups.end[x]))
# set seed to make bmr experiments reproducibles
set.seed(1985)
# enable parallelization with level = mlr.resample
if (cpus > 1) {
parallelMap::parallelStart(mode = "multicore", cpus = cpus, level = level)
}
# hack to avoid wrong last task number
if (is.na(tasks.groups.end[x])) {tasks.groups.end[x] = tasks.groups.start[x]}
# benchmark
bmr = mlr::benchmark(
learners = learners,
tasks = tasks[tasks.groups.start[x]:tasks.groups.end[x]],
resamplings = mlr::makeResampleDesc(resamplings),
measures = measures,
keep.pred = keep.pred,
models = models)
# success message
message(paste0(
"Results of batch of Benchmark experiments for tasks " ,
mlr::getTaskId(tasks[[tasks.groups.start[x]]]), " - ",
mlr::getTaskId(tasks[[tasks.groups.end[x]]]), " conducted"))
# stop the parallelized computing
if (cpus > 1) {
parallelMap::parallelStop()
}
# if temp_dir exists
if (!is.null(temp_dir)){
# save the bmr object to a file
saveRDS(object = bmr, file = paste0(temp_dir,
"/",
prefix,
"_bmr_",
mlr::getTaskId(tasks[[tasks.groups.start[x]]]),
"_",
mlr::getTaskId(tasks[[tasks.groups.end[x]]]),
".rds"))
# success message
message(paste0(
"Results of batch of Benchmark experiments for tasks " ,
mlr::getTaskId(tasks[[tasks.groups.start[x]]]), " - ",
mlr::getTaskId(tasks[[tasks.groups.end[x]]]), " written to file. "))
# remove the object stored in RAM
rm(bmr)
# return NULL
return(NULL)
}else{
return(bmr)
}
}) # end of lapply
# if temp_dir exists
if (!is.null(temp_dir)) {
# loading all the temp bmr files and merging in a single bmrs object
bmr_files = list.files(path = temp_dir, pattern = paste0(prefix, "_bmr_"), full.names = TRUE)
bmr = lapply(bmr_files, readRDS)
# deleting temporary .rds files if removeTemp = true
if (isTRUE(removeTemp)) {
file.remove(bmr_files)
}
}
if (length(bmr) > 1) {bmr = mergeBenchmarkResults(bmrs)}
else {bmr = bmr[[1]]}
# Throw a success message
message("Success, batch of benchmark experiment conducted")
# Set mlr back to default settings
mlr::configureMlr(on.learner.error = "stop", on.error.dump = FALSE)
# return the bmr results
return(bmr = bmr)
}
tryCatch(
expr = {
# check if output dir exists
if (!is.null(temp_dir)){
stopifnot(dir.exists(temp_dir))
}
# in case everything went fine, do makeBmrsBatch
output$value = doBenchmark()
output$condition$type = "success"
output$condition$message = "benchmarks conducted"
snitch = TRUE
},
warning = function(w){
warning = paste0(
"AgrometeoR::makeBmrsBatch raised a warning -> ",
w)
snitch <<- TRUE
output$value <<- doBenchmark()
output$condition$type <<- "warning"
output$condition$message <<- warning
},
error = function(e){
error = paste0(
"AgrometeoR::makeBmrsBatch raised an error -> ",
e)
output$condition$type <<- "error"
output$condition$message <<- error
},
finally = {
finalMessage = paste0(
"makeBmrsBatch has encountered : ",
output$condition$type,
". \n",
"All done with makeBmrsBatch "
)
message(finalMessage)
return(list(snitch = snitch, output = output))
}
)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.