R/pipeline.R

Defines functions pipeline fit.pipeline predict.pipeline generate_pipelines

Documented in pipeline

#' Pipeline for preprocessing, estimation and postprocessing
#'
#' @param preprocessing recipe object
#' @param model_spec parsnip model_spec object
#' @param ... arguments used to update pipeline steps
#'
#' @return pipeline
#' @export
#' @importFrom rlang list2
#' @importFrom stringr str_detect
#' @importFrom parsnip set_args
pipeline <- function(preprocessing, model_spec, ...) {
  args <- list2(...)

  # separate model_spec and recipe arguments
  args_model <- args[!str_detect(names(args), "__")]
  args_recipe <- args[str_detect(names(args), "__")]

  # update
  if (length(args_model) > 0)
    model_spec <- model_spec %>% set_args(!!!args_model)

  if (length(args_recipe) > 0)
    preprocessing <- preprocessing %>% update(!!!args_recipe)

  structure(list(recipe = preprocessing, model_spec = model_spec),
            class = "pipeline")
}

#' @export
#' @importFrom recipes prep juice
#' @importFrom parsnip fit
fit.pipeline <- function(object, data = NULL) {

  if (is.null(data)) {
    prepped_recipe <- object$recipe %>% prep()
  } else {
    prepped_recipe <- object$recipe %>% prep(data)
  }

  model_fit <- object$model_spec %>%
    fit(formula(prepped_recipe), data = juice(prepped_recipe))

  structure(
    list(recipe = prepped_recipe, model_fit = model_fit),
    class = "pipeline")
}


#' @export
#' @importFrom rlang list2
#' @importFrom recipes bake
#' @importFrom stats predict
predict.pipeline <- function(object, new_data, ...) {
  args <- list2(...)

  data <- bake(object$recipe, new_data = new_data)
  predict(object$model_fit, new_data = data, opts = args)
}


#' @importFrom rlang exec
#' @importFrom purrr map
generate_pipelines <- function(model_spec, recipe, param_grid) {

  map(seq_len(nrow(param_grid)), function(i) {
    # convert row to named vector
    param_set <- param_grid[i, ]
    args <- sapply(1:length(param_set), function(p) param_set[, p])

    # create pipeline
    pipeline_func <- exec(
      pipeline, preprocessing = recipe, model_spec = model_spec, !!!args)
  })
}
stevenpawley/tidycrossval documentation built on Oct. 3, 2019, 3:32 p.m.