R/modeltime-fit-workflowset.R

Defines functions as_modeltime_table_from_workflowset create_model_grid modeltime_fit_workflowset_parallel modeltime_fit_workflowset_sequential modeltime_fit_workflowset

Documented in create_model_grid modeltime_fit_workflowset

# MODELTIME FIT WORKFLOWSET ----

#' Fit a `workflowset` object to one or multiple time series
#'
#' This is a wrapper for `fit()` that takes a
#' `workflowset` object and fits each model on one or multiple
#' time series either sequentially or in parallel.
#'
#' @param object A workflow_set object, generated with the workflowsets::workflow_set function.
#' @param data A `tibble` that contains data to fit the models.
#' @param control An object used to modify the fitting process. See [control_fit_workflowset()].
#' @param ... Not currently used.
#'
#' @seealso
#' [control_fit_workflowset()]
#'
#' @examples
#' library(tidymodels)
#' library(modeltime)
#' library(workflowsets)
#' library(tidyverse)
#' library(lubridate)
#' library(timetk)
#'
#' data_set <- m4_monthly
#'
#' # SETUP WORKFLOWSETS
#'
#' rec1 <- recipe(value ~ date + id, data_set) %>%
#'     step_mutate(date_num = as.numeric(date)) %>%
#'     step_mutate(month_lbl = lubridate::month(date, label = TRUE)) %>%
#'     step_dummy(all_nominal(), one_hot = TRUE)
#'
#' mod1 <- linear_reg() %>% set_engine("lm")
#'
#' mod2 <- prophet_reg() %>% set_engine("prophet")
#'
#' wfsets <- workflowsets::workflow_set(
#'     preproc = list(rec1 = rec1),
#'     models  = list(
#'         mod1 = mod1,
#'         mod2 = mod2
#'     ),
#'     cross   = TRUE
#' )
#'
#' # FIT WORKFLOWSETS
#' # - Returns a Modeltime Table with fitted workflowsets
#'
#' wfsets %>% modeltime_fit_workflowset(data_set)
#'
#' @return
#' A Modeltime Table containing one or more fitted models.
#' @export
modeltime_fit_workflowset <- function(object, data, ..., control = control_fit_workflowset()) {

    if (!inherits(object, "workflow_set")){
        rlang::abort("object argument must be a `workflow_set` object generated by workflowsets::workflow_set() function.")
    }

    # Parallel or Sequential
    if ((control$cores > 1) && control$allow_par) {
        models <- modeltime_fit_workflowset_parallel(object, data = data, control = control, ...)
    } else {
        models <- modeltime_fit_workflowset_sequential(object, data = data, control = control, ...)
    }

    names(models) <- NULL

    .model_desc <- object %>% dplyr::pull(1) %>% stringr::str_to_upper()

    modeltime_tbl <- models %>% as_modeltime_table_from_workflowset(.model_desc = .model_desc)

    return(modeltime_tbl)

}

modeltime_fit_workflowset_sequential <- function(object, data, control, ...) {

    t1 <- Sys.time()

    .models  <- object %>%
        dplyr::mutate(wflow_id = forcats::as_factor(wflow_id)) %>%
        dplyr::group_by(wflow_id) %>%
        dplyr::group_split()

    safe_fit <- purrr::safely(parsnip::fit, otherwise = NULL, quiet = TRUE)

    # Setup progress

    # BEGIN LOOP
    # if (control$verbose) {
    #     t <- Sys.time()
    #     message(stringr::str_glue("Beginning Sequential Loop | {round(t-t1, 3)} seconds"))
    # }

    models <- .models %>%
        purrr::imap(
            .f = function(obj, id) {

                if (control$verbose) {
                    cli::cli_alert_info(cli::col_grey("Fitting Model: {id}"))
                }

                mod <- obj %>%
                    dplyr::pull(2) %>%
                    purrr::pluck(1, 'workflow', 1)

                ret <- safe_fit(mod, data = data)

                res <- ret %>% purrr::pluck("result")

                if (!is.null(ret$error)) message(stringr::str_glue("Model {id} Error: {ret$error}"))

                if (control$verbose) {
                    if (is.null(res)) {
                        cli::cli_alert_danger(cli::col_grey("Model Failed Fitting: {id}"))
                    } else {
                        cli::cli_alert_success(cli::col_grey("Model Successfully Fitted: {id}"))
                    }
                }

                return(res)


            }
        )

    # PRINT TOTAL TIME
    if (control$verbose) {
        t <- Sys.time()
        message(stringr::str_glue("Total time | {round(t-t1, 3)} seconds"))
    }

    return(models)


}

modeltime_fit_workflowset_parallel <- function(object, data, control, ...) {

    t1 <- Sys.time()

    # Model List
    .models  <- object %>%
        dplyr::mutate(wflow_id = forcats::as_factor(wflow_id)) %>%
        dplyr::group_by(wflow_id) %>%
        dplyr::group_split()

    # Parallel Detection
    is_par_setup <- foreach::getDoParWorkers() > 1

    # If parallel processing is not set up, set up parallel backend
    par_setup_info <- setup_parallel_processing(control, is_par_setup, t1)
    clusters_made  <- par_setup_info$clusters_made
    cl             <- par_setup_info$cl

    # Setup Foreach
    `%op%` <- get_operator(allow_par = control$allow_par)

    # Setup Safe Modeling
    safe_fit <- purrr::safely(parsnip::fit, otherwise = NULL, quiet = FALSE)

    # BEGIN LOOP
    if (control$verbose) {
        t <- Sys.time()
        message(stringr::str_glue(" Beginning Parallel Loop | {round(t-t1, 3)} seconds"))
    }

    ret <- foreach::foreach(
        this_model          = .models,
        .inorder            = TRUE,
        .packages           = control$packages,
        .verbose            = FALSE
    ) %op% {

        mod <- this_model %>%
            dplyr::pull(2) %>%
            purrr::pluck(1, 'workflow', 1) %>%
            safe_fit(data)

        res <- mod %>%
            purrr::pluck("result")

        err <- mod %>%
            purrr::pluck("error")

        return(list(res = res, err = err))

    }

    # Collect models
    models <- ret %>% purrr::map(~ .x$res)

    # Collect errors
    error_messages <- ret %>% purrr::map(~ .x$err)
    purrr::iwalk(
        error_messages,
        function (e, id) {
            if (!is.null(e)) message(stringr::str_glue("Model {id} Error: {e}"))
        }
    )

    # Finish Parallel Backend. Close clusters if we set up internally.
    finish_parallel_processing(control, clusters_made, cl, t1)

    if (control$verbose) {
        t <- Sys.time()
        message(stringr::str_glue(" Total time | {round(t-t1, 3)} seconds"))
    }


    return(models)

}

# CREATE MODEL GRID ----

#' Helper to make `parsnip` model specs from a `dials` parameter grid
#'
#' @param grid A tibble that forms a grid of parameters to adjust
#' @param f_model_spec A function name (quoted or unquoted) that
#'  specifies a `parsnip` model specification function
#' @param engine_name A name of an engine to use. Gets passed to `parsnip::set_engine()`.
#' @param ... Static parameters that get passed to the f_model_spec
#' @param engine_params A `list` of additional parameters that can be passed to the
#'  engine via `parsnip::set_engine(...)`.
#'
#' @details
#'
#' This is a helper function that combines `dials` grids with
#' `parsnip` model specifications. The intent is to make it easier
#' to generate `workflowset` objects for forecast evaluations
#' with `modeltime_fit_workflowset()`.
#'
#' The process follows:
#'
#' 1. Generate a grid (hyperparemeter combination)
#' 2. Use `create_model_grid()` to apply the parameter combinations to
#'    a parsnip model spec and engine.
#'
#' The output contains ".model" column that can be used as a list
#' of models inside the `workflow_set()` function.
#'
#' @return
#' Tibble with a new colum named `.models`
#'
#'
#' @seealso
#' - [dials::grid_regular()]: For making parameter grids.
#' - [workflowsets::workflow_set()]: For creating a `workflowset` from the `.models` list stored in the ".models" column.
#' - [modeltime_fit_workflowset()]: For fitting a `workflowset` to forecast data.
#'
#' @examples
#'
#' library(tidymodels)
#' library(modeltime)
#'
#' # Parameters that get optimized
#' grid_tbl <- grid_regular(
#'     learn_rate(),
#'     levels = 3
#' )
#'
#' # Generate model specs
#' grid_tbl %>%
#'     create_model_grid(
#'         f_model_spec = boost_tree,
#'         engine_name  = "xgboost",
#'         # Static boost_tree() args
#'         mode = "regression",
#'         # Static set_engine() args
#'         engine_params = list(
#'             max_depth = 5
#'         )
#'     )
#'
#' @export
create_model_grid <- function(grid, f_model_spec, engine_name, ..., engine_params = list()) {

    f_text <- rlang::as_name(substitute(f_model_spec))

    model_list <- seq_len(nrow(grid)) %>%
        purrr::map(.f = function(x) {

            params <- grid %>%
                dplyr::slice(x) %>%
                as.list() %>%
                append(list(...))

            do.call(f_text, params)
        }) %>%
        purrr::map(.f = function(x) {
            x %>% parsnip::set_engine(engine = engine_name, !!! engine_params)
        })

    dplyr::bind_cols(grid, tibble::tibble(.models = model_list))

}


# HELPERS -----

as_modeltime_table_from_workflowset <- function(.l, .model_desc) {

    ret <- tibble::tibble(
        .model = .l
    ) %>%
        tibble::rowid_to_column(var = ".model_id")

    # CHECKS
    # validate_model_classes(ret, accept_classes = c("model_fit", "workflow", "mdl_time_ensemble"))
    # validate_models_are_trained(ret)
    validate_models_are_not_null(ret, msg_main = "Some models failed during fitting: modeltime_fit_workflowset()")

    # CREATE MODELTIME OBJECT
    ret <- ret %>%
        dplyr::mutate(.model_desc = .model_desc)

    class(ret) <- c("mdl_time_tbl", class(ret))

    return(ret)
}

Try the modeltime package in your browser

Any scripts or data that you put into this service are public.

modeltime documentation built on Sept. 2, 2023, 5:06 p.m.