R/ensemble.R

Defines functions ensemble fit.ensemble predict.ensemble

Documented in ensemble fit.ensemble predict.ensemble

#' Create an ensemble model specification
#'
#' Create an ensemble model from a list of base pipelines and a
#' meta-model specification
#'
#' @param base_pipelines named list of pipeline objects to use as the base models
#' @param meta_model parsnip model specification to use as the meta learning model
#' @param cv rsample resampling partial function. By default this is set to vfold_cv(v = 3)
#' @param pass_predictors logical, optionally pass the predictors to the meta model as well as the
#'   predictions from the base pipeline models
#'
#' @return ensemble object
#' @export
#' @importFrom formula.tools lhs.vars
#' @importFrom recipes prep
#' @importFrom purrr partial
#' @importFrom rsample vfold_cv
#' @examples
#' \dontrun{
#' library(tidymodels)
#' library(tidycrossval)
#' library(magrittr)
#'
#' data <- iris
#' rec <- data %>% recipe(Species ~ .) %>%
#'   step_scale(all_predictors()) %>%
#'   step_center(all_predictors())
#'
#' rf_pipeline <- pipeline(rec, rand_forest(mode = "classification") %>% set_engine("ranger"))
#' boost_pipeline <- pipeline(rec, boost_tree(mode = "classification") %>% set_engine("xgboost"))
#' base_pipelines <- list(rf = rf_pipeline, boost = boost_pipeline)
#' final_model <- mlp(mode = "classification") %>% set_engine("nnet")
#'
#' object <- ensemble(base_pipelines, final_model, cv = . %>% vfold_cv(v = 3))
#' ensemble_fitted <- fit(object, data = iris)
#' preds <- predict(ensemble_fitted, data)
#' }
ensemble <- function(
  base_pipelines,
  meta_model,
  cv = partial(vfold_cv, v = 3),
  pass_predictors = FALSE) {

  target_variable <- base_pipelines[[1]]$recipe %>%
    prep() %>%
    formula() %>%
    lhs.vars()

  structure(
    list(target_variable = target_variable,
         base_pipelines = base_pipelines,
         meta_model = meta_model,
         cv = cv,
         pass_predictors = pass_predictors),
    class = "ensemble"
  )
}


#' Fit method for an ensemble object
#'
#' @param object ensemble object
#' @param data data.frame of training data
#' @param ... currently unused
#'
#' @return ensemble object
#' @export
#' @importFrom rsample vfold_cv analysis assessment
#' @importFrom purrr map map2
#' @importFrom dplyr mutate bind_rows
#' @importFrom rlang set_names
#' @importFrom parsnip fit_xy
fit.ensemble <- function(object, data, ...) {

  base_pipelines <- object$base_pipelines
  pipeline_names <- names(base_pipelines)
  pipeline_pred_names <- paste(pipeline_names, "preds", sep = "_")

  folds <- data %>% object$cv()

  # fit models during cross validation
  for (i in seq_len(length(base_pipelines))) {
    folds <- folds %>%
      mutate(!!pipeline_names[[i]] := map(
        folds$splits,
        ~ base_pipelines[[i]] %>% fit(data = analysis(.x))
        ))
  }

  # out-of-fold predictions
  for (i in seq_len(length(base_pipelines))) {
    folds <- folds %>%
      mutate(!!pipeline_pred_names[[i]] := map2(
        .x = folds[[pipeline_names[i]]],
        .y = folds$splits,
        .f = ~ predict(.x, new_data = assessment(.y))
        ))
  }

  # fit base models on full training dataset
  fitted_base_models <- map(base_pipelines, ~ fit(.x, data = data))

  # gather out of fold predictions
  oof_preds <- map(folds[pipeline_pred_names], bind_rows) %>%
    bind_cols() %>% set_names(pipeline_pred_names)

  # optionally pass predictors to meta model as well as base model predictions
  if (object$pass_predictors == TRUE) {
    oof_preds <- bind_cols(
      oof_preds, map(folds$splits, ~ assessment(.x)) %>% bind_rows())

  # otherwise just bind the response variable to the out-of-fold predictions df
  } else {
    oof_preds <- bind_cols(
      oof_preds, map(
        folds$splits,
        ~ assessment(.x)[object$target_variable]) %>% bind_rows())

  }

  # train meta model on out-of-fold predictions
  X = oof_preds[, !names(oof_preds) == object$target_variable]
  y = oof_preds[[object$target_variable]]

  meta_fit <- object$meta_model %>% fit_xy(x = X, y = y)

  structure(
    list(
      base_models = fitted_base_models,
      meta_model = meta_fit,
      pass_predictors = object$pass_predictors),
    class = "ensemble"
  )
}


#' Predict method for ensemble object
#'
#' @param object fitted ensemble object
#' @param new_data data.frame
#' @param ... currently unused
#'
#' @return tibble of predictions
#' @export
#' @importFrom purrr map_dfc
#' @importFrom rlang set_names
#' @importFrom stats predict
predict.ensemble <- function(object, new_data, ...) {

  # predictions for base models
  base_model_preds <- map_dfc(object$base_models, ~ predict(.x, new_data = new_data))
  pred_names <- paste(names(object$base_models), "preds", sep = "_")
  base_model_preds <- set_names(base_model_preds, pred_names)

  if (object$pass_predictors == TRUE) {
    base_model_preds <- bind_cols(base_model_preds, new_data)
  }

  # pass base model predictions to meta model for prediction
  meta_model_pred <- predict(object$meta_model, new_data = base_model_preds)

  meta_model_pred
}
stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.