R/modeltime-forecast.R

Defines functions detect_net mdl_time_forecast.workflow mdl_time_forecast.model_fit mdl_time_forecast conformal_split_func conformal_default_func safe_conf_interval_map_by_id safe_conf_interval_map safe_modeltime_forecast_map modeltime_forecast.mdl_time_tbl modeltime_forecast.default modeltime_forecast

Documented in mdl_time_forecast modeltime_forecast

# MODELTIME FORECAST ----

#' Forecast future data
#'
#' The goal of `modeltime_forecast()` is to simplify the process of
#' forecasting future data.
#'
#' @param object A Modeltime Table
#' @param new_data A `tibble` containing future information to forecast.
#'  If `NULL`, forecasts the calibration data.
#' @param h The forecast horizon (can be used instead of `new_data` for
#'  time series with no exogenous regressors).
#'  Extends the calibration data `h` periods into the future.
#' @param actual_data Reference data that is combined with the output tibble and given a `.key = "actual"`
#' @param conf_interval An estimated confidence interval based on the calibration data.
#'  This is designed to estimate future confidence from _out-of-sample prediction error_.
#' @param conf_by_id Whether or not to produce confidence interval estimates by an ID feature.
#'
#'  - When `FALSE`, a global model confidence interval is provided.
#'
#'  - If `TRUE`, a local confidence interval is provided group-wise for each time series ID.
#'    To enable local confidence interval, an `id` must be provided during `modeltime_calibrate()`.
#'
#' @param conf_method Algorithm used to produce confidence intervals. All CI's are Conformal Predictions. Choose one of:
#'
#'  - `conformal_default`: Uses `qnorm()` to compute quantiles from out-of-sample (test set) residuals.
#'
#'  - `conformal_split`: Uses the split method split conformal inference method described by Lei _et al_ (2018)
#'
#' @param keep_data Whether or not to keep the `new_data` and `actual_data` as extra columns in the results.
#'  This can be useful if there is an important feature in the `new_data` and `actual_data` needed
#'  when forecasting.
#'  Default: `FALSE`.
#' @param arrange_index Whether or not to sort the index in rowwise chronological order (oldest to newest) or to
#'  keep the original order of the data.
#'  Default: `FALSE`.
#' @param ... Not currently used
#'
#'
#' @return
#' A tibble with predictions and time-stamp data. For ease of plotting and calculations,
#'  the column names are transformed to:
#'
#' - `.key`: Values labeled either "prediction" or "actual"
#' - `.index`: The timestamp index.
#' - `.value`: The value being forecasted.
#'
#' Additionally, if the Modeltime Table has been previously calibrated using [modeltime_calibrate()],
#' you will gain confidence intervals.
#' - `.conf_lo`: The lower limit of the confidence interval.
#' - `.conf_hi`: The upper limit of the confidence interval.
#'
#' Additional descriptive columns are included:
#' - `.model_id`: Model ID from the Modeltime Table
#' - `.model_desc`: Model Description from the Modeltime Table
#'
#' Unnecessary columns are _dropped_ to save space:
#' - `.model`
#' - `.calibration_data`
#'
#' @details
#'
#' The `modeltime_forecast()` function prepares a forecast for visualization with
#' with [plot_modeltime_forecast()]. The forecast is controlled by `new_data` or `h`,
#' which can be combined with existing data (controlled by `actual_data`).
#' Confidence intervals are included if the incoming Modeltime Table has been
#' calibrated using [modeltime_calibrate()].
#' Otherwise confidence intervals are not estimated.
#'
#' __New Data__
#'
#' When forecasting you can specify future data using `new_data`.
#' This is a future tibble with date column and columns for xregs
#'  extending the trained dates and exogonous regressors (xregs) if used.
#'
#'    - __Forecasting Evaluation Data__: By default, the `new_data` will use the `.calibration_data`
#'      if `new_data` is not provided.
#'      This is the equivalent of using `rsample::testing()` for getting test data sets.
#'    - __Forecasting Future Data__: See `timetk::future_frame()` for creating future tibbles.
#'    - __Xregs__: Can be used with this method
#'
#'
#' __H (Horizon)__
#'
#' When forecasting, you can specify `h`. This is a phrase like "1 year",
#' which extends the `.calibration_data` (1st priority) or the `actual_data` (2nd priority)
#' into the future.
#'    - __Forecasting Future Data__: All forecasts using `h` are
#'      ___extended after the calibration data or actual_data___.
#'
#'    - Extending `.calibration_data` - Calibration data is given 1st priority, which is
#'      desirable _after refitting_ with [modeltime_refit()].
#'      Internally, a call is made to `timetk::future_frame()` to
#'      expedite creating new data using the date feature.
#'    - Extending `actual_data` - If `h` is provided, and the modeltime table has not been
#'      calibrated, the "actual_data" will be extended into the future. This is useful
#'      in situations where you want to go directly from `modeltime_table()` to `modeltime_forecast()`
#'      without calibrating or refitting.
#'    - __Xregs__: Cannot be used because future data must include new xregs.
#'      If xregs are desired, build a future data frame and use `new_data`.
#'
#' __Actual Data__
#'
#' This is reference data that contains the true values of the time-stamp data.
#' It helps in visualizing the performance of the forecast vs the actual data.
#'
#' When `h` is used and the Modeltime Table has _not been calibrated_, then the
#' actual data is extended into the future periods that are defined by `h`.
#'
#' __Confidence Interval Estimation__
#'
#' Confidence intervals (`.conf_lo`, `.conf_hi`) are estimated based on the normal estimation of
#' the testing errors (out of sample) from [modeltime_calibrate()].
#' The out-of-sample error estimates are then carried through and
#' applied to applied to any future forecasts.
#'
#' The confidence interval can be adjusted with the `conf_interval` parameter. The algorithm used
#' to produce confidence intervals can be changed with the `conf_method` parameter.
#'
#' _Conformal Default Method:_
#'
#' When `conf_method = "conformal_default"` (default), this method uses `qnorm()`
#' to produce a 95% confidence interval by default. It estimates a normal (Gaussian distribution)
#' based on the out-of-sample errors (residuals).
#'
#' The confidence interval is _mean-adjusted_, meaning that if the mean of the residuals
#' is non-zero, the confidence interval is adjusted to widen the interval to capture
#' the difference in means.
#'
#' _Conformal Split Method:_
#'
#' When `conf_method = "conformal_split`, this method uses the split conformal inference method
#' described by Lei _et al_ (2018). This is also implemented in the `probably` R package's
#' `int_conformal_split()` function.
#'
#' _What happens to the confidence interval after refitting models?_
#'
#' Refitting has no affect on the confidence interval since this is calculated independently of
#' the refitted model. New observations typically improve
#' future accuracy, which in most cases makes the out-of-sample confidence intervals conservative.
#'
#' __Keep Data__
#'
#' Include the new data (and actual data) as extra columns with the results of the model forecasts.
#' This can be helpful when the new data includes information useful to the forecasts.
#' An example is when forecasting _Panel Data_ and the new data contains
#' ID features related to the time series group that the forecast belongs to.
#'
#' __Arrange Index__
#'
#' By default, `modeltime_forecast()` keeps the original order of the data.
#' If desired, the user can sort the output by `.key`, `.model_id` and `.index`.
#'
#' @references
#' Lei, Jing, et al. "Distribution-free predictive inference for regression."
#' _Journal of the American Statistical Association_ 113.523 (2018): 1094-1111.
#'
#' @examples
#' library(tidyverse)
#' library(lubridate)
#' library(timetk)
#' library(parsnip)
#' library(rsample)
#' library(modeltime)
#'
#' # Data
#' m750 <- m4_monthly %>% filter(id == "M750")
#'
#' # Split Data 80/20
#' splits <- initial_time_split(m750, prop = 0.9)
#'
#' # --- MODELS ---
#'
#' # Model 1: prophet ----
#' model_fit_prophet <- prophet_reg() %>%
#'     set_engine(engine = "prophet") %>%
#'     fit(value ~ date, data = training(splits))
#'
#'
#' # ---- MODELTIME TABLE ----
#'
#' models_tbl <- modeltime_table(
#'     model_fit_prophet
#' )
#'
#' # ---- CALIBRATE ----
#'
#' calibration_tbl <- models_tbl %>%
#'     modeltime_calibrate(new_data = testing(splits))
#'
#' # ---- ACCURACY ----
#'
#' calibration_tbl %>%
#'     modeltime_accuracy()
#'
#' # ---- FUTURE FORECAST ----
#'
#' calibration_tbl %>%
#'     modeltime_forecast(
#'         new_data    = testing(splits),
#'         actual_data = m750
#'     )
#'
#' # ---- ALTERNATIVE: FORECAST WITHOUT CONFIDENCE INTERVALS ----
#' # Skips Calibration Step, No Confidence Intervals
#'
#' models_tbl %>%
#'     modeltime_forecast(
#'         new_data    = testing(splits),
#'         actual_data = m750
#'     )
#'
#' # ---- KEEP NEW DATA WITH FORECAST ----
#' # Keeps the new data. Useful if new data has information
#' #  like ID features that should be kept with the forecast data
#'
#' calibration_tbl %>%
#'     modeltime_forecast(
#'         new_data      = testing(splits),
#'         keep_data     = TRUE
#'     )
#'
#' @name modeltime_forecast
NULL

#' @export
#' @rdname modeltime_forecast
modeltime_forecast <- function(object, new_data = NULL, h = NULL, actual_data = NULL,
                               conf_interval = 0.95, conf_by_id = FALSE, conf_method = "conformal_default",
                               keep_data = FALSE, arrange_index = FALSE, ...) {

    # Required arguments & messages
    if (is.null(new_data) && is.null(h)) {
        if (all(c(".type", ".calibration_data") %in% names(object))) {
            message("Using '.calibration_data' to forecast.")
        } else if (!is.null(actual_data)) {
            message("Using 'actual_data' to forecast. This may not be desirable for sequence models such as ARIMA.")
        } else {
            rlang::abort("Forecast requires either: \n - 'new_data' \n - 'h'")
        }
    }

    # Horizon, h: Checks
    if (!is.null(h)) {

        # Check for .calibration data or actual data if using `h`
        # - h needs this to extend from
        using_actual <- FALSE
        if (!all(c(".type", ".calibration_data") %in% names(object))) {
            if (is.null(actual_data)) {
               rlang::abort("Forecasting with 'h' requires one of: \n - '.calibration_data' (see '?modeltime_calibrate()') \n - 'actual_data'")
            }
            using_actual <- TRUE
        }

        # Ensure no overlapping timestamps if using `h`
        # - `h` doesn't know how to handle overlapping time series
        abort_msg <- "modeltime_forecast(): Overlapping dates detected indicating time series groups. 'h' cannot be used to forecast. Try using 'new_data' that has been extended using `timetk::future_frame()`."
        if (using_actual) {
            validate_no_overlapping_dates(
                data          = actual_data,
                abort_message = abort_msg
            )
        } else {
            calib_data <- object %>%
                purrr::pluck(".calibration_data", 1)

            validate_no_overlapping_dates(
                data          = calib_data,
                abort_message = abort_msg
            )
        }
    }

    if (!conf_method %in% c("conformal_default", "conformal_split")) {
        rlang::abort('conf_method must be one of "conformal_default", "conformal_split"')
    }

    UseMethod("modeltime_forecast")
}

#' @export
modeltime_forecast.default <- function(object, new_data = NULL, h = NULL, actual_data = NULL,
                                       conf_interval = 0.95, conf_by_id = FALSE,  conf_method = "conformal_default",
                                       keep_data = FALSE, arrange_index = FALSE, ...) {
    glubort("Received an object of class: {class(object)[1]}. Expected an object of class:\n 1. 'mdl_time_tbl' - A Model Time Table made with 'modeltime_table()' and calibrated with 'modeltime_calibrate()'.")
}

#' @export
modeltime_forecast.mdl_time_tbl <- function(object, new_data = NULL, h = NULL, actual_data = NULL,
                                            conf_interval = 0.95, conf_by_id = FALSE, conf_method = "conformal_default",
                                            keep_data = FALSE, arrange_index = FALSE, ...) {

    data <- object

    n_models <- data$.model_id %>% unique() %>% length()

    # HANDLE CALIBRATION DATA
    if (!all(c(".type", ".calibration_data") %in% names(data))) {
        # glubort("Expecting columns '.type' and '.calibration_data'. Try running 'modeltime_calibrate()' before using 'modeltime_forecast()'.")
        conf_interval = NULL
        data <- data %>%
            dplyr::mutate(
                .type = NA,
                .calibration_data = NA
            )
    } else {
        data_calibration <- data %>%
            dplyr::select(.model_id, .calibration_data)
    }

    # CREATE FORECAST ----

    # Compute first model with actual data
    ret_1 <- data %>%
        dplyr::ungroup() %>%
        dplyr::slice(1) %>%
        safe_modeltime_forecast_map(
            new_data      = new_data,
            h             = h,
            actual_data   = actual_data,
            keep_data     = if (conf_by_id) TRUE else keep_data,
            arrange_index = arrange_index,
            bind_actual   = TRUE # Rowwise binds the actual data during the forecast
        )

    if ("actual" %in% unique(ret_1$.key)) {
        ret_1 <- ret_1 %>%
            dplyr::mutate(.model_desc = ifelse(.key == "actual", "ACTUAL", .model_desc)) %>%
            dplyr::mutate(.model_id   = ifelse(.key == "actual", NA_integer_, .model_id))
    }

    # Compute subsequent models without actual data
    ret_2 <- tibble::tibble()

    if (n_models > 1) {

        ret_2 <- data %>%
            dplyr::ungroup() %>%
            dplyr::slice(2:dplyr::n()) %>%
            safe_modeltime_forecast_map(
                new_data      = new_data,
                h             = h,
                actual_data   = actual_data,
                keep_data     = if (conf_by_id) TRUE else keep_data,
                arrange_index = arrange_index,
                bind_actual   = FALSE # Skips iterative rowwise binding of actual_data
            )

    }

    # If errors occur, .nested.col remains - needs removed
    if (".nested.col" %in% names(ret_1)) {
        ret_1 <- ret_1 %>%
            dplyr::select(-.nested.col)
    }

    if (".nested.col" %in% names(ret_2)) {
        ret_2 <- ret_2 %>%
            dplyr::select(-.nested.col)
    }

    ret <- dplyr::bind_rows(ret_1, ret_2)

    # ADD CONF INTERVALS ----
    if (!is.null(conf_interval)) {

        if (conf_by_id) {

            # Check 5th column exists
            id_col_found <- data_calibration$.calibration_data[[1]] %>% names() %>% length() == 5

            # Apply CI by ID if possible
            if (id_col_found) {

                id_col <- data_calibration$.calibration_data[[1]] %>% names() %>% purrr::pluck(5)

                ret <- ret %>%
                    safe_conf_interval_map_by_id(
                        data_calibration,
                        conf_interval = conf_interval,
                        conf_method   = conf_method,
                        id            = !! id_col
                    ) %>%
                    dplyr::select(.model_id, .model_desc, .key, .index, .value, .conf_lo, .conf_hi, dplyr::all_of(names(new_data)))

                # Remove unnecessary columns if `keep_data = FALSE`. Required to keep the id column.
                if (!keep_data) {
                    ret <- ret %>%
                        dplyr::select(.model_id:.conf_hi, dplyr::all_of(id_col))
                }

            } else {

                rlang::warn("The 'id' column in calibration data was not detected. Global Confidence Interval is being returned.")

                ret <- ret %>%
                    safe_conf_interval_map(data_calibration, conf_interval = conf_interval) %>%
                    dplyr::relocate(dplyr::starts_with(".conf_"), .after = .value)

                # Remove unnecessary columns if `keep_data = FALSE`
                if (!keep_data) {
                    ret <- ret %>%
                        dplyr::select(.model_id:.conf_hi)
                }
            }



        } else {
            ret <- ret %>%
                safe_conf_interval_map(
                    data_calibration,
                    conf_interval = conf_interval,
                    conf_method   = conf_method
                ) %>%
                dplyr::relocate(dplyr::starts_with(".conf_"), .after = .value)

            # Remove unnecessary columns if `keep_data = FALSE`
            if (!keep_data) {
                ret <- ret %>%
                    dplyr::select(.model_id:.conf_hi)
            }
        }

    }

    # REMOVE ANY EXTRA-ACTUAL DATA ----
    # - Happens when ensembling

    ret <- ret %>%
        dplyr::filter(.model_desc == "ACTUAL" | .key == "prediction")

    return(ret)
}


# SAFE FORECAST MAPPERS ----

safe_modeltime_forecast_map <- function(data, new_data = NULL, h = NULL, actual_data = NULL, bind_actual = TRUE, keep_data = FALSE, arrange_index = FALSE, ...) {

    safe_modeltime_forecast <- purrr::safely(mdl_time_forecast, otherwise = NA, quiet = FALSE)

    data %>%
        dplyr::mutate(.nested.col = purrr::map2(
            .x         = .model,
            .y         = .calibration_data,
            .f         = function(obj, cal) {

                ret <- safe_modeltime_forecast(
                    obj, cal,
                    new_data      = new_data,
                    h             = h,
                    actual_data   = actual_data,
                    bind_actual   = bind_actual,
                    keep_data     = keep_data,
                    arrange_index = arrange_index
                )

                err <- ret$error

                ret <- ret$result

                # if (!is.null(error)) warning(err)

                return(ret)
            })
        ) %>%
        # Drop unnecessary columns
        dplyr::select(-.model, -.type, -.calibration_data) %>%
        tidyr::unnest(cols = .nested.col)
}


# SAFE CONF INTERVAL MAPPERS ----

safe_conf_interval_map <- function(data, data_calibration, conf_interval, conf_method) {

    empty_ci_tbl <- tibble::tibble(
        .conf_lo = NA,
        .conf_hi = NA
    )

    if (conf_method == "conformal_default") {
        safe_ci <- purrr::safely(
            conformal_default_func,
            otherwise = empty_ci_tbl,
            quiet = FALSE
        )
    }
    if (conf_method == "conformal_split") {
        safe_ci <- purrr::safely(
            conformal_split_func,
            otherwise = empty_ci_tbl,
            quiet = FALSE
        )
    }

    data %>%
        dplyr::group_by(.model_id) %>%
        tidyr::nest() %>%
        dplyr::left_join(data_calibration, by = ".model_id") %>%
        dplyr::mutate(.ci = purrr::map2(data, .calibration_data, .f = function(data_1, data_2) {
            res <- safe_ci(data_1, data_2, conf_interval = conf_interval)
            res %>% purrr::pluck("result")
        })
        ) %>%
        dplyr::select(-.calibration_data) %>%
        tidyr::unnest(cols = c(data, .ci)) %>%
        dplyr::ungroup()
}

safe_conf_interval_map_by_id <- function(data, data_calibration, conf_interval, id, conf_method) {

    empty_ci_tbl <- tibble::tibble(
        .conf_lo = NA,
        .conf_hi = NA
    )

    if (conf_method == "conformal_default") {
        safe_ci <- purrr::safely(
            conformal_default_func,
            otherwise = empty_ci_tbl,
            quiet = FALSE
        )
    }
    if (conf_method == "conformal_split") {
        safe_ci <- purrr::safely(
            conformal_split_func,
            otherwise = empty_ci_tbl,
            quiet = FALSE
        )
    }

    forecast_nested_tbl <- data %>%
        tibble::rowid_to_column(var = "..rowid") %>%
        dplyr::group_by(.model_id, !! rlang::ensym(id)) %>%
        tidyr::nest()

    calibration_nested_tbl <- data_calibration %>%
        tidyr::unnest(.calibration_data) %>%
        dplyr::group_by(.model_id, !! rlang::ensym(id)) %>%
        tidyr::nest() %>%
        dplyr::rename(.calibration_data = data)

    # print(forecast_nested_tbl)
    # print(calibration_nested_tbl)

    forecast_nested_tbl %>%
        dplyr::left_join(
            calibration_nested_tbl
            ,
            by = names(forecast_nested_tbl)[1:2]
        ) %>%
        dplyr::mutate(.ci = purrr::map2(data, .calibration_data, .f = function(data_1, data_2) {
            res <- safe_ci(data_1, data_2, conf_interval = conf_interval)
            res %>% purrr::pluck("result")
        })
        ) %>%
        dplyr::select(-.calibration_data) %>%
        tidyr::unnest(cols = c(data, .ci)) %>%
        dplyr::ungroup() %>%
        dplyr::arrange(..rowid) %>%
        dplyr::select(
            -..rowid
            # , - !! rlang::ensym(id)
        )
}

# * DEFAULT CI METHOD: QUANTILE OF NORMAL DISTRIBUTION AROUND PREDICTIONS ----
conformal_default_func <- function(data_1, data_2, conf_interval) {

    # Collect absolute residuals
    ret <- tryCatch({

        residuals <- c(data_2$.residuals, -data_2$.residuals)
        s <- stats::sd(residuals)

        ci_tbl <- data_1 %>%
            dplyr::mutate(
                .conf_lo = stats::qnorm((1-conf_interval)/2, mean = .value, sd = s),
                .conf_hi = stats::qnorm((1+conf_interval)/2, mean = .value, sd = s)
            ) %>%
            dplyr::select(.conf_lo, .conf_hi)

        ci_tbl

    }, error = function(e) {
        tibble::tibble(
            .conf_lo = NA,
            .conf_hi = NA
        )
    })

    return(ret)

}

# * CONFORMAL PREDICTION VIA SPLIT METHOD ----
# https://github.com/tidymodels/probably/blob/c46326651109fb2ebd1b3762b3cb086cfb96ac88/R/conformal_infer_split.R#L99
conformal_split_func <- function(data_1, data_2, conf_interval) {

    # Collect absolute residuals
    ret <- tryCatch({

        residuals        <- data_2$.residuals
        residuals_sorted <- residuals %>%
            abs() %>%
            sort(decreasing = FALSE)

        n     <- nrow(data_2)
        q_ind <- ceiling(conf_interval * n)
        q_val <- residuals_sorted[q_ind]

        ci_tbl <- data_1 %>%
            dplyr::mutate(
                .conf_lo = .value - q_val,
                .conf_hi = .value + q_val
            ) %>%
            dplyr::select(.conf_lo, .conf_hi)

        ci_tbl

    }, error = function(e) {
        tibble::tibble(
            .conf_lo = NA,
            .conf_hi = NA
        )
    })

    return(ret)

}


# FORECAST UTILITIES ----

#' Modeltime Forecast Helpers
#'
#' Used for low-level forecasting of modeltime, parnsip and workflow models.
#' These functions are not intended for user use.
#'
#' @inheritParams modeltime_forecast
#' @param calibration_data Data that has been calibrated from a testing set
#' @param bind_actual Logical. Whether or not to skip rowwise binding of `actual_data``
#'
#' @return A tibble with forecast features
#'
#' @keywords internal
#'
#' @export
mdl_time_forecast <- function(object, calibration_data, new_data = NULL, h = NULL, actual_data = NULL, bind_actual = TRUE, keep_data = FALSE, arrange_index = FALSE, ...) {
    UseMethod("mdl_time_forecast", object)
}

#' @export
mdl_time_forecast.model_fit <- function(object, calibration_data, new_data = NULL, h = NULL, actual_data = NULL, bind_actual = TRUE, keep_data = FALSE, arrange_index = FALSE, ...) {

    calib_provided <- FALSE
    h_provided     <- FALSE

    # MODEL OBJECT

    # If no 'new_data', forecast 'calibration_data'
    if (is.null(new_data) && is.null(h)) {
        if (is.data.frame(calibration_data)) {
            new_data <- calibration_data
            calib_provided <- TRUE
        } else if (is.data.frame(actual_data)) {
            new_data <- actual_data
        } else {
            rlang::abort("Forecast requires 'new_data', 'calibration_data', or 'actual_data'.")
        }
    }

    # Convert 'h' to 'new_data'
    if (!is.null(h)) {
        if (is.data.frame(calibration_data)) {
            tryCatch({
                # Suppress date selection
                suppressMessages(new_data <- timetk::future_frame(calibration_data, .length_out = h, ...))
            }, error = function(e) {
                rlang::abort("Attempt to extend '.calibration_data' into the future using 'h' has failed.")
            })
        } else if (is.data.frame(actual_data)) {
            tryCatch({
                # Suppress date selection
                suppressMessages(new_data <- timetk::future_frame(actual_data, .length_out = h, ...))
            }, error = function(e) {
                rlang::abort("Attempt to extend 'actual_data' into the future using 'h' has failed.")
            })
        } else {
            rlang::abort("Forecast requires 'new_data', '.calibration_data', or 'actual_data'.")
        }

        h_provided <- TRUE
    }

    # For combining new data
    actual_data_unprocessed <- actual_data
    new_data_unprocessed    <- new_data

    # Setup data for predictions
    nms_time_stamp_predictors <- timetk::tk_get_timeseries_variables(new_data)[1]
    time_stamp_predictors_tbl <- new_data %>%
        dplyr::select(!! rlang::sym(nms_time_stamp_predictors)) %>%
        dplyr::rename(.index = !! rlang::sym(nms_time_stamp_predictors))

    # PREDICTIONS ----

    # Comment this out ----
    # obj <<- object
    # print({
    #     list(
    #         object   = object,
    #         class    = class(object),
    #         new_data = new_data
    #     )
    # })

    modeltime_forecast <- tryCatch({

        if (detect_net(object) && inherits(object, "recursive")) {
            predictions_tbl <- object %>% predict.recursive(new_data = new_data)
        } else if (detect_net(object) && inherits(object, "recursive_panel")) {
            predictions_tbl <- object %>% predict.recursive_panel(new_data = new_data)
        } else {
            predictions_tbl <- object %>% stats::predict(new_data = new_data)
        }

        modeltime_forecast <- predictions_tbl %>%
            dplyr::bind_cols(time_stamp_predictors_tbl)

    }, error = function(e) {
        if (any(c(h_provided, calib_provided))) {
            # Most likely issue: need to provide external regressors
            glubort("Problem occurred during prediction. Most likely cause is missing external regressors. Try using 'new_data' and supply a dataset containing all required columns. {e}")
        } else {
            glubort("Problem occurred during prediction. {e}")
        }
    })

    # Format data
    data_formatted <- modeltime_forecast %>%
        dplyr::mutate(.key = "prediction") %>%
        dplyr::select(.key, dplyr::everything())

    # COMBINE ACTUAL DATA

    if (!is.null(actual_data) && bind_actual) {

        # setup
        nms_final     <- names(data_formatted)

        # print(nms_final)

        if (length(object$preproc$y_var) > 0) {
            fit_interface <-  "formula"
        } else {
            fit_interface <- "xy"
        }

        if (fit_interface == "formula") {
            nm_target <- object$preproc$y_var

            # Get the index and target variable symbol
            idx        <- actual_data %>% timetk::tk_index()
            target_sym <- rlang::sym(object$preproc$y_var)

            # Set ID & Index
            actual_data <- actual_data %>%
                dplyr::mutate(.key = "actual") %>%
                dplyr::mutate(.index = idx)

            # Common for Data.Frame Style - Apply transformation to target variable if needed
            # If preproc, rename first variable the target name
            pp_names <- names(object$preproc)
            if (any(pp_names == "terms") | any(pp_names == "x_var")) {
                formula_lhs <- find_formula_lhs(object$preproc)
                if (!is.null(formula_lhs)) {
                    actual_data <- actual_data %>%
                        dplyr::mutate(!! target_sym := eval(formula_lhs))
                } else {
                    warning(call. = FALSE, paste0("Cannot determine if transformation is required on 'actual_data'"))
                }
            }

            # Problem with Formula Style (e.g. stats::lm() ):
            # - Need to search model for formula LHS to know transformation
            # - Use find_formula_lhs() to search first level for formula
            if (object$spec$method$fit$interface == "formula") {

                formula_lhs <- find_formula_lhs(object$fit)
                if (!is.null(formula_lhs)) {
                    actual_data <- actual_data %>%
                        dplyr::mutate(!! target_sym := eval(formula_lhs))
                } else {
                    warning(call. = FALSE, paste0("Cannot determine if transformation is required on 'actual_data'"))
                }
            }

            # Issure #228 - fix `.pred_res`
            data_formatted <- actual_data %>%
                dplyr::bind_rows(data_formatted)

            if (".pred_res" %in% colnames(data_formatted)) {
                data_formatted <- data_formatted %>%
                    dplyr::rename(.pred = .pred_res)
            }

            if (".pred_res" %in% nms_final) {
                nms_final <- stringr::str_replace(nms_final, ".pred_res", ".pred")
            }

            data_formatted <- data_formatted %>%
                dplyr::mutate(.pred = ifelse(is.na(.pred), !! target_sym, .pred)) %>%
                dplyr::select(!!! rlang::syms(nms_final))

        } else {
            # XY Interface

            rlang::abort("XY Interface not yet implemented for 'modeltime_forecast()'. Try using the Formula Interface with `fit()` or a `workflow()`.")

            # actual_data <- prepare_data(object, actual_data) %>%
            #     tibble::as_tibble() %>%
            #     dplyr::mutate(.key = "actual") %>%
            #     dplyr::rename(.index = !! rlang::sym(nms_time_stamp_predictors))

        }

        # Issue #228 - fix .pred_res
        if (".pred_res" %in% colnames(data_formatted)) {
            data_formatted <- data_formatted %>%
                dplyr::rename(.pred = .pred_res)
        }

        if (".pred_res" %in% nms_final) {
            nms_final <- stringr::str_replace(nms_final, ".pred_res", ".pred")
        }

        data_formatted <- data_formatted %>%
            dplyr::select(!!! rlang::syms(nms_final))

    }

    # FINALIZE
    ret <- data_formatted %>%
        dplyr::rename(.value = .pred) %>%
        dplyr::select(.key, .index, .value) %>%
        dplyr::mutate(.key = factor(.key, levels = c("actual", "prediction")))

    # Keep Data
    act_tbl  <- NULL
    pred_tbl <- NULL
    if (keep_data) {

        if (!is.null(actual_data) && bind_actual) {
            act_tbl <- ret %>%
                dplyr::filter(.key == "actual") %>%
                dplyr::bind_cols(actual_data_unprocessed)
        }

        pred_tbl <- ret %>%
            dplyr::filter(.key == "prediction") %>%
            dplyr::bind_cols(new_data_unprocessed)

        ret <- dplyr::bind_rows(act_tbl, pred_tbl)

    }

    if (arrange_index) {
        ret <- ret %>%
            dplyr::arrange(.key, .index)
    }

    return(ret)

}

#' @export
mdl_time_forecast.workflow <- function(object, calibration_data, new_data = NULL, h = NULL, actual_data = NULL, bind_actual = TRUE, keep_data = FALSE, arrange_index = FALSE, ...) {

    calib_provided <- FALSE
    h_provided     <- FALSE

    # Checks
    if (!object$trained) {
        rlang::abort("Workflow must be trained using the 'fit()' function.")
    }

    # WORKFLOW MOLD

    # Contains $predictors, $outcomes, $blueprint
    # mld <- object %>% workflows::extract_mold()

    # UPGRADE MOLD (TEMP FIX) ----
    # - models built with hardhat <1.0.0 have issue
    #   https://github.com/tidymodels/hardhat/issues/200
    preprocessor <- workflows::extract_preprocessor(object)
    mld          <- hardhat::mold(preprocessor, preprocessor$template)

    # NEW DATA

    # If no 'new_data', forecast 'calibration_data' or 'actual_data'
    if (is.null(new_data) && is.null(h)) {
        if (is.data.frame(calibration_data)) {
            new_data <- calibration_data
            calib_provided <- TRUE
        } else if (is.data.frame(actual_data)) {
            new_data <- actual_data
        } else {
            rlang::abort("Forecast requires 'new_data', 'calibration_data', or 'actual_data'.")
        }
    }

    # Convert 'h' to 'new_data'
    if (!is.null(h)) {
        if (is.data.frame(calibration_data)) {
            tryCatch({
                # Suppress date selection
                suppressMessages(new_data <- timetk::future_frame(calibration_data, .length_out = h, ...))
            }, error = function(e) {
                rlang::abort("Attempt to extend '.calibration_data' into the future using 'h' has failed.")
            })
        } else if (is.data.frame(actual_data)) {
            tryCatch({
                # Suppress date selection
                suppressMessages(new_data <- timetk::future_frame(actual_data, .length_out = h, ...))
            }, error = function(e) {
                rlang::abort("Attempt to extend 'actual_data' into the future using 'h' has failed.")
            })
        } else {
            rlang::abort("Forecast requires 'new_data', '.calibration_data', or 'actual_data'.")
        }

        h_provided <- TRUE
    }

    # For combining new data
    actual_data_unprocessed <- actual_data
    new_data_unprocessed    <- new_data

    # Issue - Forge processes all recipe steps at once, so need to have outcomes
    # Reference: https://tidymodels.github.io/hardhat/articles/forge.html

    # Prep for Forge (Assign dummy y_var if needed)
    y_var <- names(mld$outcomes)
    if (!any(y_var %in% names(new_data))) {
        new_data[,y_var] <- 1000
    }

    # Apply forge just to get predictors
    new_data_forged <- tryCatch({
        hardhat::forge(new_data, mld$blueprint, outcomes = TRUE)
    }, error = function(e) {
        if (any(c(h_provided, calib_provided))) {
            # Most likely issue: need to provide external regressors
            glubort("Problem occurred in getting predictors from new data. Most likely cause is missing external regressors. Try using 'new_data' and supply a dataset containing all required columns. {e}")
        } else {
            glubort("Problem occurred getting predictors from new data. {e}")
        }
    })

    # TIMESTAMP + PREDICTORS

    nms_time_stamp_predictors <- timetk::tk_get_timeseries_variables(new_data_forged$predictors)[1]

    if (!is.na(nms_time_stamp_predictors)) {
        time_stamp_predictors_tbl <- new_data_forged$predictors %>%
            dplyr::select(!! rlang::sym(nms_time_stamp_predictors)) %>%
            dplyr::rename(.index = !! rlang::sym(nms_time_stamp_predictors))
    } else {
        # Most ML algorithms won't have time stamp
        idx <- new_data %>% timetk::tk_index()

        if (length(idx) == nrow(new_data_forged$predictors)) {
            # IF index and predictors are same length, all good
            time_stamp_predictors_tbl <- new_data_forged$predictors %>%
                dplyr::mutate(.index = idx)
        } else {
            # Test to see if missing data is being dropped (happens with step_naomit)
            new_data_missing_removed_tbl <- new_data %>%
                tibble::rowid_to_column(var = "..id") %>%
                tidyr::drop_na()

            if (nrow(new_data_forged$predictors) == nrow(new_data_missing_removed_tbl)) {
                # If row counts are identical, subset idx by retained rows
                time_stamp_predictors_tbl <- new_data_forged$predictors %>%
                    dplyr::mutate(.index = idx[new_data_missing_removed_tbl$..id])
            } else {
                glubort("Problem occurred combining processed data with timestamps. Most likely cause is rows being added or removed during preprocessing. Try imputing missing values to retain the full number of rows.")
            }
        }
    }

    # FORGE NEW DATA -----

    # Fix - When ID is dummied
    id <- object$fit$fit$spec$id
    df_id <- NULL
    if (!is.null(id)) {
        df_id = new_data %>% dplyr::select(dplyr::all_of(id))
    }

    # Issue - hardhat::forge defaults to outcomes = FALSE, which creates an error at predict.workflow()
    forged    <- hardhat::forge(new_data, mld$blueprint, outcomes = TRUE)
    new_data  <- forged$predictors
    fit       <- object$fit$fit

    # PREDICTIONS ----

    # Comment this out ----
    # obj <<- object
    # print({
    #     list(
    #         object   = object,
    #         class    = class(object),
    #         new_data = new_data,
    #         id       = id,
    #         df_id    = df_id
    #     )
    # })

    # Fix - When ID is dummied
    df <- new_data
    if (!is.null(id)) {
        if (!id %in% names(new_data)) {
            df <- new_data %>%
                dplyr::bind_cols(df_id)
            fit$spec$remove_id <- TRUE
        }
    }

    # PREDICT
    if (detect_net(fit) && inherits(fit, "recursive")) {
        data_formatted <- fit %>% predict.recursive(new_data = df)
    } else if (detect_net(fit) && inherits(fit, "recursive_panel")) {
        data_formatted <- fit %>% predict.recursive_panel(new_data = df)
    } else {
        data_formatted <- fit %>% stats::predict(new_data = df)
    }

    data_formatted <- data_formatted %>%
        dplyr::bind_cols(time_stamp_predictors_tbl) %>%
        dplyr::mutate(.key = "prediction") %>%
        dplyr::select(.key, dplyr::everything())


    # COMBINE ACTUAL DATA

    if (!is.null(actual_data) && bind_actual) {

        nms_final <- names(data_formatted)

        # mld <- object %>% workflows::extract_mold()

        actual_data_forged <- hardhat::forge(new_data = actual_data, blueprint = mld$blueprint, outcomes = TRUE)

        actual_data_prepped <- actual_data_forged$outcomes %>%
            dplyr::bind_cols(actual_data_forged$predictors) %>%
            dplyr::mutate(.key = "actual")

        # ---- BUG: if NROWs before/after don't match ----
        # - This situation is likely due to missing values in predictors
        # - Predictors are irrelevant to Actual Data, so the solution implemented fills
        #   missing predictors with a downup strategy. Filling reduces the likelihood that
        #   rows will be dropped.

        nrow_before <- nrow(actual_data)
        nrow_after  <- nrow(actual_data_prepped)
        idx_actual  <- timetk::tk_index(actual_data)

        if (nrow_after == nrow_before) {

            actual_data_prepped <- actual_data_prepped %>%
                dplyr::mutate(.index = idx_actual)

        } else if (nrow_after < nrow_before) {
            # message(stringr::str_glue("Transformations are resulting in a reduced number of rows in actual data.
            #                           Attempting to reconcile:
            #                           - Filling missing predictors in actual data to prevent NA values from causing rows to be dropped."))

            # Try to reconcile by filling in missing data
            # - Most likely cause is NA's being dropped by step_naomit()

            actual_data_fill_missing <- actual_data %>%
                tidyr::fill(dplyr::everything(), -names(actual_data_forged$outcomes), .direction = c("downup"))

            actual_data_reconcile_forged <- hardhat::forge(
                new_data = actual_data_fill_missing,
                blueprint = mld$blueprint,
                outcomes = TRUE
            )

            actual_data_reconcile_prepped <- actual_data_reconcile_forged$outcomes %>%
                dplyr::bind_cols(actual_data_reconcile_forged$predictors) %>%
                dplyr::mutate(.key = "actual")

            if (nrow(actual_data_reconcile_prepped) == nrow(actual_data) ) {
                # message("Reconciliation successful.")

                actual_data_prepped <- actual_data_reconcile_prepped %>%
                    dplyr::mutate(.index = idx_actual)

            } else {
                rlang::warn("Could not reconcile actual data. To reconcile, please remove actual data from modeltime_forecast() and add manually using bind_rows().")
                actual_data_prepped <- NULL
            }

        } else {
            rlang::warn("Transformations are causing your actual data to increase in size. To reconcile, please remove actual data from modeltime_forecast() and add manually using bind_rows().")
            actual_data_prepped <- NULL
        }

        # Combine Actual Data with Data Formatted
        if (!is.null(actual_data_prepped)) {
            target_sym <- rlang::sym(names(actual_data_prepped)[1])

            # Issue #228 - fix .pred_res
            data_formatted <- actual_data_prepped %>%
                dplyr::bind_rows(data_formatted)


            if (".pred_res" %in% colnames(data_formatted)) {
                data_formatted <- data_formatted %>%
                    dplyr::rename(.pred = .pred_res)
            }

            if (".pred_res" %in% nms_final) {
                nms_final <- stringr::str_replace(nms_final, ".pred_res", ".pred")
            }

            data_formatted <- data_formatted %>%
                dplyr::mutate(.pred = ifelse(is.na(.pred), !! target_sym, .pred)) %>%
                dplyr::select(!!! rlang::syms(nms_final))
        }

    }

    # FINALIZE

    # Issue #228 - fix .pred_res
    if (".pred_res" %in% colnames(data_formatted)) {
        data_formatted <- data_formatted %>%
            dplyr::rename(.pred = .pred_res)
    }

    ret <- data_formatted %>%
        dplyr::rename(.value = .pred) %>%
        dplyr::select(.key, .index, .value) %>%
        dplyr::mutate(.key = factor(.key, levels = c("actual", "prediction")))

    # Keep Data
    act_tbl  <- NULL
    pred_tbl <- NULL
    if (keep_data) {

        if (!is.null(actual_data) && bind_actual) {
            act_tbl <- ret %>%
                dplyr::filter(.key == "actual") %>%
                dplyr::bind_cols(actual_data_unprocessed)
        }

        pred_tbl <- ret %>%
            dplyr::filter(.key == "prediction") %>%
            dplyr::bind_cols(new_data_unprocessed)

        ret <- dplyr::bind_rows(act_tbl, pred_tbl)

    }

    if (arrange_index) {
        ret <- ret %>%
            dplyr::arrange(.key, .index)
    }

    return(ret)

}



detect_net <- function(object){
    if (inherits(object, "_fishnet") || inherits(object, "_elnet") || inherits(object, "_multnet") || inherits(object, "_lognet")){
        res <- TRUE
    } else {
        res <- FALSE
    }
    return(res)
}

Try the modeltime package in your browser

Any scripts or data that you put into this service are public.

modeltime documentation built on Sept. 2, 2023, 5:06 p.m.