R/cross_validate.R

Defines functions cross_validate cross_validate.default cross_validate.nested_cv

Documented in cross_validate

#' Perform cross-validation on a nested_cv rsample object
#'
#' @param resamples nested_cv object
#' @param object parsnip model specification, or a pipeline object
#' @param recipe recipe, optional if using pipeline object
#' @param scoring yardstick scoring function
#' @param keep_preds logical, store predictions per fold in a `predictions` column,
#' default is TRUE
#' @param keep_models logical, store fitted models per fold in a `models` column, default is FALSE
#' @param .options future_options to pass additional packages required by tuning and fitting
#'   functions
#'
#' @return tibble
#' @export
#' @importFrom furrr future_options
cross_validate <- function(resamples, object, recipe = NULL, scoring, keep_preds = TRUE,
                           keep_models = FALSE, .options = future_options()) {
  UseMethod("cross_validate", resamples)
}

#' @export
#' @importFrom purrr map
#' @importFrom furrr future_map2 future_options
#' @importFrom dplyr bind_cols mutate
#' @importFrom formula.tools lhs.vars
#' @importFrom rsample analysis assessment
#' @importFrom tibble as_tibble tibble
cross_validate.default <- function(resamples, object, recipe = NULL, scoring, keep_preds = TRUE,
                                   keep_models = FALSE, .options = future_options()) {

  if (inherits(object, "pipeline")) {
    estimator <- object[[2]]
  } else {
    estimator <- object
  }

  truth_col <- switch(
    estimator$mode,
    "classification" = ".pred_class",
    "regression" = ".pred"
  )

  # create train and test data splits
  resamples <- resamples %>%
    mutate(training = map(resamples$splits, analysis),
           testing = map(resamples$splits, assessment))

  results <- future_map2(resamples$training, resamples$testing, function(X_train, X_test) {

    if (!inherits(object, "pipeline"))
      object <- pipeline(preprocessing = recipe, model_spec = object)

    fitted <- object %>% fit(X_train)
    preds <- fitted %>% predict(new_data = X_test)

    scores <- preds %>%
      bind_cols(X_test) %>%
      scoring(
        truth = !!lhs.vars(formula(fitted$recipe)),
        estimate = !!truth_col)

    if (keep_preds == TRUE)
      preds_tbl <- preds else preds_tbl = NULL

    if (keep_models == TRUE)
      models_tbl <- fitted else models_tbl = NULL

    list(scores = scores,
         predictions = preds_tbl,
         models = models_tbl)

    }, .options = .options)

  results <- set_names(results, seq_along(results))

  resamples$outer_scores <- map(results, ~ .x$scores)

  if (keep_preds == TRUE)
    resamples$predictions <- map(results, ~ .x$predictions)

  if (keep_models == TRUE)
    resamples$models <- map(results, ~ .x$models)

  resamples
}

#' @export
#' @importFrom purrr map
#' @importFrom furrr future_pmap future_options
#' @importFrom dplyr bind_cols mutate
#' @importFrom formula.tools lhs.vars
#' @importFrom rsample analysis assessment
#' @importFrom tibble as_tibble tibble
cross_validate.nested_cv <- function(resamples, object, recipe = NULL, scoring,
                                     keep_preds = TRUE, keep_models = FALSE,
                                     .options = future_options()) {

  if (inherits(object, "pipeline")) {
    estimator <- object[[2]]
  } else {
    estimator <- object
  }

  # convert best hyperparameter columns into a list
  hyper_pars_names <- attr(resamples, "tuning")

  pars <- lapply(seq_len(nrow(resamples)), function(i) {
    vars <- list()
    for (k in resamples[i, hyper_pars_names])
      vars <- append(vars, k)
    names(vars) <- hyper_pars_names
    vars
  })

  # determine truth column
  truth_col <- switch(
    estimator$mode,
    "classification" = ".pred_class",
    "regression" = ".pred"
  )

  resamples <- resamples %>%
    mutate(training = map(resamples$splits, analysis),
           testing = map(resamples$splits, assessment))

  results <- future_pmap(
    list(resamples$training, resamples$testing, pars), function(X_train, X_test, par) {

      if (!inherits(object, "pipeline")) {
        object <- exec(pipeline, preprocessing = recipe, model_spec = object, !!!par)
      } else {
        object <- object %>% update(!!!par)
      }

      fitted <- object %>% fit(X_train)
      preds <- fitted %>% predict(new_data = X_test)

      scores <- preds %>%
        bind_cols(X_test) %>%
        scoring(
          truth = !!lhs.vars(formula(fitted$recipe)),
          estimate = !!truth_col)

      if (keep_preds == TRUE)
        preds_tbl <- preds else preds_tbl = NULL

      if (keep_models == TRUE)
        models_tbl <- fitted else models_tbl = NULL

      list(scores = scores,
           predictions = preds_tbl,
           models = models_tbl)
      }, .options = .options
    )
  results <- set_names(results, seq_along(results))

  resamples$outer_scores <- map(results, ~ .x$scores)

  if (keep_preds == TRUE)
    resamples$predictions <- map(results, ~ .x$predictions)

  if (keep_models == TRUE)
    resamples$models <- map(results, ~ .x$models)

  attr(resamples, "tuning") <- hyper_pars_names

  resamples
}
stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.