Nothing
#' Train a model across horizons and validation datasets
#'
#' Train a user-defined forecast model for each horizon, 'h', and across the validation
#' datasets, 'd'. If \code{method = "direct"}, a total of 'h' * 'd' models are trained.
#' If \code{method = "multi_output"}, a total of 1 * 'd' models are trained.
#' These models can be trained in parallel with the \code{future} package.
#'
#' @param lagged_df An object of class 'lagged_df' from \code{\link{create_lagged_df}}.
#' @param windows An object of class 'windows' from \code{\link{create_windows}}.
#' @param model_name A name for the model.
#' @param model_function A user-defined wrapper function for model training that takes the following
#' arguments: (1) a horizon-specific data.frame made with \code{create_lagged_df(..., type = "train")}
#' (i.e., the dataset(s) stored in \code{lagged_df}) and, optionally, (2) any number of additional named arguments
#' which can be passed in \code{...} in this function.
#' @param ... Optional. Named arguments passed into the user-defined \code{model_function}.
#' @param use_future Boolean. If \code{TRUE}, the \code{future} package is used for training models in parallel.
#' The models will train in parallel across either (1) model forecast horizons or (b) validation windows,
#' whichever is longer (i.e., \code{length(create_lagged_df())} or \code{nrow(create_windows())}). The user
#' should run \code{future::plan(future::multiprocess)} or similar prior to this function to train these models
#' in parallel.
#' @return An S3 object of class 'forecast_model': A nested list of trained models. Models can be accessed with
#' \code{my_trained_model$horizon_h$window_w$model} where 'h' gives the forecast horizon and 'w' gives
#' the validation dataset window number from \code{create_windows()}.
#'
#' @section Methods and related functions:
#'
#' The output of \code{train_model} can be passed into
#'
#' \itemize{
#' \item \code{\link{return_error}}
#' \item \code{\link{return_hyper}}
#' }
#'
#' and has the following generic S3 methods
#'
#' \itemize{
#' \item \code{\link[=predict.forecast_model]{predict}}
#' \item \code{\link[=plot.training_results]{plot}} (from \code{predict.forecast_model(data = create_lagged_df(..., type = "train"))})
#' \item \code{\link[=plot.forecast_results]{plot}} (from \code{predict.forecast_model(data = create_lagged_df(..., type = "forecast"))})
#' }
#' @example /R/examples/example_train_model.R
#' @export
train_model <- function(lagged_df, windows, model_name, model_function, ..., use_future = FALSE) {
if (missing(lagged_df) || !methods::is(lagged_df, "lagged_df")) {
stop("The 'lagged_df' argument takes an object of class 'lagged_df' as input. Run create_lagged_df() first.")
}
if (missing(windows) || !methods::is(windows, "windows")) {
stop("The 'windows' argument takes an object of class 'windows' as input. Run create_windows() first.")
}
if (missing(model_name)) {
stop("Enter a model name for the 'model_name' argument.")
}
if (missing(model_function) || !is.function(model_function)) {
stop("The 'model_function' argument takes a user-defined model training function as input.")
}
row_indices <- attributes(lagged_df)$row_indices
date_indices <- attributes(lagged_df)$date_indices
horizons <- attributes(lagged_df)$horizons
data_start <- attributes(lagged_df)$data_start
data_stop <- attributes(lagged_df)$data_stop
skeleton <- attributes(lagged_df)$skeleton
# These are the arguments from the user-defined modeling function passed in train_model() with ...
# which is optional but potentially convenient for the user who can avoid re-defining the modeling function
# in repeated calls to train_model(). The arguments in ... will be passed as a named list in do.call(). This
# local scoping within future_lapply() appears to be necessary because global arguments passed in ...
# aren't being found by the future package when use_future = TRUE.
n_args <- ...length()
if (n_args > 0) {
model_function_args <- as.list(substitute(list(...)))[-1]
# Global objects passed in ... need to be evaluated prior to any future calls. For instance, if a
# custom ... model argument outcome_col = x is given, x needs to be evaluated before model_function_args
# is passed in future_lapply().
model_function_args <- lapply(model_function_args, eval)
}
#----------------------------------------------------------------------------
# The default future behavior is to parallelize the model training over the longer dimension: (a) number of
# forecast horizons or (b) number of validation windows. This is due to a current limitation
# in the future package on changing object size limitations for nested futures where
# "options(globals.maxSize.default = Inf)" isn't recognized.
if (isTRUE(use_future)) {
if (length(horizons) >= nrow(windows)) {
lapply_across_horizons <- future.apply::future_lapply
lapply_across_val_windows <- base::lapply
} else {
lapply_across_horizons <- base::lapply
lapply_across_val_windows <- future.apply::future_lapply
}
} else {
lapply_across_horizons <- base::lapply
lapply_across_val_windows <- base::lapply
}
#----------------------------------------------------------------------------
# Seq along model forecast horizon > cross-validation windows.
data_out <- lapply_across_horizons(lagged_df, function(data, future.seed, ...) { # model forecast horizon.
model_plus_valid_data <- lapply_across_val_windows(1:nrow(windows), function(i, future.seed, ...) { # validation windows within model forecast horizon.
window_length <- windows[i, "window_length"]
if (is.null(date_indices)) {
valid_indices <- windows[i, "start"]:windows[i, "stop"]
valid_indices_date <- NULL
} else {
# When create_lagged_df(..., groups = NULL, keep_rows = FALSE), the validation indices need an offset to account for the fact that
# validation windows are selected where row_indices %in% valid_indices which maps back to the input dataset to create_lagged_df()
# which will have 1:max(lookback) more rows than the dataset that comes out of create_lagged_df(..., groups = NULL, keep_rows = FALSE).
valid_indices <- min(row_indices) - 1 + which(date_indices >= windows[i, "start"] & date_indices <= windows[i, "stop"])
valid_indices_date <- date_indices[date_indices >= windows[i, "start"] & date_indices <= windows[i, "stop"]]
}
# A window length of 0 that spans the entire dataset removes the nested cross-validation and
# trains on all input data in lagged_df.
if (window_length == 0 && (windows[i, "start"] == data_start) && (windows[i, "stop"] == data_stop)) {
# Model training over all data.
if (n_args == 0) { # No user-defined model args passed in ...
model <- try(model_function(data))
} else {
model <- try(do.call(model_function, append(list(data), model_function_args)))
}
} else { # Model training with external block-contiguous cv.
# These validation indices will always start at 1. They index with respect to the
# data passed into model_function() as opposed to the dataset passed in create_lagged_df().
# These indices, stored as an attribute, are for manual filtering any skeleton lagged_dfs in model_function().
validation_indices <- which(row_indices %in% valid_indices)
attributes(data)$validation_indices <- validation_indices
if (n_args == 0) { # No user-defined model args passed in ...
model <- try(model_function(data[-(validation_indices), , drop = FALSE]))
} else {
model <- try({
do.call(model_function, append(list(data[-(validation_indices), , drop = FALSE]), model_function_args))
})
}
}
if (methods::is(model, "try-error")) {
warning(paste0("A model returned class 'try-error' for validation window ", i))
}
list("model" = model, "window" = i, "window_length" = window_length, "valid_indices" = valid_indices,
"date_indices" = valid_indices_date)
}, future.seed = 1) # End model training across nested cross-validation windows for the horizon in "data".
names(model_plus_valid_data) <- paste0("window_", 1:nrow(windows))
attr(model_plus_valid_data, "horizon") <- attributes(data)$horizon
model_plus_valid_data
}, future.seed = 1) # End training across horizons.
attr(data_out, "model_name") <- model_name
attr(data_out, "horizons") <- horizons
attr(data_out, "outcome_col") <- attributes(lagged_df)$outcome_col
attr(data_out, "outcome_cols") <- attributes(lagged_df)$outcome_cols
attr(data_out, "outcome_name") <- attributes(lagged_df)$outcome_name
attr(data_out, "outcome_names") <- attributes(lagged_df)$outcome_names
attr(data_out, "outcome_levels") <- attributes(lagged_df)$outcome_levels
attr(data_out, "row_indices") <- row_indices
attr(data_out, "date_indices") <- date_indices
attr(data_out, "frequency") <- attributes(lagged_df)$frequency
attr(data_out, "data_stop") <- attributes(lagged_df)$data_stop
attr(data_out, "groups") <- attributes(lagged_df)$groups
attr(data_out, "method") <- attributes(lagged_df)$method
attr(data_out, "skeleton") <- skeleton
class(data_out) <- c("forecast_model", class(data_out))
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Predict on validation datasets or forecast
#'
#' Predict with a 'forecast_model' object from \code{train_model()}. If \code{data = create_lagged_df(..., type = "train")},
#' predictions are returned for the outer-loop nested cross-validation datasets.
#' If \code{data} is an object of class 'lagged_df' from \code{create_lagged_df(..., type = "forecast")},
#' predictions are returned for the horizons specified in \code{create_lagged_df(horizons = ...)}.
#'
#' @param ... One or more trained models from \code{train_model()}.
#' @param prediction_function A list of user-defined prediction functions with length equal to
#' the number of models supplied in \code{...}. The prediction functions
#' take 2 required positional arguments--(1) a 'forecast_model' object from \code{train_model()} and (2) a
#' data.frame of model features from \code{create_lagged_df()}. For numeric outcomes and \code{method = "direct"}, the function should \code{return()}
#' 1- or 3-column data.frame of model predictions. If the prediction function returns a 1-column data.frame, point forecasts are assumed.
#' If the prediction function returns a 3-column data.frame, lower and upper forecast bounds are assumed (the
#' order and names of the 3 columns does not matter). For factor outcomes and \code{method = "direct"}, the function should \code{return()}
#' (1) 1-column data.frame of the model-predicted factor level or (2) an L-column data.frame of class probabilities where
#' 'L' equals the number of levels in the outcome; columns should be ordered, from left to right, the same as
#' \code{levels(data$outcome)} which is the default behavior for most \code{predict(..., type = "prob")} functions.
#' Column names do not matter. For numeric outcomes and \code{method = "multi_output"}, the function should \code{return()} and
#' h-column data.frame of model predictions--1 column for each horizon. Forecast intervals and factor outcomes are not currently
#' supported with \code{method = "multi_output"}.
#' @param data If \code{data} is a training dataset from \code{create_lagged_df(..., type = "train")}, validation dataset
#' predictions are returned; else, if \code{data} is a forecasting dataset from \code{create_lagged_df(..., type = "forecast")},
#' forecasts from horizons 1:h are returned.
#' @return If \code{data = create_lagged_df(..., type = "forecast")}, an S3 object of class 'training_results'. If
#' \code{data = create_lagged_df(..., type = "forecast")}, an S3 object of class 'forecast_results'.
#'
#' \strong{Columns in returned 'training_results' data.frame:}
#' \itemize{
#' \item \code{model}: User-supplied model name in \code{train_model()}.
#' \item \code{model_forecast_horizon}: The direct-forecasting time horizon that the model was trained on.
#' \item \code{window_length}: Validation window length measured in dataset rows.
#' \item \code{window_number}: Validation dataset number.
#' \item \code{valid_indices}: Validation dataset row names from \code{attributes(create_lagged_df())$row_indices}.
#' \item \code{date_indices}: If given and \code{method = "direct"}, validation dataset date indices from \code{attributes(create_lagged_df())$date_indices}.
#' If given and \code{method = "multi_output"}, date_indices represents the date of the forecast.
#' \item \code{"groups"}: If given, the user-supplied groups in \code{create_lagged_df()}.
#' \item \code{"outcome_name"}: The target being forecasted.
#' \item \code{"outcome_name"_pred}: The model predictions.
#' \item \code{"outcome_name"_pred_lower}: If given, the lower prediction bounds returned by the user-supplied prediction function.
#' \item \code{"outcome_name"_pred_upper}: If given, the upper prediction bounds returned by the user-supplied prediction function.
#' \item \code{forecast_indices}: If \code{method = "multi_output"}, the validation index of the h-step-ahead forecast.
#' \item \code{forecast_date_indices}: If \code{method = "multi_output"}, the validation date index of the h-step-ahead forecast.
#' }
#'
#' \strong{Columns in returned 'forecast_results' data.frame:}
#' \itemize{
#' \item \code{model}: User-supplied model name in \code{train_model()}.
#' \item \code{model_forecast_horizon}: If \code{method = "direct"}, the direct-forecasting time horizon that the model was trained on.
#' \item \code{horizon}: Forecast horizons, 1:h, measured in dataset rows.
#' \item \code{window_length}: Validation window length measured in dataset rows.
#' \item \code{forecast_period}: The forecast period in row indices or dates. The forecast period starts at either \code{attributes(create_lagged_df())$data_stop + 1} for row indices or \code{attributes(create_lagged_df())$data_stop + 1 * frequency} for date indices.
#' \item \code{"groups"}: If given, the user-supplied groups in \code{create_lagged_df()}.
#' \item \code{"outcome_name"}: The target being forecasted.
#' \item \code{"outcome_name"_pred}: The model forecasts.
#' \item \code{"outcome_name"_pred_lower}: If given, the lower forecast bounds returned by the user-supplied prediction function.
#' \item \code{"outcome_name"_pred_upper}: If given, the upper forecast bounds returned by the user-supplied prediction function.
#' }
#'
#' @example /R/examples/example_predict_train_model.R
#' @export
predict.forecast_model <- function(..., prediction_function = list(NULL), data) {
#----------------------------------------------------------------------------
model_list <- list(...)
if (!all(unlist(lapply(model_list, function(x) {class(x)[1]})) %in% "forecast_model")) {
stop("The '...' argument takes 1 or more objects of class 'forecast_model' as input. Run train_model() first.
Also, the arguments 'prediction_function' and 'data' need to be named and not positional because they
follow '...'.")
}
if (length(model_list) != length(prediction_function)) {
stop("The number of prediction functions does not equal the number of forecast models.")
}
if (!methods::is(data, "lagged_df")) {
stop("The 'data' argument takes a training or forecasting dataset of class 'lagged_df' from create_lagged_df().")
}
#----------------------------------------------------------------------------
type <- attributes(data)$type # train or forecast.
method <- attributes(model_list[[1]])$method
horizons <- attributes(model_list[[1]])$horizons
outcome_col <- attributes(model_list[[1]])$outcome_col
outcome_cols <- attributes(model_list[[1]])$outcome_cols
outcome_name <- attributes(model_list[[1]])$outcome_name
outcome_names <- attributes(model_list[[1]])$outcome_names
outcome_levels <- attributes(model_list[[1]])$outcome_levels
row_indices <- attributes(model_list[[1]])$row_indices
date_indices <- attributes(model_list[[1]])$date_indices
frequency <- attributes(model_list[[1]])$frequency
groups <- attributes(model_list[[1]])$groups
skeleton <- attributes(model_list[[1]])$skeleton
#----------------------------------------------------------------------------
if (type == "train") {
data_stop <- attributes(model_list[[1]])$data_stop
} else {
data_stop <- attributes(data)$data_stop
}
#----------------------------------------------------------------------------
# Seq along forecast model[i] > model forecast horizon[j] > validation window number[k].
data_model <- lapply(seq_along(model_list), function(i) {
prediction_fun <- prediction_function[[i]]
data_horizon <- lapply(seq_along(model_list[[i]]), function(j) {
data_win_num <- lapply(seq_along(model_list[[i]][[j]]), function(k) {
data_results <- model_list[[i]][[j]][[k]]
# Predict on training data or the forecast dataset?
if (type == "train") { # Nested cross-validation.
# These validation indices will always start at 1. They index with respect to the
# data passed into predict_function() as opposed to the dataset passed in create_lagged_df().
# These indices, stored as an attribute, are for manual filtering any skeleton lagged_dfs in predict_function().
validation_indices <- which(row_indices %in% data_results$valid_indices)
x_valid <- data[[j]][validation_indices, -(outcome_cols), drop = FALSE]
y_valid <- data[[j]][validation_indices, outcome_cols, drop = FALSE] # Actuals in function return.
attributes(x_valid)$validation_indices <- validation_indices
attributes(x_valid)$horizons <- attributes(data[[j]])$horizons
data_pred <- try(prediction_fun(data_results$model, x_valid)) # Nested cross-validation.
if (methods::is(data_pred, "try-error")) {
warning(paste0("Model '", attributes(model_list[[i]])$model_name, "' returned class 'try-error' for model ", j, " in validation window ", k))
}
if (!is.null(groups)) {
data_groups <- x_valid[, groups, drop = FALSE] # Save out group identifiers.
}
} else { # Forecast.
forecast_period <- data[[j]][, "index", drop = FALSE]
names(forecast_period) <- "forecast_period"
forecast_horizons <- data[[j]][, "horizon", drop = FALSE]
data_for_forecast <- data[[j]][, !names(data[[j]]) %in% c("index", "horizon"), drop = FALSE] # Remove ID columns for predict().
data_pred <- try(prediction_fun(data_results$model, data_for_forecast)) # User-defined prediction function.
if (methods::is(data_pred, "try-error")) {
warning(paste0("Model '", attributes(model_list[[i]])$model_name, "' returned class 'try-error' for model ", j, " in validation window ", k))
}
if (!is.null(groups)) {
data_groups <- data_for_forecast[, groups, drop = FALSE]
}
} # End forecast.
#----------------------------------------------------------------------
# Check the return of the user-defined predict() function.
if (is.null(outcome_levels) && method == "direct" && !ncol(data_pred) %in% c(1, 3)) { # Numeric outcome.
stop("For numeric outcomes, the user-defined prediction function needs to return 1- or 3-column data.frame of model predictions.")
}
if (is.null(outcome_levels) && method == "multi_output" && ncol(data_pred) != length(horizons)) { # Numeric outcome.
stop("For numeric outcomes, the user-defined prediction function needs to return a length(horizons)-column data.frame of model predictions.")
}
if (!is.null(outcome_levels)) { # Factor outcome.
if (ncol(data_pred) == 1 && method == "direct" && !methods::is(data_pred[, 1], "factor")) {
stop("For factor outcomes where the predicted factor level is returned (e.g., predict(..., type = 'response')), the user-defined
prediction function needs to return 1-column data.frame of factor predictions. If returning class
probabilities, the number of data.frame columns should equal the number of factor levels and be in the same order as levels(data$outcome).")
}
if (method == "direct" && ncol(data_pred) > 1 && ncol(data_pred) != length(outcome_levels)) {
stop("For factor outcomes where class probabilities are returned (e.g., predict(..., type = 'prob')),
the number of data.frame columns returned should equal the number of factor levels.")
}
}
#----------------------------------------------------------------------
# Format the returned data.frames from predict().
if (method == "direct") {
if (is.null(outcome_levels)) { # Numeric outcome.
if (ncol(data_pred) == 1) {
names(data_pred) <- paste0(outcome_names, "_pred")
} else { # Confidence/credible forecast intervals.
# Find the lower, point, and upper forecasts and order the columns accordingly.
data_pred <- data_pred[order(unlist(lapply(data_pred, mean, na.rm = TRUE)))]
names(data_pred) <- c(paste0(outcome_names, "_pred_lower"), paste0(outcome_names, "_pred"), paste0(outcome_names, "_pred_upper"))
# Re-order so that the point forecast is first.
data_pred <- data_pred[, c(2, 1, 3)]
}
} else { # Factor outcome.
if (ncol(data_pred) == 1) { # A predicted factor level.
names(data_pred) <- paste0(outcome_names, "_pred")
} else { # Predicted class probabilities.
names(data_pred) <- outcome_levels
}
} # End reformatting of the returned data.frame of predictions.
} else if (method == "multi_output") {
if (is.null(outcome_levels)) { # Numeric outcome.
names(data_pred) <- paste0(outcome_names, "_pred")
} else {
stop("Multi-output models do not currently support factor outcomes.")
}
}
#----------------------------------------------------------------------
model_name <- attributes(model_list[[i]])$model_name
if (type == "train") { # Nested cross-validation.
data_temp <- data.frame("model" = model_name,
"model_forecast_horizon" = horizons[j],
"window_length" = data_results$window_length,
"window_number" = data_results$window,
"valid_indices" = data_results$valid_indices)
data_temp$date_indices <- data_results$date_indices
if (is.null(groups)) {
data_temp <- cbind(data_temp, y_valid, data_pred)
} else {
data_temp <- cbind(data_temp, data_groups, y_valid, data_pred)
}
} else { # Forecast.
data_temp <- data.frame("model" = model_name,
"model_forecast_horizon" = horizons[j],
"horizon" = forecast_horizons,
"window_length" = data_results$window_length,
"window_number" = data_results$window,
"forecast_period" = forecast_period)
if (is.null(groups)) {
data_temp <- cbind(data_temp, data_pred)
} else {
data_temp <- cbind(data_temp, data_groups, data_pred)
}
} # End forecast results.
data_temp$model <- as.character(data_temp$model) # Coerce to remove any factor levels.
data_temp
}) # End cross-validation window predictions.
data_win_num <- suppressMessages(dplyr::bind_rows(data_win_num))
}) # End horizon-level predictions.
data_horizon <- suppressMessages(dplyr::bind_rows(data_horizon))
}) # End model-level predictions.
data_out <- suppressMessages(dplyr::bind_rows(data_model))
# For multi-output models, the data (a) need to be reshaped from a wide to long format
# and (b) the validation indices that represent the forecast origin need to be changed
# to the index for the predicted time period.
if (method == "multi_output") {
if (type == "train") {
data_actual_temp <- dplyr::select(data_out, -dplyr::ends_with("_pred"))
data_pred_temp <- dplyr::select(data_out, -!!outcome_names)
data_actual_temp <- tidyr::pivot_longer(data_actual_temp, cols = !!outcome_names, names_to = "remove", values_to = outcome_name)
data_actual_temp$remove <- NULL
data_pred_temp <- tidyr::pivot_longer(data_pred_temp, cols = dplyr::ends_with("_pred"), values_to = paste0(outcome_name, "_pred"))
data_actual_temp[, paste0(outcome_name, "_pred")] <- data_pred_temp[, paste0(outcome_name, "_pred")]
data_out <- data_actual_temp
data_out$model_forecast_horizon <- rep(horizons, nrow(data_out) / length(horizons))
data_out$forecast_indices <- data_out$valid_indices
data_out$forecast_indices <- data_out$model_forecast_horizon + data_out$forecast_indices
data_out$horizons <- NULL
names(data_out)[names(data_out) == "model_forecast_horizon"] <- "horizon"
if (!is.null(date_indices)) {
data_out$forecast_date_indices <- as.Date(unlist(purrr::map(1:nrow(data_out), function(i) {
base::seq(data_out$date_indices[i], by = frequency, length.out = data_out$horizon[i] + 1)[data_out$horizon[i] + 1]})),
origin = "1970-01-01")
}
data_out <- dplyr::arrange(data_out, .data$model, .data$valid_indices, .data$forecast_indices, .data$horizon)
} else if (type == "forecast") {
forecast_period <- data.frame(data_out[, "forecast_period", drop = FALSE], stringsAsFactors = FALSE)
forecast_period <- data.frame(strsplit(forecast_period$forecast_period, ", "), stringsAsFactors = FALSE)
names(forecast_period) <- "forecast_period"
forecast_period$forecast_period <- if (is.null(date_indices)){as.integer(forecast_period$forecast_period)} else {as.Date(forecast_period$forecast_period)}
data_out <- tidyr::pivot_longer(data_out, cols = dplyr::ends_with("_pred"),
names_to = "remove", values_to = paste0(outcome_name, "_pred"))
data_out$remove <- NULL
data_out$model_forecast_horizon <- NULL
data_out$horizon <- horizons
data_out$forecast_period <- forecast_period$forecast_period
}
}
data_out <- as.data.frame(data_out)
row.names(data_out) <- 1:nrow(data_out)
attr(data_out, "method") <- method
attr(data_out, "horizons") <- horizons
attr(data_out, "outcome_col") <- outcome_col
attr(data_out, "outcome_cols") <- outcome_cols
attr(data_out, "outcome_name") <- outcome_name
attr(data_out, "outcome_names") <- outcome_names
attr(data_out, "outcome_levels") <- outcome_levels
attr(data_out, "row_indices") <- row_indices
attr(data_out, "date_indices") <- date_indices
attr(data_out, "frequency") <- frequency
attr(data_out, "data_stop") <- data_stop
attr(data_out, "groups") <- groups
if (type == "train") {
class(data_out) <- c("training_results", "forecast_model", class(data_out))
} else {
class(data_out) <- c("forecast_results", "forecast_model", class(data_out))
}
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot an object of class training_results
#'
#' Several diagnostic plots can be returned to assess the quality of the forecasts
#' based on predictions on the validation datasets.
#'
#' @param x An object of class 'training_results' from \code{predict.forecast_model()}.
#' @param type Plot type. The default plot is "prediction" for validation dataset predictions.
#' @param models Optional. Filter results by user-defined model name from \code{train_model()}.
#' @param horizons Optional. A numeric vector of model forecast horizons to filter results by horizon-specific model.
#' @param windows Optional. A numeric vector of window numbers to filter results.
#' @param valid_indices Optional. A numeric or date vector to filter results by validation row indices or dates.
#' @param group_filter Optional. A string for filtering plot results for grouped time series
#' (e.g., \code{"group_col_1 == 'A'"}). The results are passed to \code{dplyr::filter()} internally.
#' @param facet Optional. For numeric outcomes, a formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series)
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' @param keep_missing Boolean. If \code{TRUE}, predictions are plotted for indices/dates where the outcome is missing.
#' @param ... Not used.
#' @return Diagnostic plots of class 'ggplot'.
#' @export
plot.training_results <- function(x,
type = c("prediction", "residual", "forecast_stability"),
facet = horizon ~ model, models = NULL, horizons = NULL,
windows = NULL, valid_indices = NULL, group_filter = NULL, keep_missing = FALSE, ...) { # nocov start
#----------------------------------------------------------------------------
data <- x
rm(x)
type <- type[1] # The default plot is predicting over historical validation windows.
if (type == "forecast_stability") {
if (!xor(is.null(windows), is.null(valid_indices))) {
stop("Select either (a) one or more validation windows, 'windows', or (b) a range of dataset rows, 'valid_indices', to reduce plot size.")
}
}
if (!is.null(attributes(data)$group) & !type %in% c("prediction", "residual")) {
stop("Only 'prediction' and 'residual' plots are currently available for grouped models.")
}
if (!is.null(attributes(data)$outcome_levels) & !type %in% c("prediction", "residual")) {
stop("Only 'prediction' and 'residual' plots are currently available for models with factors as outcomes.")
}
#----------------------------------------------------------------------------
method <- attributes(data)$method
outcome_col <- attributes(data)$outcome_col
outcome_name <- attributes(data)$outcome_name
outcome_levels <- attributes(data)$outcome_levels
date_indices <- attributes(data)$date_indices
frequency <- attributes(data)$frequency
groups <- attributes(data)$group
window_custom <- all(data$window_length == "custom")
data <- as.data.frame(data) # Coerce to remove new vctrs warning about mismatched classes when joining data.
#----------------------------------------------------------------------------
if (method == "multi_output") {
data$valid_indices <- data$forecast_indices # The plot indices are the forecast period, not the forecast origin.
if (!is.null(date_indices)) {
data$date_indices <- data$forecast_date_indices # The plot indices are the forecast period, not the forecast origin.
}
}
#----------------------------------------------------------------------------
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability.
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(data) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
if (type == "residual" && factor_prob) {
stop("Residual plots with predicted class probabilities are not currently supported.")
}
if (type == "residual" && !is.null(groups)) {
stop("Residual plots with grouped data are not currently supported.")
}
}
#----------------------------------------------------------------------------
if (isFALSE(keep_missing) && is.null(outcome_levels)) {
# Set the predicted values to NA to keep any missing data in the actual
# time series so gaps in data collection are not connected with a line in the plot.
data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred")] <- NA
if (all(any(grepl("_pred_lower", names(data))), any(grepl("_pred_upper", names(data))))) {
data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred_lower")] <- NA
data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred_upper")] <- NA
}
} else if (isFALSE(keep_missing) && !is.null(outcome_levels)) {
data <- data[!is.na(data[, outcome_name]), ]
}
#----------------------------------------------------------------------------
# Residual calculations.
if (is.null(outcome_levels)) { # Numeric outcome.
data$residual <- data[, outcome_name] - data[, paste0(outcome_name, "_pred")]
} else { # Factor outcome.
if (factor_level) {
# Binary accuracy/residual. A residual of 1 is an incorrect classification.
data$residual <- ifelse(as.character(data[, outcome_name, drop = TRUE]) != as.character(data[, paste0(outcome_name, "_pred"), drop = TRUE]), 1, 0)
} else { # Class probabilities were predicted.
data_residual <- 1 - data[, names(data) %in% outcome_levels]
names(data_residual) <- paste0(names(data_residual), "_residual")
data <- dplyr::bind_cols(data, data_residual)
rm(data_residual)
}
}
#----------------------------------------------------------------------------
# Filtering results based on user input.
models <- if (is.null(models)) {unique(data$model)} else {models}
if (method == "direct") {
horizons <- if (is.null(horizons)) {unique(data$model_forecast_horizon)} else {horizons}
} else {
horizons <- if (is.null(horizons)) {unique(data$horizon)} else {horizons}
}
windows <- if (is.null(windows)) {unique(data$window_number)} else {windows}
if (is.null(date_indices)) {
valid_indices <- if (is.null(valid_indices)) {unique(data$valid_indices)} else {valid_indices}
} else {
valid_indices <- if (is.null(valid_indices)) {unique(data$date_indices)} else {date_indices}
}
#----------------------------------------------------------------------------
data_plot <- data
#----------------------------------------------------------------------------
# Rename for consistency with plotting code.
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon
data_plot$model_forecast_horizon <- NULL
}
data_plot <- data_plot[data_plot$model %in% models & data_plot$horizon %in% horizons &
data_plot$window_number %in% windows, ]
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
if (methods::is(valid_indices, "Date") || methods::is(valid_indices, "POSIXt")) {
data_plot <- data_plot[data_plot$date_indices %in% valid_indices, ] # Filter plots by dates.
data_plot$index <- data_plot$date_indices
} else {
data_plot <- data_plot[data_plot$valid_indices %in% valid_indices, ] # Filter plots by row indices.
if (!is.null(date_indices)) {
data_plot$index <- data_plot$date_indices
} else {
data_plot$index <- data_plot$valid_indices
}
}
#----------------------------------------------------------------------------
# Set up ggplot color and group parameters.
if (is.null(outcome_levels)) { # Numeric outcomes.
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number, eval(parse(text = groups)))
}
#--------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
legend_title <- paste(temp_1, temp_2, sep = "")
legend_title <- paste(legend_title, collapse = " + ")
#--------------------------------------------------------------------------
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
# Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Prediction"
}
# Used to avoid lines spanning any gaps between validation windows.
if (all(data_plot$window_number == 1)) { # One window; no need to add the window number to the legend.
data_plot$ggplot_group <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_group <- apply(data_plot[, c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
# Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
#----------------------------------------------------------------------------
if (!is.null(date_indices)) {
# Create a dataset of points for those instances where there the outcomes are NA before and after a given instance.
# Points are needed because ggplot will not plot a 1-instance geom_line().
data_plot_point <- data_plot %>%
dplyr::group_by(.data$ggplot_color) %>%
dplyr::mutate("lag" = dplyr::lag(eval(parse(text = outcome_name)), 1),
"lead" = dplyr::lead(eval(parse(text = outcome_name)), 1)) %>%
dplyr::filter(is.na(.data$lag) & is.na(.data$lead))
data_plot <- data_plot[data_plot$date_indices %in% valid_indices, ]
# This may be an empty data.frame if every time series has 2 or more contiguous records, and
# suppressWarnings() suppresses a forcats warning.
data_plot_point <- suppressWarnings(data_plot_point[data_plot_point$date_indices %in% date_indices, ])
}
#--------------------------------------------------------------------------
} else { # Factor outcomes.
data_plot$ggplot_color_group <- apply(data_plot[, c("model", "horizon", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
if (type %in% c("prediction", "residual")) {
#--------------------------------------------------------------------------
# Melt the data for plotting.
if (is.null(outcome_levels)) { # Numeric outcome.
data_plot <- tidyr::gather(data_plot, "outcome", "value",
-!!names(data_plot)[!names(data_plot) %in% c(outcome_name, paste0(outcome_name, "_pred"))])
} else { # Factor outcome.
if (any(names(data_plot) %in% paste0(outcome_name, "_pred"))) { # A factor level was predicted.
data_plot <- tidyr::gather(data_plot, "outcome", "value",
-!!names(data_plot)[!names(data_plot) %in% c(outcome_name, paste0(outcome_name, "_pred"))])
} else { # Class probabilities were predicted.
data_plot <- suppressWarnings(tidyr::gather(data_plot, "outcome", "value",
-!!names(data_plot)[!names(data_plot) %in% c(outcome_name, outcome_levels)]))
}
}
#--------------------------------------------------------------------------
# If date indices exist, plot with them.
if (!is.null(date_indices)) {
data_plot$index <- data_plot$date_indices
}
#--------------------------------------------------------------------------
if (type == "prediction") {
if (is.null(outcome_levels)) { # Numeric outcome; plot historical predictions.
p <- ggplot()
#----------------------------------------------------------------------
# Plot actuals on the first or lowest layer.
if (is.null(groups)) { # Single time series.
p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, y = .data$value, group = .data$ggplot_group), color = "grey50")
} else { # Actuals, grouped time series.
# If faceting by group, this reduces to the single time series case so the actuals
# will be the default grey so as not to double encode the plot data.
if (any(facet_names %in% groups)) {
p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, y = .data$value), color = "grey50")
} else {
p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, y = .data$value,
group = .data$ggplot_group,
color = .data$ggplot_color), linetype = 2)
p <- p + scale_color_viridis_d()
}
}
#----------------------------------------------------------------------
# Plot predictions.
# If the plotting data.frame has both lower and upper forecasts plot these bounds.
# We'll add the shading in a lower ggplot layer so the point forecasts are on top in the final plot.
if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {
p <- p + geom_ribbon(data = data_plot[data_plot$outcome == outcome_name, ],
aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
fill = .data$ggplot_color, group = .data$ggplot_group, color = NULL), alpha = .25, show.legend = FALSE)
}
p <- p + geom_line(data = data_plot[data_plot$outcome != outcome_name, ], # Predictions in melted data.
aes(x = .data$index, y = .data$value, group = .data$ggplot_group, color = .data$ggplot_color),
size = 1.05, linetype = 1)
if (!is.null(facet)) {
p <- p + facet_grid(facet, drop = TRUE)
}
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
#--------------------------------------------------------------------------
#--------------------------------------------------------------------------
} else { # Factor outcome.
if (factor_prob) { # Plot class probabilities with un-grouped data.
if (is.null(groups)) {
# Only 1 actual needs to be plotted.
data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"
data_actual <- data_plot[data_plot$outcome %in% outcome_name, ]
data_pred <- data_plot[data_plot$outcome %in% outcome_levels, ]
data_actual$value <- factor(data_actual$value, levels = outcome_levels, ordered = TRUE)
data_pred$outcome <- factor(data_pred$outcome, levels = outcome_levels, ordered = TRUE)
data_pred$value <- as.numeric(data_pred$value)
p <- ggplot()
p <- p + geom_col(data = data_pred,
aes(x = .data$index, y = .data$value, color = .data$outcome, fill = .data$outcome),
position = position_stack(reverse = TRUE))
p <- p + geom_col(data = data_actual,
aes(x = .data$index, y = 1, color = .data$value, fill = .data$value))
p <- p + scale_y_continuous(limits = 0:1)
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_wrap(~ ggplot_color_group, ncol = 1, scales = "free_y")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
} else { # Plot class probabilities with grouped data.
data_plot$actual_or_pred <- NA
data_plot$actual_or_pred[data_plot$outcome == outcome_name] <- "Actual"
data_plot$actual_or_pred[data_plot$outcome != outcome_name] <- "Prediction"
data_plot$y_value <- NA
data_plot$y_value[data_plot$outcome == outcome_name] <- 1
data_plot$y_value[data_plot$outcome != outcome_name] <- data_plot$value[data_plot$outcome != outcome_name]
data_plot$y_value <- as.numeric(data_plot$y_value)
data_plot$plot_color <- NA
data_plot$plot_color[data_plot$outcome == outcome_name] <- data_plot$value[data_plot$outcome == outcome_name]
data_plot$plot_color[data_plot$outcome != outcome_name] <- data_plot$outcome[data_plot$outcome != outcome_name]
data_plot$plot_color <- factor(data_plot$plot_color, levels = outcome_levels, ordered = TRUE)
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = .data$y_value, color = .data$plot_color, fill = .data$plot_color),
position = position_stack(reverse = TRUE))
p <- p + scale_y_continuous(limits = 0:1, breaks = c(0, .5, 1))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_grid(ggplot_color_group + actual_or_pred ~ ., scales = "free_y")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
} # End plot class probabilities with grouped data.
#------------------------------------------------------------------------
} else { # Plot class levels.
if (is.null(groups)) { # Plot class levels with un-grouped data.
# Only 1 actual needs to be plotted.
data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"
data_plot$value <- factor(data_plot$value, levels = c(outcome_levels), ordered = TRUE)
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = 1, color = .data$value, fill = .data$value),
position = position_stack(reverse = TRUE))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_wrap(~ ggplot_color_group, ncol = 1, scales = "free_y")
p <- p + theme_bw() + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank())
} else { # Plot class levels with grouped data.
data_plot$actual_or_pred <- NA
data_plot$actual_or_pred[data_plot$outcome == outcome_name] <- "Actual"
data_plot$actual_or_pred[data_plot$outcome != outcome_name] <- "Prediction"
data_plot$value <- factor(data_plot$value, levels = outcome_levels, ordered = TRUE)
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = 1, color = .data$value, fill = .data$value),
position = position_stack(reverse = TRUE))
p <- p + scale_y_continuous(limits = 0:1, breaks = c(0, .5, 1))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
p <- p + facet_grid(ggplot_color_group + actual_or_pred ~ ., scales = "free_y")
p <- p + theme_bw() + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), panel.spacing = unit(0, "lines"))
}
}
} # End prediction plots for numeric and factor outcomes.
#--------------------------------------------------------------------------
} else if (type == "residual") {
if (is.null(outcome_levels)) { # Numeric outcome.
# Plot historical predictions.
p <- ggplot(data_plot[data_plot$outcome != outcome_name, ],
aes(x = .data$index, y = .data$residual,
group = .data$ggplot_group, color = .data$ggplot_color))
p <- p + geom_line(size = 1.05, linetype = 1)
p <- p + geom_hline(yintercept = 0)
p <- p + scale_color_viridis_d()
p <- p + facet_grid(facet, drop = TRUE)
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
} else { # Factor outcome.
if (is.null(groups)) {
data_plot$ggplot_color_group <- factor(as.character(data_plot$ggplot_color_group), ordered = TRUE,
levels = c(unique(as.character(data_plot$ggplot_color_group)), "Actual"))
data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"
}
p <- ggplot()
# Plot predictions to avoid duplicate residuals in plots.
p <- p + geom_tile(data = data_plot[data_plot$outcome != outcome_name, ], aes(x = .data$index, y = .data$ggplot_color_group,
fill = ordered(.data$residual)))
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
}
} # End residual plot.
#--------------------------------------------------------------------------
# Plot labels.
if (type == "prediction") {
if (is.null(outcome_levels)) { # Numeric outcome.
if (!is.null(groups)) { # Grouped time series.
if (any(facet_names %in% groups)) {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
ggtitle("Forecasts vs. Actuals Through Time")
} else {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
ggtitle("Forecasts vs. Actuals Through Time", subtitle = c("Dashed lines are actuals"))
}
} else { # Single time series.
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
ggtitle("Forecasts vs. Actuals Through Time")
}
} else { # Factor outcome.
if (factor_prob) {
p <- p + xlab("Dataset index") + ylab("Outcome probability") + labs(color = "Outcome", fill = "Outcome") +
ggtitle("Forecasts vs. Actuals Through Time")
} else {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Outcome", fill = "Outcome") +
ggtitle("Forecasts vs. Actuals Through Time")
}
}
} else if (type == "residual") {
if (is.null(outcome_levels)) {
p <- p + xlab("Dataset index") + ylab("Residual") + labs(color = NULL, group = NULL) +
ggtitle("Forecast Error Through Time")
} else {
if (is.null(groups)) {
p <- p + xlab("Dataset index") + ylab("Residual and model") + labs(fill = "Residual") +
ggtitle("Forecast Error Through Time")
} else {
p <- p + xlab("Dataset index") + ylab("Residual, model, and group") + labs(fill = "Residual") +
ggtitle("Forecast Error Through Time")
}
}
}
return(suppressWarnings(p))
}
#----------------------------------------------------------------------------
if (type %in% c("forecast_stability")) {
data_plot$forecast_origin <- with(data_plot, valid_indices - horizon)
data_plot$group <- with(data_plot, paste0(valid_indices))
data_plot$group <- ordered(data_plot$group)
# Plotting the original time-series in each facet. Because the plot is faceted by valid_indices, we'll do a bit of a hack here to create
# the same line plot for each facet.
data_outcome <- data_plot %>%
dplyr::select(valid_indices, !!outcome_name) %>%
dplyr::distinct(valid_indices, .keep_all = TRUE)
data_outcome$index <- data_outcome$valid_indices
data_outcome$valid_indices <- NULL # remove to avoid confusion in facet_wrap()
data_outcome <- data_outcome[rep(1:nrow(data_outcome), length(unique(data_plot$valid_indices))), ]
p <- ggplot()
if (max(data_plot$horizon) != 1) {
p <- p + geom_line(data = data_plot, aes(x = .data$forecast_origin,
y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = factor(.data$model)), size = 1, linetype = 1, show.legend = FALSE)
}
p <- p + geom_point(data = data_plot, aes(x = .data$forecast_origin,
y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = factor(.data$model)))
p <- p + geom_point(data = data_plot, aes(x = .data$valid_indices,
y = eval(parse(text = outcome_name)), fill = "Actual"))
p <- p + scale_color_viridis_d()
p <- p + facet_wrap(~ valid_indices)
p <- p + geom_line(data = data_outcome, aes(x = .data$index,
y = eval(parse(text = outcome_name))), color = "gray50")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Model") + labs(fill = NULL) +
ggtitle("Rolling Origin Forecast Stability - Faceted by dataset index")
return(suppressWarnings(p))
}
} # nocov end
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot an object of class forecast_results
#'
#' A forecast plot for each horizon for each model in \code{predict.forecast_model()}.
#'
#' @param x An object of class 'forecast_results' from \code{predict.forecast_model()}.
#' @param data_actual A data.frame containing the target/outcome name and any grouping columns.
#' The data can be historical actuals and/or holdout/test data.
#' @param actual_indices Required if \code{data_actual} is given. A vector or 1-column data.frame
#' of numeric row indices or dates (class 'Date' or 'POSIXt') with length \code{nrow(data_actual)}.
#' The data can be historical actuals and/or holdout/test data.
#' @param facet Optional. For numeric outcomes, a formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series)
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}.
#' @param models Optional. Filter results by user-defined model name from \code{train_model()}.
#' @param horizons Optional. Filter results by horizon.
#' @param windows Optional. Filter results by validation window number.
#' @param group_filter Optional. A string for filtering plot results for grouped time-series (e.g., \code{"group_col_1 == 'A'"});
#' passed to \code{dplyr::filter()} internally.
#' @param ... Not used.
#' @return Forecast plot of class 'ggplot'.
#' @export
plot.forecast_results <- function(x, data_actual = NULL, actual_indices = NULL, facet = horizon ~ model,
models = NULL, horizons = NULL, windows = NULL,
group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
if(xor(is.null(data_actual), is.null(actual_indices))) {
stop("If plotting a hold-out or comparison dataset, both 'data_actual' and 'actual_indices' need to be specified.")
}
data_forecast <- x
rm(x)
type <- "forecast" # Only one plot option at present.
#----------------------------------------------------------------------------
method <- attributes(data_forecast)$method
forecast_horizons <- attributes(data_forecast)$horizons
outcome_col <- attributes(data_forecast)$outcome_col
outcome_name <- attributes(data_forecast)$outcome_name
outcome_levels <- attributes(data_forecast)$outcome_levels
date_indices <- attributes(data_forecast)$date_indices
groups <- attributes(data_forecast)$group
data_stop <- attributes(data_forecast)$data_stop
data_forecast <- as.data.frame(data_forecast) # Coerce to remove new vctrs warning about mismatched classes when joining data.
if (!is.null(outcome_levels) && !is.null(groups) && !is.null(data_actual)) {
stop("Plotting forecasts from grouped time series with an actuals dataset is not currently supported.")
}
#----------------------------------------------------------------------------
names(data_forecast)[names(data_forecast) == "forecast_period"] <- "index" # For code uniformity.
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability.
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(data_forecast) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
}
#----------------------------------------------------------------------------
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
#----------------------------------------------------------------------------
if (!is.null(data_actual)) {
data_actual <- data_actual[, c(outcome_name, groups), drop = FALSE]
data_actual$index <- actual_indices
if (!is.null(group_filter)) {
data_actual <- dplyr::filter(data_actual, eval(parse(text = group_filter)))
}
}
#----------------------------------------------------------------------------
# Filter plots using user input.
models <- if (is.null(models)) {unique(data_forecast$model)} else {models}
horizons <- if (is.null(horizons)) {forecast_horizons} else {horizons}
windows <- if (is.null(windows)) {unique(data_forecast$window_number)} else {windows}
if (method == "multi_output") {
data_forecast$model_forecast_horizon <- data_forecast$horizon # Used for filtering below.
}
data_forecast <- data_forecast[data_forecast$model %in% models &
data_forecast$model_forecast_horizon %in% horizons &
data_forecast$window_number %in% windows, ]
if (!is.null(group_filter)) {
data_forecast <- dplyr::filter(data_forecast, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
data_plot <- data_forecast
#----------------------------------------------------------------------------
if (type == "forecast") {
#----------------------------------------------------------------------------
# Set up ggplot color and group parameters.
if (is.null(outcome_levels)) { # Numeric outcomes.
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$model_forecast_horizon, .data$horizon, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$model_forecast_horizon, .data$window_number, .data$horizon, eval(parse(text = groups)))
}
#--------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
if (method == "multi_output") { # With 1 model, all horizons belong to the same forecast and should have the same color.
ggplot_color <- ggplot_color[!ggplot_color %in% "horizon"]
}
#--------------------------------------------------------------------------
# There is a distinction between horizon and the model horizon; rename to work with the facet inputs.
data_plot$forecast_horizon <- data_plot$horizon
data_plot$horizon <- data_plot$model_forecast_horizon
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
# Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Forecast"
}
# Used to avoid lines spanning any gaps between validation windows.
if (all(data_plot$window_number == 1)) { # One window; no need to add the window number to the legend.
data_plot$ggplot_group <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_group <- apply(data_plot[, c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
# Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
data_plot$ggplot_group <- factor(data_plot$ggplot_group, levels = unique(data_plot$ggplot_group), ordered = TRUE)
}
#--------------------------------------------------------------------------
if (!is.null(data_actual)) {
#--------------------------------------------------------------------------
# If the plot is faceted by model, repeat the actuals dataset once for each model in a long format for faceting by model.
if (length(unique(data_plot$model)) == 1) {
data_actual$model <- unique(data_plot$model)
} else {
n_reps <- nrow(data_actual)
data_actual <- data_actual[rep(1:nrow(data_actual), length(unique(data_plot$model))), ]
data_actual$model <- rep(unique(data_plot$model), each = n_reps)
}
#--------------------------------------------------------------------------
# If the plot is colored by model, repeat the actuals dataset once for each model in a long format for faceting by model.
if (length(unique(data_plot$model)) == 1) {
data_actual$model <- unique(data_plot$model)
} else {
n_reps <- nrow(data_actual)
data_actual <- data_actual[rep(1:nrow(data_actual), length(unique(data_plot$model))), ]
data_actual$model <- rep(unique(data_plot$model), each = n_reps)
}
#--------------------------------------------------------------------------
if (length(ggplot_color) > 0) {
if (ggplot_color == "horizon") {
data_actual$ggplot_color <- NA
} else {
data_actual$ggplot_color <- apply(data_actual[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
} else {
data_actual$ggplot_color <- "Forecast"
}
# Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_actual$ggplot_color <- "Forecast"
}
if (length(ggplot_color) > 0) {
if (ggplot_color == "horizon") {
data_actual$ggplot_group <- "Forecast"
} else {
data_actual$ggplot_group <- apply(data_actual[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
} else {
data_actual$ggplot_group <- "Forecast"
}
#------------------------------------------------------------------------
# Coerce to viridis color scale with an ordered factor. The levels in the actual data are limited
# to those factor levels that appear in the forecast data.
data_actual$ggplot_color <- factor(data_actual$ggplot_color, levels = levels(data_plot$ggplot_color), ordered = TRUE)
data_actual$ggplot_group <- factor(data_actual$ggplot_group, levels = levels(data_plot$ggplot_color), ordered = TRUE)
}
#--------------------------------------------------------------------------
if (is.null(outcome_levels)) { # Numeric outcome.
p <- ggplot()
if (1 %in% horizons || nrow(data_plot) == 1 || (method == "multi_output" && any(c(facet_names %in% "horizon")))) { # Use geom_point instead of geom_line to plot a 1-step-ahead forecast.
if (method == "direct") {
data_plot_temp <- data_plot[data_plot$horizon == 1, ]
} else if (method == "multi_output") {
data_plot_temp <- data_plot
}
# If the plotting data.frame has both lower and upper forecasts plot these bounds.
# We'll add the shading in a lower ggplot layer so the point forecasts are on top in the final plot.
if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {
# geom_ribbon() does not work with a single data point when forecast bounds are plotted.
p <- p + geom_linerange(data = data_plot_temp,
aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
color = .data$ggplot_color, group = .data$ggplot_group), alpha = .25, size = 3, show.legend = FALSE)
}
p <- p + geom_point(data = data_plot_temp,
aes(x = .data$index, y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = .data$ggplot_color, group = .data$ggplot_group))
}
if ((method == "direct" && !all(horizons == 1)) || (method == "multi_output" && nrow(data_plot) > 1 && !any(c(facet_names %in% "horizon")))) { # Plot forecasts for model forecast horizons > 1.
if (method == "direct") {
data_plot_temp <- data_plot[data_plot$horizon != 1, ]
} else if (method == "multi_output") {
data_plot_temp <- data_plot
}
# If the plotting data.frame has bother lower and upper forecasts plot these bounds.
if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {
p <- p + geom_ribbon(data = data_plot_temp,
aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
fill = .data$ggplot_color, group = .data$ggplot_group, color = NULL), alpha = .25, show.legend = FALSE)
}
p <- p + geom_line(data = data_plot_temp,
aes(x = .data$index, y = eval(parse(text = paste0(outcome_name, "_pred"))),
color = .data$ggplot_color, group = .data$ggplot_group))
}
# Add user-defined actuals data to the plots
if (!is.null(data_actual)) {
if (is.null(groups)) {
p <- p + geom_line(data = data_actual, aes(x = .data$index,
y = eval(parse(text = outcome_name))), color = "grey50")
} else {
# If faceting by group, this reduces to the single time series case so the actuals
# will be the default grey so as not to double encode the plot data.
if (any(facet_names %in% groups)) {
p <- p + geom_line(data = data_actual, aes(x = .data$index, y = eval(parse(text = outcome_name))), color = "grey50", show.legend = FALSE)
} else if (facet_names != c("model", "horizon") || facet_names != c("horizon", "model")) { # The actuals colors cannot be uniquely mapped to the forecast plot colors.
p <- p + geom_line(data = data_actual, aes(x = .data$index, y = eval(parse(text = outcome_name))), color = "grey50", show.legend = FALSE)
} else if (facet_names == c("model", "horizon") || facet_names == c("horizon", "model")) { # The actuals can be uniquely mapped to the forecasts given the faceting.
p <- p + geom_line(data = data_actual, aes(x = .data$index,
y = eval(parse(text = outcome_name)),
color = .data$ggplot_color,
group = .data$ggplot_group), show.legend = FALSE)
}
}
}
p <- p + scale_color_viridis_d()
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + facet_grid(facet, scales = "free_y")
#--------------------------------------------------------------------------
#--------------------------------------------------------------------------
} else { # Factor outcome.
data_plot <- data_forecast
if (factor_prob) {
# Melt the data for plotting the multiple class probabilities in stacked bars.
data_plot <- tidyr::gather(data_plot, "outcome", "value", -!!names(data_plot)[!names(data_plot) %in% c(outcome_levels)])
}
#------------------------------------------------------------------------
# If there is only 1 validation window, remove it from the grouping to reduce clutter.
if (length(unique(data_plot$window_number)) == 1) {
data_plot$ggplot_color_group <- apply(data_plot[, c("model", "model_forecast_horizon", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_color_group <- apply(data_plot[, c("model", "model_forecast_horizon", "window_number", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
#------------------------------------------------------------------------
if (!is.null(data_actual)) {
# actual or forecast: these are all actuals.
data_actual$actual_or_forecast <- "actual"
# historical, test, or model_forecast: these may be any combination of historical data and a holdout test dataset.
data_actual$time_series_type <- with(data_actual, ifelse(index <= data_stop, "historical", "test"))
names(data_actual)[names(data_actual) == outcome_name] <- "outcome" # Standardize before concat with forecasts.
data_actual$ggplot_color_group <- "Actual" # Actuals will be plotted in the top plot facet.
data_actual$value <- 1 # Plot a solid bar with probability 1 in geom_col().
# In cases where historical data is provided in data_actual, duplicate the historical data
# such that it appears as a sequence in each plot facet. Here, 'ggplot_color_group' gives the,
# possibly user-filtered, plot facets.
if ("historical" %in% unique(data_actual$time_series_type)) {
data_hist <- data_actual[data_actual$time_series_type == "historical", c("index", "outcome", "value", "ggplot_color_group")]
n_rows <- nrow(data_hist)
data_hist <- data_hist[rep(1:nrow(data_hist), length(unique(data_plot$ggplot_color_group))), ]
data_hist$ggplot_color_group <- rep(unique(data_plot$ggplot_color_group), each = n_rows)
data_actual <- suppressWarnings(dplyr::bind_rows(data_hist, data_actual))
}
}
#------------------------------------------------------------------------
# Standardize names for plotting and before any concatenation with data_actual.
names(data_plot)[names(data_plot) == "index"] <- "index"
data_plot$actual_or_forecast <- "forecast"
data_plot$time_series_type <- "model_forecast"
#------------------------------------------------------------------------
if (factor_level) {
# Standardize names for plotting and before any concatenation with data_actual.
names(data_plot)[names(data_plot) == paste0(outcome_name, "_pred")] <- "outcome"
data_plot$value <- 1
}
#------------------------------------------------------------------------
if (!is.null(data_actual)) {
data_plot <- suppressWarnings(dplyr::bind_rows(data_plot, data_actual))
}
data_plot$ggplot_color_group <- factor(data_plot$ggplot_color_group, levels = rev(unique(data_plot$ggplot_color_group)), ordered = TRUE)
data_plot$value <- as.numeric(data_plot$value)
data_plot$outcome <- factor(data_plot$outcome, levels = outcome_levels, ordered = TRUE)
#------------------------------------------------------------------------
if (!is.null(groups)) {
data_plot <- dplyr::distinct(data_plot, .data$ggplot_color_group, .data$index, .data$outcome, .keep_all = TRUE)
}
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$index, y = .data$value, color = .data$outcome, fill = .data$outcome),
position = position_stack(reverse = TRUE))
p <- p + scale_color_viridis_d(drop = FALSE)
p <- p + scale_fill_viridis_d(drop = FALSE)
if (is.null(groups)) {
p <- p + facet_wrap(~ ggplot_color_group, scales = "free_y")
} else {
p <- p + facet_grid(ggplot_color_group ~ ., scales = "free_y")
}
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
if (factor_level) {
p <- p + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), panel.spacing = unit(0, "lines"))
}
} # End numeric and factor outcome plot setup.
#--------------------------------------------------------------------------
# Add a vertical line to mark the beginning of the forecast period.
p <- p + geom_vline(xintercept = data_stop, color = "red")
#--------------------------------------------------------------------------
if (is.null(outcome_levels)) { # Numeric outcome.
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
x_axis_title <- paste(temp_1, temp_2, sep = "")
x_axis_title <- paste(x_axis_title, collapse = " + ")
p <- p + xlab("Dataset index") + ylab("Outcome") +
labs(color = x_axis_title, fill = NULL) +
ggtitle("H-Step-Ahead Model Forecasts")
} else { # Factor outcome.
p <- p + xlab("Dataset index") + ylab("Outcome") +
labs(color = "Outcome", fill = "Outcome") +
ggtitle("H-Step-Ahead Model Forecasts")
}
return(suppressWarnings(p))
}
} # nocov end
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.