R/tune.R

Defines functions tune tune.default tune.nested_cv fit_and_score

Documented in tune

#' Hyperparameter tuning on rsample objects
#'
#' `tune`` is a generic function that accepts a tibble of resampling partitions generated by the
#' resampling schemes available in the `rsample` package. A `parsnip` model specification and a
#' `recipes` recipe also need to be supplied to this function, along with a `yardstick` scoring
#' function. The `param_grid` function accepts a `grid_regular` or `grid_random` object of tuning
#' parameters generated by the `dials` package. Hyperparameter tuning is then performed on this
#' object, which is returned as a tibble with an additional list-column called `tune_scores`
#' containing the tuning scores.
#'
#' If the resampling object represents a `nested_cv` object, then hyperparameter tuning is performed
#' on the inner resampling partitions, and best best hyperparameters per outer fold are also
#' returned as additional columns in the output tibble.
#'
#' @param resamples tibble generated by rsample containing resamplign scheme
#' @param object parsnip model specification, or a pipeline object
#' @param recipe recipe object, optional if using a pipeline
#' @param param_grid dials grid_regular or grid_random object
#' @param scoring yardstick scoring function
#' @param maximize logical, maximize the scoring function, default = TRUE
#' @param .options future_options to pass additional packages required by tuning and fitting
#'   functions
#'
#' @return tibble containing resampling results
#' @export
#' @importFrom furrr future_options
tune <- function(resamples, object, recipe = NULL, param_grid, scoring, maximize = TRUE,
                 .options = future_options()) {
  UseMethod("tune", resamples)
}


#' @importFrom purrr map map_dbl pmap
#' @importFrom furrr future_map2_dbl future_options
#' @importFrom dplyr bind_cols select rename group_by summarize
#' @importFrom tidyr expand_grid nest unnest
#' @importFrom rlang sym
#' @export
tune.default <- function(resamples, object, recipe = NULL, param_grid, scoring,
                         maximize = TRUE, .options = future_options()) {

  arg_names <- names(param_grid)

  # cross with the param_grid and pipelines
  inner_resamples <- resamples %>%
    expand_grid(param_grid)

  if (inherits(object, "pipeline")) {
    inner_resamples$pipelines <- inner_resamples %>%
      select(!!arg_names) %>%
      pmap(function(...) object %>% update(!!!list2(...)))
  } else {
    inner_resamples$pipelines <- inner_resamples %>%
      select(!!arg_names) %>%
      pmap(function(...) pipeline(recipe, object, !!!list2(...)))
  }

  # fit and score all inner resamples
  inner_resamples$score <- future_map2_dbl(
    inner_resamples$splits, inner_resamples$pipelines,
    fit_and_score, scoring,
    .options = .options
  )

  # get best scoring hyperparameter per outer fold
  which_fun <- if (isTRUE(maximize)) which.max else which.min

  scores_per_fold <- inner_resamples %>%
    group_by(!!sym("id")) %>%
    filter(row_number() == which_fun(!!sym("score"))) %>%
    ungroup()

  # bind best hyperparameters per outer fold back onto resamples
  resamples <- bind_cols(resamples, scores_per_fold %>% select(!!arg_names))

  # nest scores back onto resampling df
  resamples$tune_scores <- map(
    split(inner_resamples, as.factor(inner_resamples$id)), function(x, ...) {
      x <- x %>%
        select(-!!sym("splits")) %>%
        nest(tune_scores = c(!!sym("pipelines"), !!arg_names, !!sym("score")))
      x <- x$tune_scores
      x[[1]]
    })

  attr(resamples, "tuning") <- names(param_grid)
  class(resamples) <- append(class(resamples), "tune")

  resamples
}


#' @importFrom purrr map map_dbl pmap
#' @importFrom furrr future_map2_dbl future_options
#' @importFrom dplyr bind_cols select rename group_by group_by_at summarize group_map
#' @importFrom tidyr expand_grid nest unnest
#' @importFrom rlang sym
#' @export
tune.nested_cv <- function(resamples, object, recipe = NULL, param_grid, scoring,
                           maximize = TRUE, .options = future_options()) {

  arg_names <- names(param_grid)

  # flatten the resamples
  inner_resamples <- resamples %>%
    select(!!sym("id"), !!sym("inner_resamples")) %>%
    rename(outer_fold = !!sym("id")) %>%
    unnest(!!sym("inner_resamples")) %>%
    rename(inner_fold = !!sym("id"))

  # cross with the param_grid and pipelines
  inner_resamples <- inner_resamples %>%
    expand_grid(param_grid)

  if (inherits(object, "pipeline")) {
    inner_resamples$pipelines <- inner_resamples %>%
      select(!!arg_names) %>%
      pmap(function(...) object %>% update(!!!list2(...)))
  } else {
    inner_resamples$pipelines <- inner_resamples %>%
      select(!!arg_names) %>%
      pmap(function(...) pipeline(recipe, object, !!!list2(...)))
  }

  # fit and score all inner resamples
  inner_resamples$score <- future_map2_dbl(
    inner_resamples$splits, inner_resamples$pipelines,
    fit_and_score, scoring,
    .options = .options
  )

  # summarize mean of inner resamples per outer fold/param
  scores <- inner_resamples %>%
    group_by_at(c(arg_names, "outer_fold")) %>%
    summarize(score = mean(!!sym("score")))

  # get best scoring hyperparameter per outer fold
  which_fun <- if (isTRUE(maximize)) which.max else which.min

  scores_per_fold <- scores %>%
    group_by(!!sym("outer_fold")) %>%
    filter(row_number() == which_fun(!!sym("score"))) %>%
    ungroup() %>%
    select(-!!sym("outer_fold"))

  # bind best hyperparameters per outer fold back onto resamples
  resamples <- bind_cols(resamples, scores_per_fold %>% select(!!arg_names))

  # nest scores back onto resampling df
  # resamples$tune_scores <- map(
  #   split(inner_resamples, as.factor(inner_resamples$outer_fold)), function(x, ...) {
  #     x <- x %>%
  #       select(-!!sym("splits")) %>%
  #       nest(tune_scores = c(!!sym("pipelines"), !!arg_names, !!sym("score")))
  #     id <- x$inner_fold
  #     x <- x$tune_scores
  #     names(x) <- id
  #     x
  #   })

  resamples$tune_scores <- inner_resamples %>%
    group_by(!!sym("outer_fold")) %>%
    group_map(function(x, ...) {
      x <- x %>%
        select(-!!sym("splits")) %>%
        nest(tune_scores = c(!!sym("pipelines"), !!arg_names, !!sym("score")))
      id <- x$inner_fold
      x <- x$tune_scores
      names(x) <- id
      x
    })

  attr(resamples, "tuning") <- names(param_grid)
  class(resamples) <- append(class(resamples), "tune")

  resamples
}


#' @importFrom stats formula predict
#' @importFrom rsample assessment form_pred
#' @importFrom formula.tools lhs.vars
#' @importFrom dplyr filter
fit_and_score <- function(rsplit, pipeline, scoring) {

  # fit model to recipe
  fitted <- pipeline %>% fit(data = analysis(rsplit))

  # subset assessment set
  X_test <- assessment(rsplit)
  outcome_name <- setdiff(all.vars(formula(fitted$recipe)), form_pred(formula(fitted$recipe)))

  # predict assessment set
  pred <- fitted %>% predict(X_test)
  pred <- pred %>% bind_cols(X_test)

  # scoring
  truth_col <- switch(
    pipeline$model_spec$mode,
    "classification" = ".pred_class",
    "regression" = ".pred"
  )

  score <- pred %>% scoring(
    truth = !!outcome_name,
    estimate = !!truth_col)

  score[[".estimate"]]
}
stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.