#' Pipeline for preprocessing, estimation and postprocessing
#'
#' @param preprocessing recipe object
#' @param model_spec parsnip model_spec object
#' @param ... arguments used to update pipeline steps
#'
#' @return pipeline
#' @export
#' @importFrom rlang list2
#' @importFrom stringr str_detect
#' @importFrom parsnip set_args
pipeline <- function(preprocessing, model_spec, ...) {
args <- list2(...)
# separate model_spec and recipe arguments
args_model <- args[!str_detect(names(args), "__")]
args_recipe <- args[str_detect(names(args), "__")]
# update
if (length(args_model) > 0)
model_spec <- model_spec %>% set_args(!!!args_model)
if (length(args_recipe) > 0)
preprocessing <- preprocessing %>% update(!!!args_recipe)
structure(list(recipe = preprocessing, model_spec = model_spec),
class = "pipeline")
}
#' @export
#' @importFrom recipes prep juice
#' @importFrom parsnip fit
fit.pipeline <- function(object, data = NULL) {
if (is.null(data)) {
prepped_recipe <- object$recipe %>% prep()
} else {
prepped_recipe <- object$recipe %>% prep(data)
}
model_fit <- object$model_spec %>%
fit(formula(prepped_recipe), data = juice(prepped_recipe))
structure(
list(recipe = prepped_recipe, model_fit = model_fit),
class = "pipeline")
}
#' @export
#' @importFrom rlang list2
#' @importFrom recipes bake
#' @importFrom stats predict
predict.pipeline <- function(object, new_data, ...) {
args <- list2(...)
data <- bake(object$recipe, new_data = new_data)
predict(object$model_fit, new_data = data, opts = args)
}
#' @importFrom rlang exec
#' @importFrom purrr map
generate_pipelines <- function(model_spec, recipe, param_grid) {
map(seq_len(nrow(param_grid)), function(i) {
# convert row to named vector
param_set <- param_grid[i, ]
args <- sapply(1:length(param_set), function(p) param_set[, p])
# create pipeline
pipeline_func <- exec(
pipeline, preprocessing = recipe, model_spec = model_spec, !!!args)
})
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.