#' Compute forecast error
#'
#' Compute forecast error metrics on the validation datasets or a new test dataset.
#'
#' @param data_results An object of class 'training_results' or 'forecast_results' from running (a)
#' \code{\link[=predict.forecast_model]{predict}} on a trained model or (b) \code{combine_forecasts()}.
#' @param data_test Required for forecast results only. If \code{data_results} is an object of class 'forecast_results', a data.frame used to
#' assess the accuracy of a 'forecast_results' object. \code{data_test} should have the outcome/target columns
#' and any grouping columns.
#' @param test_indices Required if \code{data_test} is given or 'rmsse' %in% \code{metrics}. A vector or 1-column data.frame of numeric
#' row indices or dates (class 'Date' or 'POSIXt') with length \code{nrow(data_test)}.
#' @param aggregate Default \code{median}. A function--without parentheses--that aggregates historical prediction or forecast error across time series.
#' All error metrics are first calculated at the level of the individual time series. \code{aggregate} is then used to combine error metrics across
#' validation windows and horizons. Aggregations are returned at the group level if \code{data_results} contains groups.
#' @param metrics A character vector of common forecast error metrics. The default behavior is to return all metrics.
#' @param models Optional. A character vector of user-defined model names supplied to \code{train_model()} to filter results.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param group_filter Optional. A string for filtering plot results for grouped time series
#' (e.g., \code{"group_col_1 == 'A'"}). \code{group_filter} is passed to \code{dplyr::filter()} internally.
#'
#' @return An S3 object of class 'validation_error', 'forecast_error', or 'forecastML_error': A list of data.frames
#' of error metrics for the validation or forecast dataset depending on the class of \code{data_results}: 'training_results',
#' 'forecast_results', or 'forecastML' from \code{combine_forecasts()}.
#'
#' A list containing: \cr
#'
#' \itemize{
#' \item Error metrics by model, horizon, and validation window
#' \item Error metrics by model and horizon, collapsed across validation windows
#' \item Global error metrics by model collapsed across horizons and validation windows
#'}
#' @section Error Metrics:
#'
#' \itemize{
#' \item \code{mae}: Mean absolute error (works with factor outcomes)
#' \item \code{mape}: Mean absolute percentage error
#' \item \code{mdape}: Median absolute percentage error
#' \item \code{smape}: Symmetrical mean absolute percentage error
#' \item \code{rmse}: Root mean squared error
#' \item \code{rmsse}: Root mean squared scaled error from the M5 competition
#'}
#' @section Methods and related functions:
#'
#' The output of \code{return_error()} has the following generic S3 methods
#'
#' \itemize{
#' \item \code{\link[=plot.validation_error]{plot}} from \code{return_error()}
#' \item \code{\link[=plot.forecast_error]{plot}} from \code{return_error()}
#' }
#' @example /R/examples/example_return_error.R
#' @export
return_error <- function(data_results, data_test = NULL, test_indices = NULL, aggregate = stats::median,
metrics = c("mae", "mape", "mdape", "smape", "rmse", 'rmsse'),
models = NULL, horizons = NULL, windows = NULL, group_filter = NULL) {
if (!(methods::is(data_results, "training_results") || methods::is(data_results, "forecast_results"))) {
stop("The 'data_results' argument takes an object of class 'training_results' or 'forecast_results' as input. Run predict() on a 'forecast_model' object first.")
}
if (methods::is(data_results, "training_results") && "rmsse" %in% metrics && any(is.null(data_test), is.null(test_indices))) {
warning("'rmsse' was not calculated. The 'rmsse' metric needs a dataset of actuals passed in 'data_test' and 'test_indices'.")
}
if (methods::is(data_results, "forecast_results") && is.null(data_test)) {
stop("Computing forecast error metrics requires a data.frame input to the 'data_test' argument.")
}
if (xor(is.null(data_test), is.null(test_indices))) {
stop("If using a test dataset to assess forecast error, both 'data_test' and 'test_indices' need to be specified.")
}
# The order of these available metrics should match the error_functions vector later in the script. Only 'mae'
# is available for factor outcomes at present; an error will be thrown below if this is not the case.
if (is.null(data_test)) { # The M5 rmsse requires a dataset of actuals.
error_metrics <- c("mae", "mape", "mdape", "smape", "rmse")
} else {
error_metrics <- c("mae", "mape", "mdape", "smape", "rmse", 'rmsse')
}
# Filter the user input error metrics to only those that are available.
metrics <- metrics[metrics %in% error_metrics]
if (length(metrics) == 0) {
stop("None of the error 'metrics' match any of 'mae', 'mape', 'mdape', 'smape', 'rmse' or 'rmsse'.")
}
# The return() from combine_forecasts(), 'forecastML', is also an object of class 'forecast_results' but does not need
# filtering, so these input types will be handled slightly differently.
is_forecastML <- methods::is(data_results, "forecastML")
data <- data_results
method <- attributes(data)$method
outcome_name <- attributes(data)$outcome_name
outcome_levels <- attributes(data)$outcome_levels
groups <- attributes(data)$groups
if (is_forecastML) {type <- attributes(data)$type} # 'horizon' or 'error'; used for group_by() for outputs from combine_forecasts().
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability?
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(data) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
if (!all(metrics %in% c("mae"))) {
stop("Only the 'mae' metric is available for factor outcomes. Set 'metrics = 'mae'' and re-run.")
}
# This will eventually change.
if (factor_prob) {
stop("Error metrics with predicted class probabilities are not currently supported.")
}
}
#----------------------------------------------------------------------------
# If dates were given, use dates.
if (!is.null(data$date_indices)) {
data$valid_indices <- data$date_indices
}
#----------------------------------------------------------------------------
# Merge user-supplied test data to the forecasts from predict.forecast_model()
if (methods::is(data, "forecast_results")) {
data_test$forecast_period <- test_indices
data_test <- dplyr::select(data_test, .data$forecast_period, !!outcome_name, !!groups)
data <- dplyr::inner_join(data, data_test, by = c("forecast_period", groups))
if (nrow(data) == 0) {
stop("The test dataset in 'data_test' does not overlap with the forecast period in 'data_results'.")
}
}
#----------------------------------------------------------------------------
# Special rmsse error metric from the M5 forecasting competition. This metric needs
# access to the actuals from the training data which is unique for the supported error metrics.
if (any(metrics %in% "rmsse")) {
data_test$index <- test_indices
data_test <- data_test %>% dplyr::select(.data$index, !!outcome_name, !!groups)
# Repeat prediction data for multiple windows.
n_windows <- length(unique(data_results$window_number))
if (n_windows == 0) {n_windows <- 1}
n_rows <- nrow(data_test)
data_test <- data_test[rep(1:n_rows, n_windows), , drop = FALSE]
data_test$window_number <- rep(1:n_windows, each = n_rows)
if (methods::is(data_results, "training_results")) {
data_merge <- data %>% dplyr::select(.data$valid_indices, .data$window_number, !!groups)
names(data_merge)[names(data_merge) == "valid_indices"] <- "index"
data_merge$instance_in_valid_dataset <- TRUE
}
if (methods::is(data_results, "training_results")) {
if (unique(data$window_length) != 0) { # Validation rows need to be removed from training rows.
data_test <- dplyr::left_join(data_test, data_merge, by = c("index", "window_number", groups))
data_test <- dplyr::distinct(data_test, .data$index, .data$window_number, !!!rlang::syms(groups), .keep_all = TRUE)
data_test[which(data_test$instance_in_valid_dataset), outcome_name] <- NA
} else { # All training rows are also validation rows--essentially model fit error.
data_test$window_number <- NULL
data_test <- dplyr::left_join(data_test, data_merge, by = c("index", groups))
data_test <- dplyr::distinct(data_test, .data$index, !!!rlang::syms(groups), .keep_all = TRUE)
}
}
data_test <- data_test %>%
dplyr::arrange(!!!rlang::syms(groups), .data$window_number, .data$index) %>%
dplyr::group_by_at(dplyr::vars(.data$window_number, !!groups)) %>%
dplyr::mutate("outcome_lag_1" = dplyr::lag(!!rlang::sym(outcome_name), 1),
"outcome_minus_outcome_lag_1" = !!rlang::sym(outcome_name) - .data$outcome_lag_1,
"sse_denom" = sum(.data$outcome_minus_outcome_lag_1^2, na.rm = TRUE),
"n" = sum(!is.na(.data$outcome_minus_outcome_lag_1))) %>% # Used to compute average to convert sse_denom to mse_denom.
dplyr::filter(n > 0) %>% # Avoid a division by zero error for sparse data with NAs or only 1 data point.
dplyr::summarize("mse_denom" = .data$sse_denom[1] / (n[1])) # No need to subtract 1 per the formula, 'n' already accounts for missing data from lags.
if (!is_forecastML) {
data <- dplyr::left_join(data, data_test, by = c("window_number", groups))
} else {
data$window_number <- 1
data <- dplyr::left_join(data, data_test, by = c("window_number", groups))
}
}
#----------------------------------------------------------------------------
# Residual calculations
if (is.null(outcome_levels)) { # Numeric outcome.
data$residual <- data[, outcome_name] - data[, paste0(outcome_name, "_pred")]
} else { # Factor outcome.
if (factor_level) {
# Binary accuracy/residual. A residual of 1 is an incorrect classification.
data$residual <- ifelse(as.character(data[, outcome_name, drop = TRUE]) != as.character(data[, paste0(outcome_name, "_pred"), drop = TRUE]), 1, 0)
} else { # Class probabilities were predicted.
# data_residual <- 1 - data[, names(data) %in% outcome_levels]
# names(data_residual) <- paste0(names(data_residual), "_residual")
# data <- dplyr::bind_cols(data, data_residual)
# rm(data_residual)
}
}
#----------------------------------------------------------------------------
# Filter results based on user input.
if (!is_forecastML) {
models <- if (is.null(models)) {unique(data$model)} else {models}
horizons <- if (is.null(horizons)) {unique(data$model_forecast_horizon)} else {horizons}
windows <- if (is.null(windows)) {unique(data$window_number)} else {windows}
data <- data[data$model %in% models & data$model_forecast_horizon %in% horizons & data$window_number %in% windows, ]
if (!is.null(group_filter)) {
data <- dplyr::filter(data, eval(parse(text = group_filter)))
}
}
#----------------------------------------------------------------------------
# Select error functions. The forecastML internal error functions are in zzz.R.
# The functions are called with named x, y, and z args in dplyr::summarize().
if (is.null(data_test)) { # The M5 rmsse requires a dataset of actuals.
error_functions <- c(forecastML_mae, forecastML_mape, forecastML_mdape, forecastML_smape,
forecastML_rmse)
} else {
error_functions <- c(forecastML_mae, forecastML_mape, forecastML_mdape, forecastML_smape,
forecastML_rmse, forecastML_rmsse)
}
# The user-selected 'metrics' are used to select the appropriate error functions.
select_error_funs <- sapply(metrics, function(x) {which(x == error_metrics)})
error_functions <- error_functions[select_error_funs]
names(error_functions) <- error_metrics[select_error_funs]
#----------------------------------------------------------------------------
# The error combinations at higher levels of aggregation--'data_2' and 'data_3'--are based on
# error metrics calculated at the lowest level--data_1--and combined using the user-supplied
# 'aggregate' method.
# For all error function args: x = 'residual', y = 'actual', and z = 'prediction'. Unused
# function args for a given metric are passed in ... and ignored. See zzz.R.
if (methods::is(data_results, "training_results")) { # Error metrics for training data.
#--------------------------------------------------------------------------
# Compute error metrics at the validation window level.
if (!all(metrics %in% "rmsse")) {
data_1 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number)) %>%
dplyr::mutate("window_start" = min(.data$valid_indices, na.rm = TRUE),
"window_stop" = max(.data$valid_indices, na.rm = TRUE),
"window_midpoint" = base::mean(.data$valid_indices, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon,
.data$window_number, !!groups, .data$window_start,
.data$window_stop, .data$window_midpoint)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_1_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number)) %>%
dplyr::mutate("window_start" = min(.data$valid_indices, na.rm = TRUE),
"window_stop" = max(.data$valid_indices, na.rm = TRUE),
"window_midpoint" = base::mean(.data$valid_indices, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon,
.data$window_number, !!groups, .data$window_start,
.data$window_stop, .data$window_midpoint)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_1_rmsse$rmsse <- with(data_1_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_1_rmsse$h <- NULL
data_1_rmsse$sse_num <- NULL
data_1_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_1 <- data_1_rmsse
} else {
data_1$rmsse <- data_1_rmsse$rmsse
}
}
#--------------------------------------------------------------------------
# Compute error metric by horizon and window length across all validation windows.
data_2 <- data_1 %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon)) %>%
dplyr::mutate("window_start" = min(.data$window_start, na.rm = TRUE),
"window_stop" = max(.data$window_stop, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, !!groups, .data$window_start, .data$window_stop)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#--------------------------------------------------------------------------
# Compute error metric by model.
data_3 <- data_1 %>%
dplyr::group_by_at(dplyr::vars(.data$model)) %>%
dplyr::mutate("window_start" = min(.data$window_start, na.rm = TRUE),
"window_stop" = max(.data$window_start, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups, .data$window_start, .data$window_stop)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#--------------------------------------------------------------------------
} else if (!is_forecastML) { # Error metrics for the forecast_results class which has no validation windows and slightly different grouping columns.
#------------------------------------------------------------------------
# Compute error metrics at the validation window level.
if (!all(metrics %in% "rmsse")) {
data_1 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_1_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number, !!groups)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_1_rmsse$rmsse <- with(data_1_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_1_rmsse$h <- NULL
data_1_rmsse$sse_num <- NULL
data_1_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_1 <- data_1_rmsse
} else {
data_1$rmsse <- data_1_rmsse$rmsse
}
}
#------------------------------------------------------------------------
# Compute error metric by horizon and window across all validation windows.
data_2 <- data_1 %>%
dplyr::ungroup() %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#------------------------------------------------------------------------
# Compute error metric by model.
data_3 <- data_1 %>%
dplyr::ungroup() %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#--------------------------------------------------------------------------
} else { # Final forecasts from combine_forecasts().
#------------------------------------------------------------------------
# There are no validation windows to compute error metrics for.
data_1 <- data.frame()
#------------------------------------------------------------------------
# Compute error metric by model and horizon across all validation windows.
if (!all(metrics %in% "rmsse")) {
if (type == "horizon") {
data_2 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$horizon, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_2_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$horizon, !!groups)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_2_rmsse$rmsse <- with(data_2_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_2_rmsse$h <- NULL
data_2_rmsse$sse_num <- NULL
data_2_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_2 <- data_2_rmsse
} else {
data_2$rmsse <- data_2_rmsse$rmsse
}
}
}
#------------------------------------------------------------------------
# Compute error metric by model.
if (type == "horizon") {
if (type == "horizon") {
data_3 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_3_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_3_rmsse$rmsse <- with(data_3_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_3_rmsse$h <- NULL
data_3_rmsse$sse_num <- NULL
data_3_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_3 <- data_3_rmsse
} else {
data_3$rmsse <- data_3_rmsse$rmsse
}
}
}
#------------------------------------------------------------------------
} # End error metrics for results from combine_forecasts().
#----------------------------------------------------------------------------
data_out <- list("error_by_window" = data_1, "error_by_horizon" = data_2, "error_global" = data_3)
data_out[] <- lapply(data_out, as.data.frame) # Remove the tibble class.
if (methods::is(data_results, "training_results")) { # Validation error.
class(data_out) <- c("validation_error", class(data_out))
} else { # Forecast error
if (!is_forecastML) {
class(data_out) <- c("forecast_error", class(data_out))
} else {
class(data_out) <- c("forecastML_error", class(data_out))
}
}
attr(data_out, "error_metrics") <- metrics
attr(data_out, "outcome_name") <- outcome_name
attr(data_out, "outcome_levels") <- outcome_levels
attr(data_out, "method") <- method
attr(data_out, "groups") <- groups
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot validation dataset forecast error
#'
#' Plot forecast error at various levels of aggregation across validation datasets.
#' @param x An object of class 'validation_error' from \code{return_error()}.
#' @param type Select plot type; \code{type = "window"} is the default plot.
#' @param metric Select error metric to plot (e.g., "mae"); \code{attributes(x)$error_metrics[1]} is the default metric.
#' @param models Optional. A vector of user-defined model names from \code{train_model()} to filter results.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param group_filter A string for filtering plot results for grouped time series (e.g., \code{"group_col_1 == 'A'"}).
#' @param facet Optional. A formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series).
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}. The default faceting is set internally depending on the plot \code{type}.
#' @param ... Not used.
#' @return Forecast error plots of class 'ggplot'.
#' @export
plot.validation_error <- function(x, type = c("window", "horizon", "global"), metric = NULL,
facet = NULL, models = NULL, horizons = NULL, windows = NULL, group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
data_error <- x
type <- type[1]
method <- attributes(data_error)$method
groups <- attributes(data_error)$groups
error_metrics <- metric
if (is.null(error_metrics)) {
error_metrics <- attributes(data_error)$error_metrics[1]
}
if (!error_metrics %in% attributes(data_error)$error_metrics) {
stop("The error metric in 'metric' is not in the validation error dataset 'x'; re-run return_error() with this metric.")
}
#----------------------------------------------------------------------------
# Set default plot facets for each plot type.
if (is.null(facet)) {
if (type == "window") {
facet <- horizon ~ model
} else if (type == "horizon") {
facet <- horizon ~ model
} else if (type == "global") {
facet <- horizon ~ model
}
}
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
if (type == "window") {
data_plot <- data_error$error_by_window
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
} else if (type == "horizon") {
data_plot <- data_error$error_by_horizon
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
data_plot$window_number <- TRUE # Not in data, added for filtering.
} else if (type == "global") {
data_plot <- data_error$error_global
data_plot$model_forecast_horizon <- TRUE # Not in data, added for filtering.
data_plot$window_number <- TRUE # Not in data, added for filtering.
data_plot$horizon <- TRUE # Not in data, added for filtering.
}
#----------------------------------------------------------------------------
# Filter the datasets based on user input.
models <- if (is.null(models)) {unique(data_plot$model)} else {models}
if (is.null(horizons)) {
horizons <- unique(data_plot$horizon)
}
windows <- if (is.null(windows)) {unique(data_plot$window_number)} else {windows}
data_plot <- data_plot[data_plot$horizon %in% horizons & data_plot$model %in% models & data_plot$window_number %in% windows, ]
#----------------------------------------------------------------------------
# User filtering to display select results in grouped time series.
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
# Melt the data for plotting.
data_plot <- tidyr::gather(data_plot, "error_metric", "value", -!!names(data_plot)[!names(data_plot) %in% error_metrics])
#----------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
#----------------------------------------------------------------------------
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$horizon, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$horizon, .data$window_number, !!!rlang::syms(groups))
}
if (type == "global") {
data_plot$horizon <- "All"
data_plot$horizon <- ordered(data_plot$horizon)
}
data_plot$horizon <- factor(data_plot$horizon, levels = unique(data_plot$horizon), ordered = TRUE)
data_plot[, groups] <- lapply(seq_along(data_plot[, groups, drop = FALSE]), function(i) {
factor(data_plot[, groups[i]], levels = unique(data_plot[, groups[i]]), ordered = TRUE)
})
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
# Give error a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Error"
data_plot$ggplot_color <- ordered(data_plot$ggplot_color)
}
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
x_axis_title <- paste(temp_1, temp_2, sep = "")
x_axis_title <- paste(x_axis_title, collapse = " + ")
#----------------------------------------------------------------------------
# Create plots.
if (type == "window") {
p <- ggplot()
if (length(unique(data_plot$window_midpoint)) != 1) {
p <- p + geom_line(data = data_plot, aes(x = .data$window_midpoint,
y = .data$value,
color = .data$ggplot_color,
group = .data$ggplot_color
), size = 1.05)
}
p <- p + geom_point(data = data_plot, aes(x = .data$window_midpoint,
y = .data$value,
color = .data$ggplot_color,
group = .data$ggplot_color
))
p <- p + scale_color_viridis_d()
p <- p + facet_grid(facet, scales = "free")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + xlab("Dataset index") + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(color = x_axis_title) +
ggtitle("Forecast Error by Validation Window")
return(suppressWarnings(p))
}
if (type == "horizon") {
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$ggplot_color,
y = .data$value,
fill = .data$ggplot_color,
group = .data$ggplot_color))
p <- p + facet_grid(facet, scales = "free_y")
p <- p + scale_color_viridis_d()
p <- p + theme_bw() + theme(
panel.spacing = unit(0, "lines"),
axis.text.x = element_blank()
)
p <- p + xlab(x_axis_title) + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(fill = x_axis_title, alpha = NULL) +
ggtitle("Forecast Error by Forecast Horizon")
p
return(suppressWarnings(p))
}
if (type == "global") {
p <- ggplot()
p <- p + geom_bar(data = data_plot,
aes(x = .data$ggplot_color,
y = .data$value,
fill = .data$ggplot_color),
stat = "identity", position = position_dodge(width = 1))
p <- p + scale_fill_viridis_d()
p <- p + facet_grid(facet, scales = "free")
p <- p + theme_bw() + theme(
panel.spacing = unit(0, "lines"),
axis.text.x = element_blank()
)
p <- p + xlab(x_axis_title) + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(fill = x_axis_title, alpha = NULL) +
ggtitle("Forecast Error Across Validation Windows and Horizons")
return(suppressWarnings(p))
}
} # nocov end
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot forecast error
#'
#' Plot forecast error at various levels of aggregation.
#' @param x An object of class 'forecast_error' from \code{return_error()}.
#' @param type Select plot type; \code{type = "global"} is the default plot.
#' @param metric Select error metric to plot (e.g., "mae"); \code{attributes(x)$error_metrics[1]} is the default metric.
#' @param models Optional. A vector of user-defined model names from \code{train_model()} to filter results.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param group_filter A string for filtering plot results for grouped time series (e.g., \code{"group_col_1 == 'A'"}).
#' @param facet Optional. A formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series).
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}. The default faceting is set internally depending on the plot \code{type}.
#' @param ... Not used.
#' @return Forecast error plots of class 'ggplot'.
#' @export
plot.forecast_error <- function(x, type = c("global"), metric = NULL,
facet = NULL, models = NULL, horizons = NULL, windows = NULL, group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
data_error <- x
type <- type[1]
method <- attributes(data_error)$method
groups <- attributes(data_error)$groups
error_metrics <- metric
if (is.null(error_metrics)) {
error_metrics <- attributes(data_error)$error_metrics[1]
}
if (!error_metrics %in% attributes(data_error)$error_metrics) {
stop("The error metric in 'metric' is not in the forecast error dataset 'x'; re-run return_error() with this metric.")
}
#----------------------------------------------------------------------------
# Set default plot facets for each plot type.
if (is.null(facet)) {
if (type == "window") {
facet <- horizon ~ model
} else if (type == "horizon") {
facet <- horizon ~ model
} else if (type == "global") {
facet <- horizon ~ model
}
}
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
if (type == "window") {
data_plot <- data_error$error_by_window
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
} else if (type == "horizon") {
data_plot <- data_error$error_by_horizon
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
data_plot$window_number <- TRUE # Not in data, added for filtering.
} else if (type == "global") {
data_plot <- data_error$error_global
data_plot$model_forecast_horizon <- TRUE # Not in data, added for filtering.
data_plot$window_number <- TRUE # Not in data, added for filtering.
data_plot$horizon <- TRUE # Not in data, added for filtering.
}
#----------------------------------------------------------------------------
# Filter the datasets based on user input.
models <- if (is.null(models)) {unique(data_plot$model)} else {models}
if (is.null(horizons)) {
horizons <- unique(data_plot$model_forecast_horizon)
}
windows <- if (is.null(windows)) {unique(data_plot$window_number)} else {windows}
data_plot <- data_plot[data_plot$model_forecast_horizon %in% horizons & data_plot$model %in% models & data_plot$window_number %in% windows, ]
#----------------------------------------------------------------------------
# User filtering to display select results in grouped time series.
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
# Melt the data for plotting.
data_plot <- tidyr::gather(data_plot, "error_metric", "value", -!!names(data_plot)[!names(data_plot) %in% error_metrics])
#----------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
#----------------------------------------------------------------------------
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number, !!!rlang::syms(groups))
}
data_plot$model_forecast_horizon <- factor(data_plot$model_forecast_horizon, levels = unique(data_plot$model_forecast_horizon), ordered = TRUE)
data_plot[, groups] <- lapply(seq_along(data_plot[, groups, drop = FALSE]), function(i) {
factor(data_plot[, groups[i]], levels = unique(data_plot[, groups[i]]), ordered = TRUE)
})
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
# Give error a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Error"
}
# Used to avoid lines spanning any gaps between validation windows.
if (all(data_plot$window_number == 1)) { # One window; no need to add the window number to the legend.
data_plot$ggplot_group <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_group <- apply(data_plot[, c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
# Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
data_plot$ggplot_group <- factor(data_plot$ggplot_group, levels = unique(data_plot$ggplot_group), ordered = TRUE)
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
x_axis_title <- paste(temp_1, temp_2, sep = "")
x_axis_title <- paste(x_axis_title, collapse = " + ")
if (type == "global") {
data_plot$horizon <- "All"
}
p <- ggplot()
p <- p + geom_bar(data = data_plot,
aes(x = .data$ggplot_color,
y = .data$value,
fill = .data$ggplot_color,
group = .data$ggplot_group),
stat = "identity", position = position_dodge(width = 1))
p <- p + scale_fill_viridis_d()
p <- p + facet_grid(facet, scales = "free")
p <- p + theme_bw() + theme(
panel.spacing = unit(0, "lines"),
axis.text.x = element_blank()
)
p <- p + xlab(x_axis_title) + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(fill = x_axis_title, alpha = NULL)
if (type %in% c("window")) {
p <- p + ggtitle("Forecast Error by Validation Window, Model Forecast Horizon, & Model")
} else if (type %in% c("horizon")) {
p <- p + ggtitle("Forecast Error by Model Forecast Horizon & Model")
} else if (type %in% c("global")) {
p <- p + ggtitle("Forecast Error by Model")
}
return(suppressWarnings(p))
} # nocov end
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.