R/batch_grid_search.R

#' Perform batch grid search over a grid of parameters.
#' @description If your parameter search is likely to take a long time, this function
#' allows you to do it in batches, saving the result of the search results to disk
#' after each batch. This incurs a penalty in running time, because the assessment splits
#' are recomputed (or `baked` in `recipes`). terminology at the beginning of each batch.
#' The smaller the batch size, the bigger the penalty.
#' @param resamples A data.frame with columns `splits` and `id`, created using the `rsample` package.
#' @param recipe The recipe to use. See package `recipes`.
#' @param param_grid List of list of parameters passed to the `train_predict_func` function.
#'  (generated by purrr::cross for example)
#' @param scoring_func Your custom train/predict/score function. 
#' Must take as parameters: 
#' \itemize{
#'     \item a training dataframe
#'     \item the name of the target variable in the training dataframe
#'     \item a list of parameters (these are the hyperparameters we are tuning)
#'     \item an evaluation dataframe
#'     \item dots. These are additional non-tunable parameters that could be passed to the function.
#' }
#' @param ... Optional params passed to train_predict_func.
#' @param batch_size Size of the batches.
#' @param out_folder Where to save the intermediate batch results. Folder will be created if not found.
#' @param file_prefix Used to name the results files.
#' @param overwrite Overwrite existing results files or create new ones.
#' @param verboseity Integer: level of verbosity, or TRUE/FALSE for max/min verbosity.
#' @details `scoring_func` can return a single score as a numeric vector, 
#' or multiple scores in a data.frame. 
#' 
#' The output folder will be scanned for files corresponding to pattern <file_prefix>_n.RDS.
#' If overwrite is false, the outputs of the current run will be witten to files starting at n + 1.
#' Otherwise it starts at 1 (i.e. <file_prefix_1.RDS).
#'     Option verbose will print the batch number at the beginning of each batch.
#' @return A tidy data.frame, the aggregate result. This is the same as without the batches.
#' @export
batch_grid_search <- 
  function(resamples, 
           recipe, 
           param_grid, 
           scoring_func, 
           ...,
           batch_size,
           out_folder = '.',
           file_prefix = 'batch_',
           overwrite = FALSE,
           verbosity = TRUE){
    
  partial_grid_search <- 
    purrr::partial(
      grid_search,
      resamples = resamples,
      rec = recipe,
      scoring_func = scoring_func,
      ...,
      verbosity = verbosity
    )
  
  batch_list <- make_batches(param_grid, batch_size)
  
  nb_batches <- length(batch_list)
  
  new_file_numbers <- make_file_numbers(overwrite = overwrite, 
                                       out_folder = out_folder, 
                                       file_prefix = file_prefix, 
                                       n = nb_batches)
  
  str(new_file_numbers)
  
  final_res <- 
    purrr::pmap(
      list(
        batch_list,
        1:nb_batches,
        new_file_numbers
      ),
      function(param_grid, i, new_file_number){
        if(verbose) print(paste("Batch", i, "/", nb_batches))
        res_df <- partial_grid_search(param_grid = param_grid)
        saveRDS(res_df, paste0(out_folder, '/', file_prefix, new_file_number, '.RDS'))
        res_df
      }
    ) %>%
    (dplyr::bind_rows)
  
  final_res
}
artichaud1/cook documentation built on May 21, 2019, 9:23 a.m.