R/surrogate_search.R

Defines functions surrogate_search

Documented in surrogate_search

#' Use a surrogate model to find potentially good parameter combinations.
#' @description The surrogate model used is a random forest regressor (see `ranger` package).
#' We generate n_candidates random parameter combinations, and ask the surrogate model to 
#' rank them according to their predicted performance. We then take the top_n combinations
#' and pass them through to the actual underlying model.
#' @param resamples A data.frame with columns `splits` and `id`, created using the `rsample` package.
#' @param recipe The recipe to use. See package `recipes`.
#' @param param_set Param set created by calling ParamHelpers::makeParamset.
#' @param n Number of runs of the surrogate model. Can be a vector for iterative surrogate search.
#' @param scoring_func Your custom train/predict/score function. 
#' Must take as parameters: 
#' \itemize{
#'     \item a training dataframe
#'     \item the name of the target variable in the training dataframe
#'     \item a list of parameters (these are the hyperparameters we are tuning)
#'     \item an evaluation dataframe
#'     \item dots. These are additional non-tunable parameters that could be passed to the function.
#' }
#' @param ... Optional params passed to train_predict_func.
#' @param input Input to the surrogate model. This should be the output of a previous parameter search. 
#' @param n_candidates How many candidate parameter combinations should be evaluated
#' in each surrogate model run. If n_candidates == 0, fall back to regular random
#' search. If `n` is a vector, `n_candidates` must be a vector of length 1 or the same length as `n`.
#' @param top_n Out of the n_candidates, we will keep the top n (as predicted by the
#' surrogate) to test with the actual underlying model.
#' If `n` is a vector, `top_n` must be a vector of length 1 or the same length as `n`.
#' @param verbosity Integer: level of verbosity, or TRUE/FALSE 
#' (TRUE is maximum verbosity, FALSE is not verbose).
#' @return A tidy data.frame, with one column per parameter, columns to identify the
#' paramset and the fold, a column giving the row indices of the evaluation dataset,
#' and columns for the performance scores (these are taken from the scoring function if
#' it returned a data.frame, otherwise it will just be a _score_ column).
#' @note The results of the surrogate search (i.e. the performance of parameters 
#' selected by the surrogate model) are appended to the input after each run,
#' so that the surrogate model can also learn from it's own suggestions.
#' @details `scoring_func` can return a single score as a numeric vector, 
#' or multiple scores in a data.frame. 
#' @export
#' @importFrom ParamHelpers generateRandomDesign dfRowsToList
#' @importFrom dplyr bind_rows bind_cols arrange_at slice select
#' @importFrom purrr map map_dfr
#' @importFrom rlang quo_name UQ ":=" sym is_true is_false
#' @importFrom crayon yellow magenta
surrogate_search <- function(
    resamples, 
    recipe, 
    param_set,
    n,
    scoring_func, 
    ...,
    input,
    surrogate_target,
    n_candidates = 1000,
    top_n = 10,
    verbosity = TRUE
  ){
  
  do_it_once <- function(n, input, n_candidates, top_n, verbosity){
    
    if(is.null(input)){
      cat("No input provided. Falling back to random search\n")
      return(
        random_search(
          resamples = resamples,
          recipe = recipe,
          param_set = param_set,
          n = n * top_n,
          scoring_func = scoring_func,
          ...,
          verbosity = verbosity
        ) %>%
          mutate(surrogate_iteration = 0)
      )
    }
    
    partial_grid_search <- 
      purrr::partial(
        grid_search,
        resamples = resamples,
        recipe = recipe,
        scoring_func = scoring_func,
        ...,
        verbosity = verbosity - 1
      )
    
    results <- vector('list', n)
    
    for(i in 1:n){
      if(verbosity > 0){
        cat(yellow(paste("Surrogate search iteration:", i, "/", n, "\n")))
      }
      # output_files <- list.files(data_folder, pattern = paste0("^", file_prefix))
      # 
      # if(length(output_files) == 0){
      #   stop(paste("Data folder does not batch files starting with", file_prefix))
      # } 
      # 
      # # Read existing output
      # meta_data <- 
      #   output_files %>%
      #   map(readRDS) %>%
      #   bind_rows
      
      # Create random candidate params
      random_param_grid <- 
        generateRandomDesign(n_candidates, param_set, trafo = TRUE)
      
      target_var <- surrogate_target
      
      # Get average performance by fold in input
      byfold_perf <- 
        input %>%
        group_by_at(getParamIds(param_set)) %>%
        summarise(UQ(quo_name(target_var)) := mean(UQ(sym(target_var))))
      
      # Make performance predictions using surrogate
      surrogate_preds <- 
        ranger_regressor(
          train_df = byfold_perf,
          target_var = target_var,
          eval_df = random_param_grid,
          num.trees = 100
        ) 
      
      # Select top_n performing parameter combinations, according to surrogate
      selected_param_grid <- 
        bind_cols(random_param_grid, UQ(quo_name(target_var)) := surrogate_preds) %>%
        arrange_at(
          target_var
        ) %>%
        slice(1:top_n) %>%
        select(-matches(target_var))
      
      # Run grid search using selected candidates and save results
      results[[i]] <- 
        partial_grid_search(param_grid = selected_param_grid) %>%
        mutate(
          surrogate_iteration = i
        )
      
      # Append the results of this run to input so we can learn from them too
      input <- 
        bind_rows(
          input, 
          results[[i]] %>% select(-surrogate_iteration)
        )
    }
    
    bind_rows(results)
  }
  
  if(length(n) <= 0){
    stop("n must be positive")
  }
  
  if(length(top_n) < length(n)){
    if(length(top_n) == 1){
      top_n <- rep(top_n, length(n))
    }
    else{
      stop("top_n parameter must be of length 1 or same length as n")
    }
  }
  
  if(length(n_candidates) < length(n)){
    if(length(n_candidates) == 1){
      n_candidates <- rep(n_candidates, length(n))
    }
    else{
      stop("n_candidates parameter must be of length 1 or same length as n")
    }
  }
  
  if(is_true(verbosity)){
    verbosity <- Inf
  }
  else if(is_false(verbosity)){
    verbosity <- -Inf
  }
  
  all_output <- NULL
      
  for(i in 1:length(n)){
    if(verbosity > 0){
      cat(magenta(paste("Surrogate search run:", i, "/", length(n), "\n")))
    }
    
    output <- 
      do_it_once(n = n[i],
                 input = input,
                 n_candidates = n_candidates[i],
                 top_n = top_n[i],
                 verbosity = verbosity - 1) %>%
      mutate(
        surrogate_run = i
      )
    
    input <- bind_rows(input, output)
    all_output <- bind_rows(all_output, output)
  }
  
  all_output
}
artichaud1/tidygrid documentation built on July 6, 2018, 9:10 a.m.