#' Train a model across horizons and validation datasets
#'
#' Train a user-defined forecast model for each horizon, 'h', and across the validation
#' datasets, 'd'. If \code{method = "direct"}, a total of 'h' * 'd' models are trained.
#' If \code{method = "multi_output"}, a total of 1 * 'd' models are trained.
#' These models can be trained in parallel with the \code{future} package.
#'
#' @param lagged_df An object of class 'lagged_df' from \code{\link{create_lagged_df}}.
#' @param windows An object of class 'windows' from \code{\link{create_windows}}.
#' @param model_name A name for the model.
#' @param model_function A user-defined wrapper function for model training that takes the following
#' arguments: (1) a horizon-specific data.frame made with \code{create_lagged_df(..., type = "train")}
#' (i.e., the dataset(s) stored in \code{lagged_df}) and, optionally, (2) any number of additional named arguments
#' which can be passed in \code{...} in this function.
#' @param ... Optional. Named arguments passed into the user-defined \code{model_function}.
#' @param use_future Boolean. If \code{TRUE}, the \code{future} package is used for training models in parallel.
#' The models will train in parallel across either (1) model forecast horizons or (b) validation windows,
#' whichever is longer (i.e., \code{length(create_lagged_df())} or \code{nrow(create_windows())}). The user
#' should run \code{future::plan(future::multiprocess)} or similar prior to this function to train these models
#' in parallel.
#' @param python Boolean. If \code{TRUE}, the \code{reticulate} package is used for model training.
#' @return An S3 object of class 'forecast_model': A nested list of trained models. Models can be accessed with
#' \code{my_trained_model$horizon_h$window_w$model} where 'h' gives the forecast horizon and 'w' gives
#' the validation dataset window number from \code{create_windows()}.
#'
#' @section Methods and related functions:
#'
#' The output of \code{train_model} can be passed into
#'
#' \itemize{
#' \item \code{\link{return_error}}
#' \item \code{\link{return_hyper}}
#' }
#'
#' and has the following generic S3 methods
#'
#' \itemize{
#' \item \code{\link[=predict.forecast_model]{predict}}
#' \item \code{\link[=plot.training_results]{plot}} (from \code{predict.forecast_model(data = create_lagged_df(..., type = "train"))})
#' \item \code{\link[=plot.forecast_results]{plot}} (from \code{predict.forecast_model(data = create_lagged_df(..., type = "forecast"))})
#' }
#' @example /R/examples/example_train_model.R
#' @export
train_model <- function(lagged_df, windows, model_name, model_function, ..., use_future = FALSE, python = FALSE) {
if (missing(lagged_df) || !methods::is(lagged_df, "lagged_df")) {
stop("The 'lagged_df' argument takes an object of class 'lagged_df' as input. Run create_lagged_df() first.")
}
if (missing(windows) || !methods::is(windows, "windows")) {
stop("The 'windows' argument takes an object of class 'windows' as input. Run create_windows() first.")
}
if (missing(model_name)) {
stop("Enter a model name for the 'model_name' argument.")
}
if (missing(model_function) || !is.function(model_function)) {
stop("The 'model_function' argument takes a user-defined model training function as input.")
}
row_indices <- attributes(lagged_df)$row_indices
date_indices <- attributes(lagged_df)$date_indices
horizons <- attributes(lagged_df)$horizons
data_start <- attributes(lagged_df)$data_start
data_stop <- attributes(lagged_df)$data_stop
skeleton <- attributes(lagged_df)$skeleton
# These are the arguments from the user-defined modeling function passed in train_model() with ...
# which is optional but potentially convenient for the user who can avoid re-defining the modeling function
# in repeated calls to train_model(). The arguments in ... will be passed as a named list in do.call(). This
# local scoping within future_lapply() appears to be necessary because global arguments passed in ...
# aren't being found by the future package when use_future = TRUE.
n_args <- ...length()
if (n_args > 0) {
model_function_args <- as.list(substitute(list(...)))[-1]
# Global objects passed in ... need to be evaluated prior to any future calls. For instance, if a
# custom ... model argument outcome_col = x is given, x needs to be evaluated before model_function_args
# is passed in future_lapply().
model_function_args <- lapply(model_function_args, eval)
}
#----------------------------------------------------------------------------
# The default future behavior is to parallelize the model training over the longer dimension: (a) number of
# forecast horizons or (b) number of validation windows. This is due to a current limitation
# in the future package on changing object size limitations for nested futures where
# "options(globals.maxSize.default = Inf)" isn't recognized.
if (isTRUE(use_future)) {
if (length(horizons) >= nrow(windows)) {
lapply_across_horizons <- future.apply::future_lapply
lapply_across_val_windows <- base::lapply
} else {
lapply_across_horizons <- base::lapply
lapply_across_val_windows <- future.apply::future_lapply
}
} else {
lapply_across_horizons <- base::lapply
lapply_across_val_windows <- base::lapply
}
#----------------------------------------------------------------------------
# Seq along model forecast horizon > cross-validation windows.
data_out <- lapply_across_horizons(lagged_df, function(data, future.seed, ...) { # model forecast horizon.
horizon <- attributes(data)$horizons
model_plus_valid_data <- lapply_across_val_windows(1:nrow(windows), function(i, future.seed, ...) { # validation windows within model forecast horizon.
window_length <- windows[i, "window_length"]
if (is.null(date_indices)) {
valid_indices <- windows[i, "start"]:windows[i, "stop"]
valid_indices_date <- NULL
} else {
# When create_lagged_df(..., groups = NULL, keep_rows = FALSE), the validation indices need an offset to account for the fact that
# validation windows are selected where row_indices %in% valid_indices which maps back to the input dataset to create_lagged_df()
# which will have 1:max(lookback) more rows than the dataset that comes out of create_lagged_df(..., groups = NULL, keep_rows = FALSE).
valid_indices <- min(row_indices) - 1 + which(date_indices >= windows[i, "start"] & date_indices <= windows[i, "stop"])
valid_indices_date <- date_indices[date_indices >= windows[i, "start"] & date_indices <= windows[i, "stop"]]
}
# A window length of 0 that spans the entire dataset removes the nested cross-validation and
# trains on all input data in lagged_df.
if (window_length == 0 && (windows[i, "start"] == data_start) && (windows[i, "stop"] == data_stop)) {
# Python data.frame preparation.
if (python) {
horizon <- as.integer(attributes(data)$horizons)
data <- reticulate::r_to_py(data)
reticulate::py_set_attr(data, "horizons", horizon)
}
# Model training over all data.
if (n_args == 0) { # No user-defined model args passed in ...
model <- try(model_function(data))
} else {
model <- try(do.call(model_function, append(list(data), model_function_args)))
}
} else { # Model training with external block-contiguous cv.
# These validation indices will always start at 1. They index with respect to the
# data passed into model_function() as opposed to the dataset passed in create_lagged_df().
# These indices, stored as an attribute, are for manual filtering any skeleton lagged_dfs in model_function().
validation_indices <- which(row_indices %in% valid_indices)
attributes(data)$validation_indices <- validation_indices
data <- data[-(validation_indices), , drop = FALSE]
# Python data.frame preparation.
if (python) {
horizon <- as.integer(attributes(data)$horizons)
validation_indices <- as.integer(validation_indices - 1) # Match Python's 0-based indexing.
data <- reticulate::r_to_py(data)
reticulate::py_set_attr(data, "horizons", horizon)
reticulate::py_set_attr(data, "validation_indices", paste0("range(", min(validation_indices), ", ", max(validation_indices), ", 1)"))
}
if (n_args == 0) { # No user-defined model args passed in ...
model <- try(model_function(data))
} else {
model <- try({
do.call(model_function, append(list(data), model_function_args))
})
}
}
if (methods::is(model, "try-error")) {
warning(paste0("A model returned class 'try-error' for validation window ", i))
}
list("model" = model, "window" = i, "window_length" = window_length, "valid_indices" = valid_indices,
"date_indices" = valid_indices_date)
}, future.seed = 1) # End model training across nested cross-validation windows for the horizon in "data".
names(model_plus_valid_data) <- paste0("window_", 1:nrow(windows))
attr(model_plus_valid_data, "horizon") <- horizon
model_plus_valid_data
}, future.seed = 1, future.packages = "reticulate") # End training across horizons.
attr(data_out, "model_name") <- model_name
attr(data_out, "horizons") <- horizons
attr(data_out, "outcome_col") <- attributes(lagged_df)$outcome_col
attr(data_out, "outcome_cols") <- attributes(lagged_df)$outcome_cols
attr(data_out, "outcome_name") <- attributes(lagged_df)$outcome_name
attr(data_out, "outcome_names") <- attributes(lagged_df)$outcome_names
attr(data_out, "outcome_levels") <- attributes(lagged_df)$outcome_levels
attr(data_out, "row_indices") <- row_indices
attr(data_out, "date_indices") <- date_indices
attr(data_out, "frequency") <- attributes(lagged_df)$frequency
attr(data_out, "data_stop") <- attributes(lagged_df)$data_stop
attr(data_out, "groups") <- attributes(lagged_df)$groups
attr(data_out, "method") <- attributes(lagged_df)$method
attr(data_out, "skeleton") <- skeleton
attr(data_out, "python") <- python
class(data_out) <- c("forecast_model", class(data_out))
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Predict on validation datasets or forecast
#'
#' Predict with a 'forecast_model' object from \code{train_model()}. If \code{data = create_lagged_df(..., type = "train")},
#' predictions are returned for the outer-loop nested cross-validation datasets.
#' If \code{data} is an object of class 'lagged_df' from \code{create_lagged_df(..., type = "forecast")},
#' predictions are returned for the horizons specified in \code{create_lagged_df(horizons = ...)}.
#'
#' @param ... One or more trained models from \code{train_model()}.
#' @param prediction_function A list of user-defined prediction functions with length equal to
#' the number of models supplied in \code{...}. The prediction functions
#' take 2 required positional arguments--(1) a 'forecast_model' object from \code{train_model()} and (2) a
#' data.frame of model features from \code{create_lagged_df()}. For numeric outcomes and \code{method = "direct"}, the function should \code{return()}
#' 1- or 3-column data.frame of model predictions. If the prediction function returns a 1-column data.frame, point forecasts are assumed.
#' If the prediction function returns a 3-column data.frame, lower and upper forecast bounds are assumed (the
#' order and names of the 3 columns does not matter). For factor outcomes and \code{method = "direct"}, the function should \code{return()}
#' (1) 1-column data.frame of the model-predicted factor level or (2) an L-column data.frame of class probabilities where
#' 'L' equals the number of levels in the outcome; columns should be ordered, from left to right, the same as
#' \code{levels(data$outcome)} which is the default behavior for most \code{predict(..., type = "prob")} functions.
#' Column names do not matter. For numeric outcomes and \code{method = "multi_output"}, the function should \code{return()} and
#' h-column data.frame of model predictions--1 column for each horizon. Forecast intervals and factor outcomes are not currently
#' supported with \code{method = "multi_output"}.
#' @param data If \code{data} is a training dataset from \code{create_lagged_df(..., type = "train")}, validation dataset
#' predictions are returned; else, if \code{data} is a forecasting dataset from \code{create_lagged_df(..., type = "forecast")},
#' forecasts from horizons 1:h are returned.
#' @return If \code{data = create_lagged_df(..., type = "forecast")}, an S3 object of class 'training_results'. If
#' \code{data = create_lagged_df(..., type = "forecast")}, an S3 object of class 'forecast_results'.
#'
#' \strong{Columns in returned 'training_results' data.frame:}
#' \itemize{
#' \item \code{model}: User-supplied model name in \code{train_model()}.
#' \item \code{model_forecast_horizon}: The direct-forecasting time horizon that the model was trained on.
#' \item \code{window_length}: Validation window length measured in dataset rows.
#' \item \code{window_number}: Validation dataset number.
#' \item \code{valid_indices}: Validation dataset row names from \code{attributes(create_lagged_df())$row_indices}.
#' \item \code{date_indices}: If given and \code{method = "direct"}, validation dataset date indices from \code{attributes(create_lagged_df())$date_indices}.
#' If given and \code{method = "multi_output"}, date_indices represents the date of the forecast.
#' \item \code{"groups"}: If given, the user-supplied groups in \code{create_lagged_df()}.
#' \item \code{"outcome_name"}: The target being forecasted.
#' \item \code{"outcome_name"_pred}: The model predictions.
#' \item \code{"outcome_name"_pred_lower}: If given, the lower prediction bounds returned by the user-supplied prediction function.
#' \item \code{"outcome_name"_pred_upper}: If given, the upper prediction bounds returned by the user-supplied prediction function.
#' \item \code{forecast_indices}: If \code{method = "multi_output"}, the validation index of the h-step-ahead forecast.
#' \item \code{forecast_date_indices}: If \code{method = "multi_output"}, the validation date index of the h-step-ahead forecast.
#' }
#'
#' \strong{Columns in returned 'forecast_results' data.frame:}
#' \itemize{
#' \item \code{model}: User-supplied model name in \code{train_model()}.
#' \item \code{model_forecast_horizon}: If \code{method = "direct"}, the direct-forecasting time horizon that the model was trained on.
#' \item \code{horizon}: Forecast horizons, 1:h, measured in dataset rows.
#' \item \code{window_length}: Validation window length measured in dataset rows.
#' \item \code{forecast_period}: The forecast period in row indices or dates. The forecast period starts at either \code{attributes(create_lagged_df())$data_stop + 1} for row indices or \code{attributes(create_lagged_df())$data_stop + 1 * frequency} for date indices.
#' \item \code{"groups"}: If given, the user-supplied groups in \code{create_lagged_df()}.
#' \item \code{"outcome_name"}: The target being forecasted.
#' \item \code{"outcome_name"_pred}: The model forecasts.
#' \item \code{"outcome_name"_pred_lower}: If given, the lower forecast bounds returned by the user-supplied prediction function.
#' \item \code{"outcome_name"_pred_upper}: If given, the upper forecast bounds returned by the user-supplied prediction function.
#' }
#'
#' @example /R/examples/example_predict_train_model.R
#' @export
predict.forecast_model <- function(..., prediction_function = list(NULL), data) {
#----------------------------------------------------------------------------
model_list <- list(...)
if (!all(unlist(lapply(model_list, function(x) {class(x)[1]})) %in% "forecast_model")) {
stop("The '...' argument takes 1 or more objects of class 'forecast_model' as input. Run train_model() first.
Also, the arguments 'prediction_function' and 'data' need to be named and not positional because they
follow '...'.")
}
if (length(model_list) != length(prediction_function)) {
stop("The number of prediction functions does not equal the number of forecast models.")
}
if (!methods::is(data, "lagged_df")) {
stop("The 'data' argument takes a training or forecasting dataset of class 'lagged_df' from create_lagged_df().")
}
#----------------------------------------------------------------------------
type <- attributes(data)$type # train or forecast.
method <- attributes(model_list[[1]])$method
horizons <- attributes(model_list[[1]])$horizons
outcome_col <- attributes(model_list[[1]])$outcome_col
outcome_cols <- attributes(model_list[[1]])$outcome_cols
outcome_name <- attributes(model_list[[1]])$outcome_name
outcome_names <- attributes(model_list[[1]])$outcome_names
outcome_levels <- attributes(model_list[[1]])$outcome_levels
row_indices <- attributes(model_list[[1]])$row_indices
date_indices <- attributes(model_list[[1]])$date_indices
frequency <- attributes(model_list[[1]])$frequency
groups <- attributes(model_list[[1]])$groups
skeleton <- attributes(model_list[[1]])$skeleton
#----------------------------------------------------------------------------
if (type == "train") {
data_stop <- attributes(model_list[[1]])$data_stop
} else {
data_stop <- attributes(data)$data_stop
}
#----------------------------------------------------------------------------
# Seq along forecast model[i] > model forecast horizon[j] > validation window number[k].
data_model <- lapply(seq_along(model_list), function(i) {
prediction_fun <- prediction_function[[i]]
python <- attributes(model_list[[i]])$python
data_horizon <- lapply(seq_along(model_list[[i]]), function(j) {
data_win_num <- lapply(seq_along(model_list[[i]][[j]]), function(k) {
data_results <- model_list[[i]][[j]][[k]]
# Predict on training data or the forecast dataset?
if (type == "train") { # Nested cross-validation.
# These validation indices will always start at 1. They index with respect to the
# data passed into predict_function() as opposed to the dataset passed in create_lagged_df().
# These indices, stored as an attribute, are for manual filtering any skeleton lagged_dfs in predict_function().
validation_indices <- which(row_indices %in% data_results$valid_indices)
x_valid <- data[[j]][validation_indices, -(outcome_cols), drop = FALSE]
y_valid <- data[[j]][validation_indices, outcome_cols, drop = FALSE] # Actuals in function return.
attributes(x_valid)$horizons <- attributes(data[[j]])$horizons
attributes(x_valid)$validation_indices <- validation_indices
# Python data.frame preparation.
if (python) {
horizon <- as.integer(attributes(data[[j]])$horizons)
validation_indices <- as.integer(validation_indices - 1) # Match Python's 0-based indexing.
py_x_valid <- reticulate::r_to_py(x_valid)
reticulate::py_set_attr(py_x_valid, "horizons", horizon)
reticulate::py_set_attr(py_x_valid, "validation_indices", paste0("range(", min(validation_indices), ", ", max(validation_indices), ", 1)"))
data_pred <- try(prediction_fun(data_results$model, py_x_valid)) # Nested cross-validation.
} else {
data_pred <- try(prediction_fun(data_results$model, x_valid)) # Nested cross-validation.
}
if (methods::is(data_pred, "try-error")) {
warning(paste0("Model '", attributes(model_list[[i]])$model_name, "' returned class 'try-error' for model ", j, " in validation window ", k))
}
if (!is.null(groups)) {
data_groups <- x_valid[, groups, drop = FALSE] # Save out group identifiers.
}
} else { # Forecast.
forecast_period <- data[[j]][, "index", drop = FALSE]
names(forecast_period) <- "forecast_period"
forecast_horizons <- data[[j]][, "horizon", drop = FALSE]
data_for_forecast <- data[[j]][, !names(data[[j]]) %in% c("index", "horizon"), drop = FALSE] # Remove ID columns for predict().
# Python data.frame preparation.
if (python) {
py_data_for_forecast <- reticulate::r_to_py(data_for_forecast)
data_pred <- try(prediction_fun(data_results$model, py_data_for_forecast)) # Nested cross-validation.
} else {
data_pred <- try(prediction_fun(data_results$model, data_for_forecast)) # User-defined prediction function.
}
if (methods::is(data_pred, "try-error")) {
warning(paste0("Model '", attributes(model_list[[i]])$model_name, "' returned class 'try-error' for model ", j, " in validation window ", k))
}
if (!is.null(groups)) {
data_groups <- data_for_forecast[, groups, drop = FALSE]
}
} # End forecast.
#----------------------------------------------------------------------
# Check the return of the user-defined predict() function.
if (is.null(outcome_levels) && method == "direct" && !ncol(data_pred) %in% c(1, 3)) { # Numeric outcome.
stop("For numeric outcomes, the user-defined prediction function needs to return 1- or 3-column data.frame of model predictions.")
}
if (is.null(outcome_levels) && method == "multi_output" && ncol(data_pred) != length(horizons)) { # Numeric outcome.
stop("For numeric outcomes, the user-defined prediction function needs to return a length(horizons)-column data.frame of model predictions.")
}
if (!is.null(outcome_levels)) { # Factor outcome.
if (ncol(data_pred) == 1 && method == "direct" && !methods::is(data_pred[, 1], "factor")) {
stop("For factor outcomes where the predicted factor level is returned (e.g., predict(..., type = 'response')), the user-defined
prediction function needs to return 1-column data.frame of factor predictions. If returning class
probabilities, the number of data.frame columns should equal the number of factor levels and be in the same order as levels(data$outcome).")
}
if (method == "direct" && ncol(data_pred) > 1 && ncol(data_pred) != length(outcome_levels)) {
stop("For factor outcomes where class probabilities are returned (e.g., predict(..., type = 'prob')),
the number of data.frame columns returned should equal the number of factor levels.")
}
}
#----------------------------------------------------------------------
# Format the returned data.frames from predict().
if (method == "direct") {
if (is.null(outcome_levels)) { # Numeric outcome.
if (ncol(data_pred) == 1) {
names(data_pred) <- paste0(outcome_names, "_pred")
} else { # Confidence/credible forecast intervals.
# Find the lower, point, and upper forecasts and order the columns accordingly.
data_pred <- data_pred[order(unlist(lapply(data_pred, mean, na.rm = TRUE)))]
names(data_pred) <- c(paste0(outcome_names, "_pred_lower"), paste0(outcome_names, "_pred"), paste0(outcome_names, "_pred_upper"))
# Re-order so that the point forecast is first.
data_pred <- data_pred[, c(2, 1, 3)]
}
} else { # Factor outcome.
if (ncol(data_pred) == 1) { # A predicted factor level.
names(data_pred) <- paste0(outcome_names, "_pred")
} else { # Predicted class probabilities.
names(data_pred) <- outcome_levels
}
} # End reformatting of the returned data.frame of predictions.
} else if (method == "multi_output") {
if (is.null(outcome_levels)) { # Numeric outcome.
names(data_pred) <- paste0(outcome_names, "_pred")
} else {
stop("Multi-output models do not currently support factor outcomes.")
}
}
#----------------------------------------------------------------------
model_name <- attributes(model_list[[i]])$model_name
if (type == "train") { # Nested cross-validation.
data_temp <- data.frame("model" = model_name,
"model_forecast_horizon" = horizons[j],
"window_length" = data_results$window_length,
"window_number" = data_results$window,
"valid_indices" = data_results$valid_indices)
data_temp$date_indices <- data_results$date_indices
if (is.null(groups)) {
data_temp <- cbind(data_temp, y_valid, data_pred)
} else {
data_temp <- cbind(data_temp, data_groups, y_valid, data_pred)
}
} else { # Forecast.
data_temp <- data.frame("model" = model_name,
"model_forecast_horizon" = horizons[j],
"horizon" = forecast_horizons,
"window_length" = data_results$window_length,
"window_number" = data_results$window,
"forecast_period" = forecast_period)
if (is.null(groups)) {
data_temp <- cbind(data_temp, data_pred)
} else {
data_temp <- cbind(data_temp, data_groups, data_pred)
}
} # End forecast results.
data_temp$model <- as.character(data_temp$model) # Coerce to remove any factor levels.
data_temp
}) # End cross-validation window predictions.
data_win_num <- suppressMessages(dplyr::bind_rows(data_win_num))
}) # End horizon-level predictions.
data_horizon <- suppressMessages(dplyr::bind_rows(data_horizon))
}) # End model-level predictions.
data_out <- suppressMessages(dplyr::bind_rows(data_model))
# For multi-output models, the data (a) need to be reshaped from a wide to long format
# and (b) the validation indices that represent the forecast origin need to be changed
# to the index for the predicted time period.
if (method == "multi_output") {
if (type == "train") {
data_actual_temp <- dplyr::select(data_out, -dplyr::ends_with("_pred"))
data_pred_temp <- dplyr::select(data_out, -!!outcome_names)
data_actual_temp <- tidyr::pivot_longer(data_actual_temp, cols = !!outcome_names, names_to = "remove", values_to = outcome_name)
data_actual_temp$remove <- NULL
data_pred_temp <- tidyr::pivot_longer(data_pred_temp, cols = dplyr::ends_with("_pred"), values_to = paste0(outcome_name, "_pred"))
data_actual_temp[, paste0(outcome_name, "_pred")] <- data_pred_temp[, paste0(outcome_name, "_pred")]
data_out <- data_actual_temp
data_out$model_forecast_horizon <- rep(horizons, nrow(data_out) / length(horizons))
data_out$forecast_indices <- data_out$valid_indices
data_out$forecast_indices <- data_out$model_forecast_horizon + data_out$forecast_indices
data_out$horizons <- NULL
names(data_out)[names(data_out) == "model_forecast_horizon"] <- "horizon"
if (!is.null(date_indices)) {
data_out$forecast_date_indices <- as.Date(unlist(purrr::map(1:nrow(data_out), function(i) {
base::seq(data_out$date_indices[i], by = frequency, length.out = data_out$horizon[i] + 1)[data_out$horizon[i] + 1]})),
origin = "1970-01-01")
}
data_out <- dplyr::arrange(data_out, .data$model, .data$valid_indices, .data$forecast_indices, .data$horizon)
} else if (type == "forecast") {
forecast_period <- data.frame(data_out[, "forecast_period", drop = FALSE], stringsAsFactors = FALSE)
forecast_period <- data.frame(strsplit(forecast_period$forecast_period, ", "), stringsAsFactors = FALSE)
names(forecast_period) <- "forecast_period"
forecast_period$forecast_period <- if (is.null(date_indices)){as.integer(forecast_period$forecast_period)} else {as.Date(forecast_period$forecast_period)}
data_out <- tidyr::pivot_longer(data_out, cols = dplyr::ends_with("_pred"),
names_to = "remove", values_to = paste0(outcome_name, "_pred"))
data_out$remove <- NULL
data_out$model_forecast_horizon <- NULL
data_out$horizon <- horizons
data_out$forecast_period <- forecast_period$forecast_period
}
}
data_out <- as.data.frame(data_out)
row.names(data_out) <- 1:nrow(data_out)
attr(data_out, "method") <- method
attr(data_out, "horizons") <- horizons
attr(data_out, "outcome_col") <- outcome_col
attr(data_out, "outcome_cols") <- outcome_cols
attr(data_out, "outcome_name") <- outcome_name
attr(data_out, "outcome_names") <- outcome_names
attr(data_out, "outcome_levels") <- outcome_levels
attr(data_out, "row_indices") <- row_indices
attr(data_out, "date_indices") <- date_indices
attr(data_out, "frequency") <- frequency
attr(data_out, "data_stop") <- data_stop
attr(data_out, "groups") <- groups
if (type == "train") {
class(data_out) <- c("training_results", "forecast_model", class(data_out))
} else {
class(data_out) <- c("forecast_results", "forecast_model", class(data_out))
}
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
# Residuals method.
#' Return model residuals
#'
#' @param object An object of class 'training_results' from running \code{predict()} on a training dataset.
#' @param ... Not used.
#' @return A data.frame of model residuals of class 'training_residuals'.
#' @export
residuals <- function (object, ...) {
UseMethod("residuals", object)
}
#' @export
residuals.training_results <- function(object, ...) {
outcome_name <- attributes(object)$outcome_name
prediction_name <- paste0(outcome_name, "_pred")
outcome_levels <- attributes(object)$outcome_levels
groups <- attributes(object)$groups
has_dates <- !is.null(attributes(object)$frequency)
if (is.null(outcome_levels)) { # Numeric outcomes.
object$residuals <- object[, outcome_name] - object[, prediction_name]
if (has_dates) {
object <- object[, c("model", groups, "model_forecast_horizon", "date_indices", "residuals")]
} else {
object <- object[, c("model", groups, "model_forecast_horizon", "valid_indices", "residuals")]
}
}
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability?
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(object) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
if (factor_level) {
# Binary accuracy/residual. A residual of 1 is an incorrect classification.
object$residuals <- ifelse(as.character(object[, outcome_name, drop = TRUE]) != as.character(object[, paste0(outcome_name, "_pred"), drop = TRUE]), 1, 0)
if (has_dates) {
object <- object[, c("model", groups, "model_forecast_horizon", "date_indices", "residuals")]
} else {
object <- object[, c("model", groups, "model_forecast_horizon", "valid_indices", "residuals")]
}
}
if (factor_prob) {
outcome_col <- which(names(object) == outcome_name) + 1
object[, outcome_col:ncol(object)][] <- lapply(object[, outcome_col:ncol(object)], function(x) {
residuals <- 1 - x
})
names(object)[outcome_col:ncol(object)] <- paste0(names(object[, outcome_col:ncol(object)]), "_residuals")
if (has_dates) {
object <- object[, c("model", groups, "model_forecast_horizon", "date_indices", names(object)[outcome_col:ncol(object)])]
} else {
object <- object[, c("model", groups, "model_forecast_horizon", "valid_indices", names(object)[outcome_col:ncol(object)])]
}
}
}
if (has_dates) {
names(object)[names(object) == "date_indices"] <- "index"
} else {
names(object)[names(object) == "valid_indices"] <- "index"
}
attr(object, "groups") <- groups
class(object) <- c("training_residuals", "data.frame")
return(object)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot an object of class training_results
#'
#' Several diagnostic plots can be returned to assess the quality of the forecasts
#' based on predictions on the validation datasets.
#'
#' @param x An object of class 'training_results' from \code{predict.forecast_model()}.
#' @param type Plot type. The default plot is "prediction" for validation dataset predictions.
#' @param models Optional. Filter results by user-defined model name from \code{train_model()}.
#' @param horizons Optional. A numeric vector of model forecast horizons to filter results by horizon-specific model.
#' @param windows Optional. A numeric vector of window numbers to filter results.
#' @param valid_indices Optional. A numeric or date vector to filter results by validation row indices or dates.
#' @param group_filter Optional. A string for filtering plot results for grouped time series
#' (e.g., \code{"group_col_1 == 'A'"}). The results are passed to \code{dplyr::filter()} internally.
#' @param facet Optional. For numeric outcomes, a formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series)
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' @param keep_missing Boolean. If \code{TRUE}, predictions are plotted for indices/dates where the outcome is missing.
#' @param ... Not used.
#' @return Diagnostic plots of class 'ggplot'.
#' @export
plot.training_results <- function(x,
type = c("prediction", "residual", "forecast_stability"),
facet = horizon ~ model, models = NULL, horizons = NULL,
windows = NULL, valid_indices = NULL, group_filter = NULL, keep_missing = FALSE, ...) { # nocov start
#----------------------------------------------------------------------------
data <- x
rm(x)
type <- type[1] # The default plot is predicting over historical validation windows.
if (type == "forecast_stability") {
if (!xor(is.null(windows), is.null(valid_indices))) {
stop("Select either (a) one or more validation windows, 'windows', or (b) a range of dataset rows, 'valid_indices', to reduce plot size.")
}
}
if (!is.null(attributes(data)$group) & !type %in% c("prediction", "residual")) {
stop("Only 'prediction' and 'residual' plots are currently available for grouped models.")
}
if (!is.null(attributes(data)$outcome_levels) & !type %in% c("prediction", "residual")) {
stop("Only 'prediction' and 'residual' plots are currently available for models with factors as outcomes.")
}
#----------------------------------------------------------------------------
method <- attributes(data)$method
outcome_col <- attributes(data)$outcome_col
outcome_name <- attributes(data)$outcome_name
outcome_levels <- attributes(data)$outcome_levels
date_indices <- attributes(data)$date_indices
frequency <- attributes(data)$frequency
groups <- attributes(data)$group
window_custom <- all(data$window_length == "custom")
data <- as.data.frame(data) # Coerce to remove new vctrs warning about mismatched classes when joining data.
#----------------------------------------------------------------------------
if (method == "multi_output") {
data$valid_indices <- data$forecast_indices # The plot indices are the forecast period, not the forecast origin.
if (!is.null(date_indices)) {
data$date_indices <- data$forecast_date_indices # The plot indices are the forecast period, not the forecast origin.
}
}
#----------------------------------------------------------------------------
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability.
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(data) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
if (type == "residual" && factor_prob) {
stop("Residual plots with predicted class probabilities are not currently supported.")
}
if (type == "residual" && !is.null(groups)) {
stop("Residual plots with grouped data are not currently supported.")
}
}
#----------------------------------------------------------------------------
if (isFALSE(keep_missing) && is.null(outcome_levels)) {
# Set the predicted values to NA to keep any missing data in the actual
# time series so gaps in data collection are not connected with a line in the plot.
data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred")] <- NA
if (all(any(grepl("_pred_lower", names(data))), any(grepl("_pred_upper", names(data))))) {
data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred_lower")] <- NA
data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred_upper")] <- NA
}
} else if (isFALSE(keep_missing) && !is.null(outcome_levels)) {
data <- data[!is.na(data[, outcome_name]), ]
}
#----------------------------------------------------------------------------
# Residual calculations.
if (is.null(outcome_levels)) { # Numeric outcome.
data$residual <- data[, outcome_name] - data[, paste0(outcome_name, "_pred")]
} else { # Factor outcome.
if (factor_level) {
# Binary accuracy/residual. A residual of 1 is an incorrect classification.
data$residual <- ifelse(as.character(data[, outcome_name, drop = TRUE]) != as.character(data[, paste0(outcome_name, "_pred"), drop = TRUE]), 1, 0)
} else { # Class probabilities were predicted.
data_residual <- 1 - data[, names(data) %in% outcome_levels]
names(data_residual) <- paste0(names(data_residual), "_residual")
data <- dplyr::bind_cols(data, data_residual)
rm(data_residual)
}
}
#----------------------------------------------------------------------------
# Filtering results based on user input.
models <- if (is.null(models)) {unique(data$model)} else {models}
if (method == "direct") {
horizons <- if (is.null(horizons)) {unique(data$model_forecast_horizon)} else {horizons}
} else {
horizons <- if (is.null(horizons)) {unique(data$horizon)} else {horizons}
}
windows <- if (is.null(windows)) {unique(data$window_number)} else {windows}
if (is.null(date_indices)) {
valid_indices <- if (is.null(valid_indices)) {unique(data$valid_indices)} else {valid_indices}
} else {
valid_indices <- if (is.null(valid_indices)) {unique(data$date_indices)} else {valid_indices}
}
#----------------------------------------------------------------------------
data_plot <- data
#----------------------------------------------------------------------------
# Rename for consistency with plotting code.
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon
data_plot$model_forecast_horizon <- NULL
}
data_plot <- data_plot[data_plot$model %in% models & data_plot$horizon %in% horizons &
data_plot$window_number %in% windows, ]
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
if (methods::is(valid_indices, "Date") || methods::is(valid_indices, "POSIXt")) {
data_plot <- data_plot[data_plot$date_indices %in% valid_indices, ] # Filter plots by dates.
data_plot$index <- data_plot$date_indices
} else {
data_plot <- data_plot[data_plot$valid_indices %in% valid_indices, ] # Filter plots by row indices.
if (!is.null(date_indices)) {
data_plot$index <- data_plot$date_indices
} else {
data_plot$index <- data_plot$valid_indices
}
}
#----------------------------------------------------------------------------
# Set up ggplot color and group parameters.
if (is.null(outcome_levels)) { # Numeric outcomes.
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number, eval(parse(text = groups)))
}
#--------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
legend_title <- paste(temp_1, temp_2, sep = "")
legend_title <- paste(legend_title, collapse = " + ")
#--------------------------------------------------------------------------
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
# Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Prediction"
}
# Used to avoid lines spanning any gaps between validation windows.
if (all(data_plot$window_number == 1)) { # One window; no need to add the window number to the legend.
data_plot$ggplot_group <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_group <- apply(data_plot[, c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
# Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
#----------------------------------------------------------------------------
if (!is.null(date_indices)) {
# Create a dataset of points for those instances where there the outcomes are NA before and after a given instance.
# Points are needed because ggplot will not plot a 1-instance geom_line().
data_plot_point <- data_plot %>%
dplyr::group_by(.data$ggplot_color) %>%
dplyr::mutate("lag" = dplyr::lag(eval(parse(text = outcome_name)), 1),
"lead" = dplyr::lead(eval(parse(text = outcome_name)), 1)) %>%
dplyr::filter(is.na(.data$lag) & is.na(.data$lead))
data_plot <- data_plot[data_plot$date_indices %in% valid_indices, ]
# This may be an empty data.frame if every time series has 2 or more contiguous records, and
# suppressWarnings() suppresses a forcats warning.
data_plot_point <- suppressWarnings(data_plot_point[data_plot_point$date_indices %in% date_indices, ])
}
#--------------------------------------------------------------------------
} else { # Factor outcomes.
data_plot$ggplot_color_group <- apply(data_plot[, c("model", "horizon", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
if (type %in% c("prediction", "residual")) {
#--------------------------------------------------------------------------
# Melt the data for plotting.
if (is.null(outcome_levels)) { # Numeric outcome.
data_plot <- tidyr::gather(data_plot, "outcome", "value",
-!!names(data_plot)[!names(data_plot) %in% c(outcome_name, paste0(outcome_name, "_pred"))])
} else { # Factor outcome.
if (any(names(data_plot) %in% paste0(outcome_name, "_pred"))) { # A factor level was predicted.
data_plot <- tidyr::gather(data_plot, "outcome", "value",
-!!names(data_plot)[!names(data_plot) %in% c(outcome_name, paste0(outcome_name, "_pred"))])
} else { # Class probabilities were predicted.
data_plot <- suppressWarnings(tidyr::gather(data_plot, "outcome", "value",
-!!names(data_plot)[!names(data_plot) %in% c(outcome_name, outcome_levels)]))
}
}
#--------------------------------------------------------------------------
# If date indices exist, plot with them.
if (!is.null(date_indices)) {
data_plot$index <- data_plot$date_indices
}
#--------------------------------------------------------------------------
if (type == "prediction") {
if (is.null(outcome_levels)) { # Numeric outcome; plot historical predictions.
p <- ggplot()
#----------------------------------------------------------------------
# Plot actuals on the first or lowest layer.
if (is.null(groups)) { # Single time series.
p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, y = .data$value, group = .data$ggplot_group), color = "grey50")
} else { # Actuals, grouped time series.
# If faceting by group, this reduces to the single time series case so the actuals
# will be the default grey so as not to double encode the plot data.
if (any(facet_names %in% groups)) {
p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, y = .data$value), color = "grey50")
} else {
p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, y = .data$value,
group = .data$ggplot_group,
color = .data$ggplot_color), linetype = 2)
p <- p + scale_color_viridis_d()
}
}
#----------------------------------------------------------------------
# Plot predictions.
# If the plotting data.frame has both lower and upper forecasts plot these bounds.
# We'll add the shading in a lower ggplot layer so the point forecasts are on top in the final plot.
if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {
p <- p + geom_ribbon(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
fill = .data$ggplot_color, group = .data$ggplot_group, color = NULL), alpha = .25, show.legend = FALSE)
}
p <- p + geom_line(data = data_plot[data_plot$outcome != outcome_name, ], # Predictions in melted data.
aes(x = .data$index, y = .data$value, group = .data$ggplot_group, color = .data$ggplot_color),
size = 1.05, linetype = 1)
if (!is.null(facet)) {
p <- p + facet_grid(facet, drop = TRUE)
}
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
#--------------------------------------------------------------------------
#--------------------------------------------------------------------------
} else { # Factor outcome.
if (factor_prob) { # Plot class probabilities with un-grouped data.
if (is.null(groups)) {
# Only 1 actual needs to be plotted.
data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"
data_actual <- data_plot[data_plot$outcome %in% outcome_name, ]
data_pred <- data_plot[data_plot$outcome %in% outcome_levels, ]
data_actual$value <- factor(data_actual$value, levels = outcome_levels, ordered = TRUE)
data_pred$outcome <- factor(data_pred$outcome, levels = outcome_levels, ordered = TRUE)
data_pred$value <- as.numeric(data_pred$value)
p <- ggplot()
p <- p + geom_col(data = data_pred,
aes(x = .data$index, y = .data$value, color = .data$outcome, fill = .data$outcome),
position = position_stack(reverse = TRUE))
p <- p + geom_col(data = data_actual,
aes(x = .data$index, y = 1, color = .data$value, fill = .data$value))
p <- p + scale_y_continuous(limits = 0:1)
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_wrap(~ ggplot_color_group, ncol = 1, scales = "free_y")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
} else { # Plot class probabilities with grouped data.
data_plot$actual_or_pred <- NA
data_plot$actual_or_pred[data_plot$outcome == outcome_name] <- "Actual"
data_plot$actual_or_pred[data_plot$outcome != outcome_name] <- "Prediction"
data_plot$y_value <- NA
data_plot$y_value[data_plot$outcome == outcome_name] <- 1
data_plot$y_value[data_plot$outcome != outcome_name] <- data_plot$value[data_plot$outcome != outcome_name]
data_plot$y_value <- as.numeric(data_plot$y_value)
data_plot$plot_color <- NA
data_plot$plot_color[data_plot$outcome == outcome_name] <- data_plot$value[data_plot$outcome == outcome_name]
data_plot$plot_color[data_plot$outcome != outcome_name] <- data_plot$outcome[data_plot$outcome != outcome_name]
data_plot$plot_color <- factor(data_plot$plot_color, levels = outcome_levels, ordered = TRUE)
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = .data$y_value, color = .data$plot_color, fill = .data$plot_color),
position = position_stack(reverse = TRUE))
p <- p + scale_y_continuous(limits = 0:1, breaks = c(0, .5, 1))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_grid(ggplot_color_group + actual_or_pred ~ ., scales = "free_y")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
} # End plot class probabilities with grouped data.
#------------------------------------------------------------------------
} else { # Plot class levels.
if (is.null(groups)) { # Plot class levels with un-grouped data.
# Only 1 actual needs to be plotted.
data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"
data_plot$value <- factor(data_plot$value, levels = c(outcome_levels), ordered = TRUE)
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = 1, color = .data$value, fill = .data$value),
position = position_stack(reverse = TRUE))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_wrap(~ ggplot_color_group, ncol = 1, scales = "free_y")
p <- p + theme_bw() + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank())
} else { # Plot class levels with grouped data.
data_plot$actual_or_pred <- NA
data_plot$actual_or_pred[data_plot$outcome == outcome_name] <- "Actual"
data_plot$actual_or_pred[data_plot$outcome != outcome_name] <- "Prediction"
data_plot$value <- factor(data_plot$value, levels = outcome_levels, ordered = TRUE)
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = 1, color = .data$value, fill = .data$value),
position = position_stack(reverse = TRUE))
p <- p + scale_y_continuous(limits = 0:1, breaks = c(0, .5, 1))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_grid(ggplot_color_group + actual_or_pred ~ ., scales = "free_y")
p <- p + theme_bw() + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), panel.spacing = unit(0, "lines"))
}
}
} # End prediction plots for numeric and factor outcomes.
#--------------------------------------------------------------------------
} else if (type == "residual") {
if (is.null(outcome_levels)) { # Numeric outcome.
# Plot historical predictions.
p <- ggplot(data_plot[data_plot$outcome != outcome_name, ],
aes(x = .data$index, y = .data$residual,
group = .data$ggplot_group, color = .data$ggplot_color))
p <- p + geom_line(size = 1.05, linetype = 1)
p <- p + geom_hline(yintercept = 0)
p <- p + scale_color_viridis_d()
p <- p + facet_grid(facet, drop = TRUE)
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
} else { # Factor outcome.
if (is.null(groups)) {
data_plot$ggplot_color_group <- factor(as.character(data_plot$ggplot_color_group), ordered = TRUE,
levels = c(unique(as.character(data_plot$ggplot_color_group)), "Actual"))
data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"
}
p <- ggplot()
# Plot predictions to avoid duplicate residuals in plots.
p <- p + geom_tile(data = data_plot[data_plot$outcome != outcome_name, ], aes(x = .data$index, y = .data$ggplot_color_group,
fill = ordered(.data$residual)))
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
}
} # End residual plot.
#--------------------------------------------------------------------------
# Plot labels.
if (type == "prediction") {
if (is.null(outcome_levels)) { # Numeric outcome.
if (!is.null(groups)) { # Grouped time series.
if (any(facet_names %in% groups)) {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
ggtitle("Forecasts vs. Actuals Through Time")
} else {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
ggtitle("Forecasts vs. Actuals Through Time", subtitle = c("Dashed lines are actuals"))
}
} else { # Single time series.
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
ggtitle("Forecasts vs. Actuals Through Time")
}
} else { # Factor outcome.
if (factor_prob) {
p <- p + xlab("Dataset index") + ylab("Outcome probability") + labs(color = "Outcome", fill = "Outcome") +
ggtitle("Forecasts vs. Actuals Through Time")
} else {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Outcome", fill = "Outcome") +
ggtitle("Forecasts vs. Actuals Through Time")
}
}
} else if (type == "residual") {
if (is.null(outcome_levels)) {
p <- p + xlab("Dataset index") + ylab("Residual") + labs(color = NULL, group = NULL) +
ggtitle("Forecast Error Through Time")
} else {
if (is.null(groups)) {
p <- p + xlab("Dataset index") + ylab("Residual and model") + labs(fill = "Residual") +
ggtitle("Forecast Error Through Time")
} else {
p <- p + xlab("Dataset index") + ylab("Residual, model, and group") + labs(fill = "Residual") +
ggtitle("Forecast Error Through Time")
}
}
}
return(suppressWarnings(p))
}
#----------------------------------------------------------------------------
if (type %in% c("forecast_stability")) {
data_plot$forecast_origin <- with(data_plot, valid_indices - horizon)
data_plot$group <- with(data_plot, paste0(valid_indices))
data_plot$group <- ordered(data_plot$group)
# Plotting the original time-series in each facet. Because the plot is faceted by valid_indices, we'll do a bit of a hack here to create
# the same line plot for each facet.
data_outcome <- data_plot %>%
dplyr::select(valid_indices, !!outcome_name) %>%
dplyr::distinct(valid_indices, .keep_all = TRUE)
data_outcome$index <- data_outcome$valid_indices
data_outcome$valid_indices <- NULL # remove to avoid confusion in facet_wrap()
data_outcome <- data_outcome[rep(1:nrow(data_outcome), length(unique(data_plot$valid_indices))), ]
p <- ggplot()
if (max(data_plot$horizon) != 1) {
p <- p + geom_line(data = data_plot, aes(x = .data$forecast_origin,
y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = factor(.data$model)), size = 1, linetype = 1, show.legend = FALSE)
}
p <- p + geom_point(data = data_plot, aes(x = .data$forecast_origin,
y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = factor(.data$model)))
p <- p + geom_point(data = data_plot, aes(x = .data$valid_indices,
y = eval(parse(text = outcome_name)), fill = "Actual"))
p <- p + scale_color_viridis_d()
p <- p + facet_wrap(~ valid_indices)
p <- p + geom_line(data = data_outcome, aes(x = .data$index,
y = eval(parse(text = outcome_name))), color = "gray50")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Model") + labs(fill = NULL) +
ggtitle("Rolling Origin Forecast Stability - Faceted by dataset index")
return(suppressWarnings(p))
}
} # nocov end
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot an object of class forecast_results
#'
#' A forecast plot for each horizon for each model in \code{predict.forecast_model()}.
#'
#' @param x An object of class 'forecast_results' from \code{predict.forecast_model()}.
#' @param data_actual A data.frame containing the target/outcome name and any grouping columns.
#' The data can be historical actuals and/or holdout/test data.
#' @param actual_indices Required if \code{data_actual} is given. A vector or 1-column data.frame
#' of numeric row indices or dates (class 'Date' or 'POSIXt') with length \code{nrow(data_actual)}.
#' The data can be historical actuals and/or holdout/test data.
#' @param facet Optional. For numeric outcomes, a formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series)
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}.
#' @param models Optional. Filter results by user-defined model name from \code{train_model()}.
#' @param horizons Optional. Filter results by horizon.
#' @param windows Optional. Filter results by validation window number.
#' @param group_filter Optional. A string for filtering plot results for grouped time-series (e.g., \code{"group_col_1 == 'A'"});
#' passed to \code{dplyr::filter()} internally.
#' @param ... Not used.
#' @return Forecast plot of class 'ggplot'.
#' @export
plot.forecast_results <- function(x, data_actual = NULL, actual_indices = NULL, facet = horizon ~ model,
models = NULL, horizons = NULL, windows = NULL,
group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
if(xor(is.null(data_actual), is.null(actual_indices))) {
stop("If plotting a hold-out or comparison dataset, both 'data_actual' and 'actual_indices' need to be specified.")
}
data_forecast <- x
rm(x)
type <- "forecast" # Only one plot option at present.
#----------------------------------------------------------------------------
method <- attributes(data_forecast)$method
forecast_horizons <- attributes(data_forecast)$horizons
outcome_col <- attributes(data_forecast)$outcome_col
outcome_name <- attributes(data_forecast)$outcome_name
outcome_levels <- attributes(data_forecast)$outcome_levels
date_indices <- attributes(data_forecast)$date_indices
groups <- attributes(data_forecast)$group
data_stop <- attributes(data_forecast)$data_stop
data_forecast <- as.data.frame(data_forecast) # Coerce to remove new vctrs warning about mismatched classes when joining data.
if (!is.null(outcome_levels) && !is.null(groups) && !is.null(data_actual)) {
stop("Plotting forecasts from grouped time series with an actuals dataset is not currently supported.")
}
#----------------------------------------------------------------------------
names(data_forecast)[names(data_forecast) == "forecast_period"] <- "index" # For code uniformity.
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability.
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(data_forecast) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
}
#----------------------------------------------------------------------------
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
#----------------------------------------------------------------------------
if (!is.null(data_actual)) {
data_actual <- data_actual[, c(outcome_name, groups), drop = FALSE]
data_actual$index <- actual_indices
if (!is.null(group_filter)) {
data_actual <- dplyr::filter(data_actual, eval(parse(text = group_filter)))
}
}
#----------------------------------------------------------------------------
# Filter plots using user input.
models <- if (is.null(models)) {unique(data_forecast$model)} else {models}
horizons <- if (is.null(horizons)) {forecast_horizons} else {horizons}
windows <- if (is.null(windows)) {unique(data_forecast$window_number)} else {windows}
if (method == "multi_output") {
data_forecast$model_forecast_horizon <- data_forecast$horizon # Used for filtering below.
}
data_forecast <- data_forecast[data_forecast$model %in% models &
data_forecast$model_forecast_horizon %in% horizons &
data_forecast$window_number %in% windows, ]
if (!is.null(group_filter)) {
data_forecast <- dplyr::filter(data_forecast, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
data_plot <- data_forecast
#----------------------------------------------------------------------------
if (type == "forecast") {
#----------------------------------------------------------------------------
# Set up ggplot color and group parameters.
if (is.null(outcome_levels)) { # Numeric outcomes.
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$model_forecast_horizon, .data$horizon, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$model_forecast_horizon, .data$window_number, .data$horizon, eval(parse(text = groups)))
}
#--------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
if (method == "multi_output") { # With 1 model, all horizons belong to the same forecast and should have the same color.
ggplot_color <- ggplot_color[!ggplot_color %in% "horizon"]
}
#--------------------------------------------------------------------------
# There is a distinction between horizon and the model horizon; rename to work with the facet inputs.
data_plot$forecast_horizon <- data_plot$horizon
data_plot$horizon <- data_plot$model_forecast_horizon
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
# Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Forecast"
}
# Used to avoid lines spanning any gaps between validation windows.
if (all(data_plot$window_number == 1)) { # One window; no need to add the window number to the legend.
data_plot$ggplot_group <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_group <- apply(data_plot[, c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
# Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
data_plot$ggplot_group <- factor(data_plot$ggplot_group, levels = unique(data_plot$ggplot_group), ordered = TRUE)
}
#--------------------------------------------------------------------------
if (!is.null(data_actual)) {
#--------------------------------------------------------------------------
# If the plot is faceted by model, repeat the actuals dataset once for each model in a long format for faceting by model.
if (length(unique(data_plot$model)) == 1) {
data_actual$model <- unique(data_plot$model)
} else {
n_reps <- nrow(data_actual)
data_actual <- data_actual[rep(1:nrow(data_actual), length(unique(data_plot$model))), ]
data_actual$model <- rep(unique(data_plot$model), each = n_reps)
}
#--------------------------------------------------------------------------
# If the plot is colored by model, repeat the actuals dataset once for each model in a long format for faceting by model.
if (length(unique(data_plot$model)) == 1) {
data_actual$model <- unique(data_plot$model)
} else {
n_reps <- nrow(data_actual)
data_actual <- data_actual[rep(1:nrow(data_actual), length(unique(data_plot$model))), ]
data_actual$model <- rep(unique(data_plot$model), each = n_reps)
}
#--------------------------------------------------------------------------
if (length(ggplot_color) > 0) {
if (ggplot_color == "horizon") {
data_actual$ggplot_color <- NA
} else {
data_actual$ggplot_color <- apply(data_actual[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
} else {
data_actual$ggplot_color <- "Forecast"
}
# Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_actual$ggplot_color <- "Forecast"
}
if (length(ggplot_color) > 0) {
if (ggplot_color == "horizon") {
data_actual$ggplot_group <- "Forecast"
} else {
data_actual$ggplot_group <- apply(data_actual[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
} else {
data_actual$ggplot_group <- "Forecast"
}
#------------------------------------------------------------------------
# Coerce to viridis color scale with an ordered factor. The levels in the actual data are limited
# to those factor levels that appear in the forecast data.
data_actual$ggplot_color <- factor(data_actual$ggplot_color, levels = levels(data_plot$ggplot_color), ordered = TRUE)
data_actual$ggplot_group <- factor(data_actual$ggplot_group, levels = levels(data_plot$ggplot_color), ordered = TRUE)
}
#--------------------------------------------------------------------------
if (is.null(outcome_levels)) { # Numeric outcome.
p <- ggplot()
if (1 %in% horizons || nrow(data_plot) == 1 || (method == "multi_output" && any(c(facet_names %in% "horizon")))) { # Use geom_point instead of geom_line to plot a 1-step-ahead forecast.
if (method == "direct") {
data_plot_temp <- data_plot[data_plot$horizon == 1, ]
} else if (method == "multi_output") {
data_plot_temp <- data_plot
}
# If the plotting data.frame has both lower and upper forecasts plot these bounds.
# We'll add the shading in a lower ggplot layer so the point forecasts are on top in the final plot.
if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {
# geom_ribbon() does not work with a single data point when forecast bounds are plotted.
p <- p + geom_linerange(data = data_plot_temp,
aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
color = .data$ggplot_color, group = .data$ggplot_group), alpha = .25, size = 3, show.legend = FALSE)
}
p <- p + geom_point(data = data_plot_temp,
aes(x = .data$index, y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = .data$ggplot_color, group = .data$ggplot_group))
}
if ((method == "direct" && !all(horizons == 1)) || (method == "multi_output" && nrow(data_plot) > 1 && !any(c(facet_names %in% "horizon")))) { # Plot forecasts for model forecast horizons > 1.
if (method == "direct") {
data_plot_temp <- data_plot[data_plot$horizon != 1, ]
} else if (method == "multi_output") {
data_plot_temp <- data_plot
}
# If the plotting data.frame has bother lower and upper forecasts plot these bounds.
if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {
p <- p + geom_ribbon(data = data_plot_temp,
aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
fill = .data$ggplot_color, group = .data$ggplot_group, color = NULL), alpha = .25, show.legend = FALSE)
}
p <- p + geom_line(data = data_plot_temp,
aes(x = .data$index, y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = .data$ggplot_color, group = .data$ggplot_group))
}
# Add user-defined actuals data to the plots
if (!is.null(data_actual)) {
if (is.null(groups)) {
p <- p + geom_line(data = data_actual, aes(x = .data$index,
y = eval(parse(text = outcome_name))), color = "grey50")
} else {
# If faceting by group, this reduces to the single time series case so the actuals
# will be the default grey so as not to double encode the plot data.
if (any(facet_names %in% groups)) {
p <- p + geom_line(data = data_actual, aes(x = .data$index, y = eval(parse(text = outcome_name))), color = "grey50", show.legend = FALSE)
} else if (facet_names != c("model", "horizon") || facet_names != c("horizon", "model")) { # The actuals colors cannot be uniquely mapped to the forecast plot colors.
p <- p + geom_line(data = data_actual, aes(x = .data$index, y = eval(parse(text = outcome_name))), color = "grey50", show.legend = FALSE)
} else if (facet_names == c("model", "horizon") || facet_names == c("horizon", "model")) { # The actuals can be uniquely mapped to the forecasts given the faceting.
p <- p + geom_line(data = data_actual, aes(x = .data$index,
y = eval(parse(text = outcome_name)),
color = .data$ggplot_color,
group = .data$ggplot_group), show.legend = FALSE)
}
}
}
p <- p + scale_color_viridis_d()
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + facet_grid(facet, scales = "free_y")
#--------------------------------------------------------------------------
#--------------------------------------------------------------------------
} else { # Factor outcome.
data_plot <- data_forecast
if (factor_prob) {
# Melt the data for plotting the multiple class probabilities in stacked bars.
data_plot <- tidyr::gather(data_plot, "outcome", "value", -!!names(data_plot)[!names(data_plot) %in% c(outcome_levels)])
}
#------------------------------------------------------------------------
# If there is only 1 validation window, remove it from the grouping to reduce clutter.
if (length(unique(data_plot$window_number)) == 1) {
data_plot$ggplot_color_group <- apply(data_plot[, c("model", "model_forecast_horizon", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_color_group <- apply(data_plot[, c("model", "model_forecast_horizon", "window_number", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
#------------------------------------------------------------------------
if (!is.null(data_actual)) {
# actual or forecast: these are all actuals.
data_actual$actual_or_forecast <- "actual"
# historical, test, or model_forecast: these may be any combination of historical data and a holdout test dataset.
data_actual$time_series_type <- with(data_actual, ifelse(index <= data_stop, "historical", "test"))
names(data_actual)[names(data_actual) == outcome_name] <- "outcome" # Standardize before concat with forecasts.
data_actual$ggplot_color_group <- "Actual" # Actuals will be plotted in the top plot facet.
data_actual$value <- 1 # Plot a solid bar with probability 1 in geom_col().
# In cases where historical data is provided in data_actual, duplicate the historical data
# such that it appears as a sequence in each plot facet. Here, 'ggplot_color_group' gives the,
# possibly user-filtered, plot facets.
if ("historical" %in% unique(data_actual$time_series_type)) {
data_hist <- data_actual[data_actual$time_series_type == "historical", c("index", "outcome", "value", "ggplot_color_group")]
n_rows <- nrow(data_hist)
data_hist <- data_hist[rep(1:nrow(data_hist), length(unique(data_plot$ggplot_color_group))), ]
data_hist$ggplot_color_group <- rep(unique(data_plot$ggplot_color_group), each = n_rows)
data_actual <- suppressWarnings(dplyr::bind_rows(data_hist, data_actual))
}
}
#------------------------------------------------------------------------
# Standardize names for plotting and before any concatenation with data_actual.
names(data_plot)[names(data_plot) == "index"] <- "index"
data_plot$actual_or_forecast <- "forecast"
data_plot$time_series_type <- "model_forecast"
#------------------------------------------------------------------------
if (factor_level) {
# Standardize names for plotting and before any concatenation with data_actual.
names(data_plot)[names(data_plot) == paste0(outcome_name, "_pred")] <- "outcome"
data_plot$value <- 1
}
#------------------------------------------------------------------------
if (!is.null(data_actual)) {
data_plot <- suppressWarnings(dplyr::bind_rows(data_plot, data_actual))
}
data_plot$ggplot_color_group <- factor(data_plot$ggplot_color_group, levels = rev(unique(data_plot$ggplot_color_group)), ordered = TRUE)
data_plot$value <- as.numeric(data_plot$value)
data_plot$outcome <- factor(data_plot$outcome, levels = outcome_levels, ordered = TRUE)
#------------------------------------------------------------------------
if (!is.null(groups)) {
data_plot <- dplyr::distinct(data_plot, .data$ggplot_color_group, .data$index, .data$outcome, .keep_all = TRUE)
}
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = .data$value, color = .data$outcome, fill = .data$outcome),
position = position_stack(reverse = TRUE))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
if (is.null(groups)) {
p <- p + facet_wrap(~ ggplot_color_group, scales = "free_y")
} else {
p <- p + facet_grid(ggplot_color_group ~ ., scales = "free_y")
}
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
if (factor_level) {
p <- p + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), panel.spacing = unit(0, "lines"))
}
} # End numeric and factor outcome plot setup.
#--------------------------------------------------------------------------
# Add a vertical line to mark the beginning of the forecast period.
p <- p + geom_vline(xintercept = data_stop, color = "red")
#--------------------------------------------------------------------------
if (is.null(outcome_levels)) { # Numeric outcome.
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
x_axis_title <- paste(temp_1, temp_2, sep = "")
x_axis_title <- paste(x_axis_title, collapse = " + ")
p <- p + xlab("Dataset index") + ylab("Outcome") +
labs(color = x_axis_title, fill = NULL) +
ggtitle("H-Step-Ahead Model Forecasts")
} else { # Factor outcome.
p <- p + xlab("Dataset index") + ylab("Outcome") +
labs(color = "Outcome", fill = "Outcome") +
ggtitle("H-Step-Ahead Model Forecasts")
}
return(suppressWarnings(p))
}
} # nocov end
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.