R/modeltime-refit.R

Defines functions mdl_time_refit.recursive_panel mdl_time_refit.recursive mdl_time_refit.model_fit mdl_time_refit.workflow mdl_time_refit.default mdl_time_refit modeltime_refit_parallel modeltime_refit_sequential modeltime_refit.mdl_time_tbl modeltime_refit

Documented in mdl_time_refit modeltime_refit

# MODELTIME REFIT ----

#' Refit one or more trained models to new data
#'
#' This is a wrapper for `fit()` that takes a
#' Modeltime Table and retrains each model on _new data_ re-using the parameters
#' and preprocessing steps used during the training process.
#'
#' @param object A Modeltime Table
#' @param data A `tibble` that contains data to retrain the model(s) using.
#' @param ... Additional arguments to control refitting.
#'
#'  __Ensemble Model Spec (`modeltime.ensemble`):__
#'
#'    When making a meta-learner with `modeltime.ensemble::ensemble_model_spec()`,
#'    used to pass `resamples` argument containing results
#'    from `modeltime.resample::modeltime_fit_resamples()`.
#'
#' @param control Used to control verbosity and parallel processing.
#'  See [control_refit()].
#'
#' @return
#' A Modeltime Table containing one or more re-trained models.
#'
#' @details
#'
#' Refitting is an important step prior to forecasting time series models.
#' The `modeltime_refit()` function makes it easy to recycle models,
#' retraining on new data.
#'
#' __Recycling Parameters__
#'
#' Parameters are recycled during retraining using the following criteria:
#'
#' - __Automated models__ (e.g. "auto arima") will have parameters recalculated.
#' - __Non-automated models__ (e.g. "arima") will have parameters preserved.
#' - All preprocessing steps will be reused on the data
#'
#' __Refit__
#'
#' The `modeltime_refit()` function is used to retrain models trained with `fit()`.
#'
#' __Refit XY__
#'
#' The XY format is not supported at this time.
#'
#' @seealso
#' [control_refit()]
#'
#'
#' @examples
#' library(tidyverse)
#' library(lubridate)
#' library(timetk)
#' library(parsnip)
#' library(rsample)
#' library(modeltime)
#'
#' # Data
#' m750 <- m4_monthly %>% filter(id == "M750")
#'
#' # Split Data 80/20
#' splits <- initial_time_split(m750, prop = 0.9)
#'
#' # --- MODELS ---
#'
#' model_fit_prophet <- prophet_reg() %>%
#'     set_engine(engine = "prophet") %>%
#'     fit(value ~ date, data = training(splits))
#'
#'
#' # ---- MODELTIME TABLE ----
#'
#' models_tbl <- modeltime_table(
#'     model_fit_prophet
#' )
#'
#' # ---- CALIBRATE ----
#' # - Calibrate on training data set
#'
#' calibration_tbl <- models_tbl %>%
#'     modeltime_calibrate(new_data = testing(splits))
#'
#'
#' # ---- REFIT ----
#' # - Refit on full data set
#'
#' refit_tbl <- calibration_tbl %>%
#'     modeltime_refit(m750)
#'
#'
#' @name modeltime_refit
#' @importFrom parsnip fit fit_xy
NULL

#' @export
#' @rdname modeltime_refit
modeltime_refit <- function(object, data, ..., control = control_refit()) {
    UseMethod("modeltime_refit", object)
}

#' @export
modeltime_refit.mdl_time_tbl <- function(object, data, ..., control = control_refit()) {

    new_data <- data
    data     <- object # object is a Modeltime Table

    # Backwards compatibility
    if (is.null(control)) control <- control_refit()

    # Save current model descriptions
    model_desc_user_vec          <- object$.model_desc
    model_desc_modeltime_old_vec <- object$.model %>% purrr::map_chr(get_model_description)

    # Parallel or Sequential
    if ((control$cores > 1) && control$allow_par) {
        ret <- modeltime_refit_parallel(object, data = new_data, control = control, ...)
    } else {
        ret <- modeltime_refit_sequential(object, data = new_data, control = control, ...)
    }

    validate_models_are_not_null(ret, msg_main = "Some models failed during fitting: modeltime_refit()")

    # Get new Model Descriptions
    ret <- ret %>%
        dplyr::mutate(.model_desc_user = model_desc_user_vec) %>%
        dplyr::mutate(.model_desc_old  = model_desc_modeltime_old_vec) %>%
        dplyr::mutate(.model_desc_new  = purrr::map_chr(.model, .f = get_model_description)) %>%

        # Description Logic
        dplyr::mutate(.model_desc = ifelse(
            .model_desc_old == .model_desc_new,
            # TRUE - Let User Choice Alone
            .model_desc_user,
            # FALSE - Model Algorithm Parameters Have Changed
            # - Reflect Updated Model Params in Description
            paste0("UPDATE: ", .model_desc_new)
            )
        ) %>%

        # Clean up columns
        dplyr::select(-.model_desc_user, -.model_desc_old, -.model_desc_new)


    return(ret)

}

modeltime_refit_sequential <- function(object, data, ..., control) {

    t1 <- Sys.time()

    new_data <- data
    data     <- object # object is a Modeltime Table

    # Safely refit
    safe_modeltime_refit <- purrr::safely(mdl_time_refit, otherwise = NULL, quiet = TRUE)

    # BEGIN LOOP
    # if (control$verbose) {
    #     t <- Sys.time()
    #     message(stringr::str_glue(" Beginning Sequential Loop | {round(t-t1, 3)} seconds"))
    # }

    ret <- data %>%
        dplyr::ungroup() %>%
        dplyr::mutate(.model = purrr::map2(
            .x         = .model,
            .y         = .model_id,
            .f         = function(obj, id) {

                if (control$verbose) {
                    cli::cli_alert_info(cli::col_grey("Refitting Model: Model ID {id}"))
                }

                ret <- safe_modeltime_refit(
                    obj,
                    data    = new_data,
                    control = control,
                    ...
                )

                res <- ret %>% purrr::pluck("result")

                if (!is.null(ret$error)) message(stringr::str_glue("Model {id} Error: {ret$error}"))

                if (control$verbose) {
                    if (is.null(res)) {
                        cli::cli_alert_danger(cli::col_grey("Model Failed Refitting: Model ID {id}"))
                    } else {
                        cli::cli_alert_success(cli::col_grey("Model Successfully Refitted: Model ID {id}"))
                    }
                }

                return(res)
            })
        )

    # PRINT TOTAL TIME
    if (control$verbose) {
        t <- Sys.time()
        message(stringr::str_glue("Total time | {round(t-t1, 3)} seconds"))
    }

    return(ret)

}

modeltime_refit_parallel <- function(object, data, ..., control) {

    t1 <- Sys.time()

    new_data <- data
    data     <- object # object is a Modeltime Table

    # Parallel Detection
    is_par_setup <- foreach::getDoParWorkers() > 1

    # If parallel processing is not set up, set up parallel backend
    par_setup_info <- setup_parallel_processing(control, is_par_setup, t1)
    clusters_made  <- par_setup_info$clusters_made
    cl             <- par_setup_info$cl
    t              <- par_setup_info$t

    # Setup Foreach
    `%op%` <- get_operator(allow_par = control$allow_par)

    # Safely refit
    safe_modeltime_refit <- purrr::safely(mdl_time_refit, otherwise = NULL, quiet = FALSE)

    ret <- data %>% dplyr::ungroup()

    # BEGIN LOOP
    if (control$verbose) {
        t <- Sys.time()
        message(stringr::str_glue(" Beginning Parallel Loop | {round(t-t1, 3)} seconds"))
    }

    mod_list <- foreach::foreach(
            id                  = ret$.model_id,
            .inorder            = TRUE,
            .packages           = control$packages,
            .verbose            = FALSE
        ) %op% {

            model <- ret %>%
                dplyr::filter(.model_id == id) %>%
                dplyr::select(.model) %>%
                dplyr::pull()

            mod <- safe_modeltime_refit(model[[1]], new_data, control = control)

            res <- mod %>%
                purrr::pluck("result")

            err <- mod %>%
                purrr::pluck("error")

            return(list(res = res, err = err))

        }

    # Collect models
    models <- mod_list %>% purrr::map(~ .x$res)

    # Collect errors
    error_messages <- mod_list %>% purrr::map(~ .x$err)
    purrr::iwalk(
        error_messages,
        function (e, id) {
            if (!is.null(e)) message(stringr::str_glue("Model {id} Error: {e}"))
        }
    )

    # Recombine models with modeltime table
    ret <- ret %>%
        dplyr::mutate(.model = models)


    # Finish Parallel Backend. Close clusters if we set up internally.
    finish_parallel_processing(control, clusters_made, cl, t1)

    # PRINT TOTAL TIME
    if (control$verbose) {
        t <- Sys.time()
        message(stringr::str_glue(" Total time | {round(t-t1, 3)} seconds"))
    }

    return(ret)

}





# #' @export
# modeltime_refit_xy.mdl_time_tbl <- function(object, x, y, ..., control = NULL) {
#     rlang::abort("Only models and workflows trained using `fit()` are supported at this time.")
# }


# REFIT ----

#' Modeltime Refit Helpers
#'
#' Used for low-level refitting of modeltime, parnsip and workflow models
#' These functions are not intended for user use.
#'
#' @inheritParams modeltime_refit
#'
#' @return A tibble with forecast features
#'
#' @keywords internal
#'
#' @export
mdl_time_refit <- function(object, data, ..., control = NULL) {
    UseMethod("mdl_time_refit", object)
}

#' @export
mdl_time_refit.default <- function(object, data, ..., control = NULL) {
    glubort("No method for an object of class: {class(object)[1]}. .")
}

#' @export
mdl_time_refit.workflow <- function(object, data, ..., control = NULL) {

    ret <- object %>% fit(data)

    return(ret)

}

#' @export
mdl_time_refit.model_fit <- function(object, data, ...,  control = NULL) {

    model_spec <- object$spec

    form <- object %>% pull_parsnip_preprocessor()

    ret <- model_spec %>%
        parsnip::fit(form, data = data)

    return(ret)

}

#' @export
mdl_time_refit.recursive <- function(object, data, ..., control = NULL) {

    if (inherits(object, "model_fit")) {

        # Swap out train_tail
        train_tail_old <- object$spec$train_tail

        object$spec$train_tail <- data %>%
            dplyr::slice_tail(n = nrow(train_tail_old))

        # Refit
        object <- mdl_time_refit.model_fit(object, data, ..., control = control)

        # Reconstruct class
        .class        <- class(object)
        class(object) <- c(.class[1], "recursive", .class[2])


    } else {

        # Get transformer
        transformer <- object$fit$fit$spec$transform

        # Create new train tail
        train_tail_old <- object$fit$fit$spec$train_tail
        chunk_size_old <- object$fit$fit$spec$chunk_size

        train_tail_new <- data %>%
            dplyr::slice_tail(n = nrow(train_tail_old))

        # Refit
        object <- mdl_time_refit.workflow(object, data, ..., control = control)

        # Make Recursive
        object <- recursive(object,
                            transform = transformer,
                            train_tail = train_tail_new,
                            chunk_size = chunk_size_old)

        # Need to overwrite transformer
        object$fit$fit$spec$transform <- transformer

    }

    return(object)


}


#' @export
mdl_time_refit.recursive_panel <- function(object, data, ..., control = NULL) {

    if (inherits(object, "model_fit")) {

        # Swap out train_tail
        train_tail_old <- object$spec$train_tail

        n <- object$spec$train_tail %>%
            dplyr::count(!! rlang::sym(object$spec$id)) %>%
            dplyr::pull(n) %>%
            stats::median(na.rm = TRUE)

        object$spec$train_tail <- data %>%
            panel_tail(
                id = !! object$spec$id,
                n  = n
            )

        # Refit
        object <- mdl_time_refit.model_fit(object, data, ..., control = control)

        # Reconstruct class
        .class        <- class(object)
        class(object) <- c(.class[1], "recursive_panel", .class[2])


    } else {

        # Get transformer
        transformer <- object$fit$fit$spec$transform

        # Create new train tail
        train_tail_old <- object$fit$fit$spec$train_tail

        # print("Spec ID")
        # print(object$fit$fit$spec$id)

        n <- object$fit$fit$spec$train_tail %>%
            dplyr::count(!! rlang::sym(object$fit$fit$spec$id)) %>%
            dplyr::pull(n) %>%
            stats::median(na.rm = TRUE)

        train_tail_new <- data %>%
            panel_tail(
                id = !! object$fit$fit$spec$id,
                n  = n
            )

        id_old <- object$fit$fit$spec$id
        chunk_size_old <- object$fit$fit$spec$chunk_size

        # Refit
        object <- mdl_time_refit.workflow(object, data, ..., control = control)

        # Make Recursive
        object <- recursive(
            object,
            transform  = transformer,
            train_tail = train_tail_new,
            id         = id_old,
            chunk_size = chunk_size_old
        )

        # Need to overwrite transformer
        object$fit$fit$spec$transform <- transformer

    }

    return(object)


}



# # REFIT XY ----
#
# #' @export
# mdl_time_refit_xy.workflow <- function(object, x, y, control = NULL, ...) {
#     rlang::abort("Using 'mdl_time_refit_xy()' on a workflow object is not allowed. Try using 'modeltime_refit()'.")
# }
#
# #' @export
# mdl_time_refit_xy.model_fit <- function(object, x, y, control = NULL, ...) {
#
#     if (is.null(control)) {
#         control <- parsnip::control_parsnip()
#     }
#
#     model_spec <- object$spec
#
#     ret <- model_spec %>%
#         fit_xy(x = x, y = y, control = control, ...)
#
#     return(ret)
#
# }
#
# #' @export
# mdl_time_refit_xy.default <- function(object, x, y, control = NULL, ...) {
#     rlang::abort(paste0("No method for class '", class(object)[1], "'."))
# }

Try the modeltime package in your browser

Any scripts or data that you put into this service are public.

modeltime documentation built on Sept. 2, 2023, 5:06 p.m.