R/select_best.R

Defines functions select_best select_best.default select_best.nested_cv

Documented in select_best

#' Select the highest average scoring hyperparameters from a tune object
#'
#' @param resamples tibble
#'@param maximize logical, select best hyperparameters by maximizing or minimizing the tuning metric
#'
#' @return list
#' @export
select_best <- function(resamples, maximize = TRUE) {
  UseMethod("select_best", resamples)
}

#' @export
#' @importFrom dplyr summarize filter select bind_rows ungroup row_number group_by_at
#' @importFrom purrr map
select_best.default <- function(resamples, maximize = TRUE) {

  tunable_args <- attr(resamples, "tuning")

  tune_scores_list <- resamples$tune_scores %>%
    bind_rows(.id = "fold")

  fun <- if (isTRUE(maximize)) max else min

  best_params <- tune_scores_list %>%
    group_by_at(tunable_args) %>%
    summarize(score = mean(!! sym("score"))) %>%
    ungroup() %>%
    filter(!!sym("score") == fun(!!sym("score"))) %>%
    select(-c(!!sym("score"))) %>%
    filter(row_number() == 1)

  map(best_params, ~ .x)
}

#' @export
#' @importFrom dplyr summarize filter select ungroup row_number group_by_at
#' @importFrom purrr map_dfr map
select_best.nested_cv <- function(resamples, maximize = TRUE) {

  tunable_args <- attr(resamples, "tuning")

  tune_scores <- map_dfr(resamples$tune_scores, function(x) {
    x %>% map_dfr(~ .x) %>%
      group_by_at(tunable_args) %>%
      summarize(score = mean(!!sym("score")))
  }, .id = "id")

  fun <- if (isTRUE(maximize)) max else min

  best_params <- tune_scores %>%
    group_by_at(tunable_args) %>%
    summarize(score = mean(!!sym("score"))) %>%
    ungroup() %>%
    filter(!!sym("score") == fun(!!sym("score"))) %>%
    select(-c(!! sym("score"))) %>%
    filter(row_number() == 1)

  map(best_params, ~ .x)
}
stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.