Nothing
# MODELTIME FIT WORKFLOWSET ----
#' Fit a `workflowset` object to one or multiple time series
#'
#' This is a wrapper for `fit()` that takes a
#' `workflowset` object and fits each model on one or multiple
#' time series either sequentially or in parallel.
#'
#' @param object A workflow_set object, generated with the workflowsets::workflow_set function.
#' @param data A `tibble` that contains data to fit the models.
#' @param control An object used to modify the fitting process. See [control_fit_workflowset()].
#' @param ... Not currently used.
#'
#' @seealso
#' [control_fit_workflowset()]
#'
#' @examples
#' library(tidymodels)
#' library(workflowsets)
#' library(dplyr)
#' library(lubridate)
#' library(timetk)
#'
#' data_set <- m4_monthly
#'
#' # SETUP WORKFLOWSETS
#'
#' rec1 <- recipe(value ~ date + id, data_set) %>%
#' step_mutate(date_num = as.numeric(date)) %>%
#' step_mutate(month_lbl = lubridate::month(date, label = TRUE)) %>%
#' step_dummy(all_nominal(), one_hot = TRUE)
#'
#' mod1 <- linear_reg() %>% set_engine("lm")
#'
#' mod2 <- prophet_reg() %>% set_engine("prophet")
#'
#' wfsets <- workflowsets::workflow_set(
#' preproc = list(rec1 = rec1),
#' models = list(
#' mod1 = mod1,
#' mod2 = mod2
#' ),
#' cross = TRUE
#' )
#'
#' # FIT WORKFLOWSETS
#' # - Returns a Modeltime Table with fitted workflowsets
#'
#' wfsets %>% modeltime_fit_workflowset(data_set)
#'
#' @return
#' A Modeltime Table containing one or more fitted models.
#' @export
modeltime_fit_workflowset <- function(object, data, ..., control = control_fit_workflowset()) {
if (!inherits(object, "workflow_set")){
cli::cli_abort("object argument must be a `workflow_set` object generated by the {.fn workflowsets::workflow_set} function.")
}
# Parallel or Sequential
if ((control$cores > 1) && control$allow_par) {
models <- modeltime_fit_workflowset_parallel(object, data = data, control = control, ...)
} else {
models <- modeltime_fit_workflowset_sequential(object, data = data, control = control, ...)
}
names(models) <- NULL
.model_desc <- object %>% dplyr::pull(1) %>% toupper()
modeltime_tbl <- models %>% as_modeltime_table_from_workflowset(.model_desc = .model_desc)
return(modeltime_tbl)
}
modeltime_fit_workflowset_sequential <- function(object, data, control, ...) {
t1 <- Sys.time()
.models <- object %>%
dplyr::mutate(wflow_id = forcats::as_factor(wflow_id)) %>%
dplyr::group_by(wflow_id) %>%
dplyr::group_split()
safe_fit <- purrr::safely(parsnip::fit, otherwise = NULL, quiet = TRUE)
# Setup progress
# BEGIN LOOP
# if (control$verbose) {
# t <- Sys.time()
# message(stringr::str_glue("Beginning Sequential Loop | {round(t-t1, 3)} seconds"))
# }
models <- .models %>%
purrr::imap(
.f = function(obj, id) {
if (control$verbose) {
cli::cli_alert_info(cli::col_grey("Fitting Model: {id}"))
}
mod <- obj %>%
dplyr::pull(2) %>%
purrr::pluck(1, 'workflow', 1)
ret <- safe_fit(mod, data = data)
res <- ret %>% purrr::pluck("result")
if (!is.null(ret$error)) message(stringr::str_glue("Model {id} Error: {ret$error}"))
if (control$verbose) {
if (is.null(res)) {
cli::cli_alert_danger(cli::col_grey("Model Failed Fitting: {id}"))
} else {
cli::cli_alert_success(cli::col_grey("Model Successfully Fitted: {id}"))
}
}
return(res)
}
)
# PRINT TOTAL TIME
if (control$verbose) {
t <- Sys.time()
message(stringr::str_glue("Total time | {round(t-t1, 3)} seconds"))
}
return(models)
}
modeltime_fit_workflowset_parallel <- function(object, data, control, ...) {
t1 <- Sys.time()
# Model List
.models <- object %>%
dplyr::mutate(wflow_id = forcats::as_factor(wflow_id)) %>%
dplyr::group_by(wflow_id) %>%
dplyr::group_split()
# Parallel Detection
is_par_setup <- foreach::getDoParWorkers() > 1
# If parallel processing is not set up, set up parallel backend
par_setup_info <- setup_parallel_processing(control, is_par_setup, t1)
clusters_made <- par_setup_info$clusters_made
cl <- par_setup_info$cl
# Setup Foreach
`%op%` <- get_operator(allow_par = control$allow_par)
# Setup Safe Modeling
safe_fit <- purrr::safely(parsnip::fit, otherwise = NULL, quiet = FALSE)
# BEGIN LOOP
if (control$verbose) {
t <- Sys.time()
message(stringr::str_glue(" Beginning Parallel Loop | {round(t-t1, 3)} seconds"))
}
ret <- foreach::foreach(
this_model = .models,
.inorder = TRUE,
.packages = control$packages,
.verbose = FALSE
) %op% {
mod <- this_model %>%
dplyr::pull(2) %>%
purrr::pluck(1, 'workflow', 1) %>%
safe_fit(data)
res <- mod %>%
purrr::pluck("result")
err <- mod %>%
purrr::pluck("error")
return(list(res = res, err = err))
}
# Collect models
models <- ret %>% purrr::map(~ .x$res)
# Collect errors
error_messages <- ret %>% purrr::map(~ .x$err)
purrr::iwalk(
error_messages,
function (e, id) {
if (!is.null(e)) message(stringr::str_glue("Model {id} Error: {e}"))
}
)
# Finish Parallel Backend. Close clusters if we set up internally.
finish_parallel_processing(control, clusters_made, cl, t1)
if (control$verbose) {
t <- Sys.time()
message(stringr::str_glue(" Total time | {round(t-t1, 3)} seconds"))
}
return(models)
}
# CREATE MODEL GRID ----
#' Helper to make `parsnip` model specs from a `dials` parameter grid
#'
#' @param grid A tibble that forms a grid of parameters to adjust
#' @param f_model_spec A function name (quoted or unquoted) that
#' specifies a `parsnip` model specification function
#' @param engine_name A name of an engine to use. Gets passed to `parsnip::set_engine()`.
#' @param ... Static parameters that get passed to the f_model_spec
#' @param engine_params A `list` of additional parameters that can be passed to the
#' engine via `parsnip::set_engine(...)`.
#'
#' @details
#'
#' This is a helper function that combines `dials` grids with
#' `parsnip` model specifications. The intent is to make it easier
#' to generate `workflowset` objects for forecast evaluations
#' with `modeltime_fit_workflowset()`.
#'
#' The process follows:
#'
#' 1. Generate a grid (hyperparemeter combination)
#' 2. Use `create_model_grid()` to apply the parameter combinations to
#' a parsnip model spec and engine.
#'
#' The output contains ".model" column that can be used as a list
#' of models inside the `workflow_set()` function.
#'
#' @return
#' Tibble with a new colum named `.models`
#'
#'
#' @seealso
#' - [dials::grid_regular()]: For making parameter grids.
#' - [workflowsets::workflow_set()]: For creating a `workflowset` from the `.models` list stored in the ".models" column.
#' - [modeltime_fit_workflowset()]: For fitting a `workflowset` to forecast data.
#'
#' @examples
#'
#' library(tidymodels)
#'
#' # Parameters that get optimized
#' grid_tbl <- grid_regular(
#' learn_rate(),
#' levels = 3
#' )
#'
#' # Generate model specs
#' grid_tbl %>%
#' create_model_grid(
#' f_model_spec = boost_tree,
#' engine_name = "xgboost",
#' # Static boost_tree() args
#' mode = "regression",
#' # Static set_engine() args
#' engine_params = list(
#' max_depth = 5
#' )
#' )
#'
#' @export
create_model_grid <- function(grid, f_model_spec, engine_name, ..., engine_params = list()) {
f_text <- rlang::as_name(substitute(f_model_spec))
model_list <- seq_len(nrow(grid)) %>%
purrr::map(.f = function(x) {
params <- grid %>%
dplyr::slice(x) %>%
as.list() %>%
append(list(...))
do.call(f_text, params)
}) %>%
purrr::map(.f = function(x) {
x %>% parsnip::set_engine(engine = engine_name, !!! engine_params)
})
dplyr::bind_cols(grid, tibble::tibble(.models = model_list))
}
# HELPERS -----
as_modeltime_table_from_workflowset <- function(.l, .model_desc) {
ret <- tibble::tibble(
.model = .l
) %>%
tibble::rowid_to_column(var = ".model_id")
# CHECKS
# validate_model_classes(ret, accept_classes = c("model_fit", "workflow", "mdl_time_ensemble"))
# validate_models_are_trained(ret)
validate_models_are_not_null(ret, msg_main = "Some models failed during fitting: modeltime_fit_workflowset()")
# CREATE MODELTIME OBJECT
ret <- ret %>%
dplyr::mutate(.model_desc = .model_desc)
class(ret) <- c("mdl_time_tbl", class(ret))
return(ret)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.