Nothing
#' Compute forecast error
#'
#' Compute forecast error metrics on the validation datasets or a new test dataset.
#'
#' @param data_results An object of class 'training_results' or 'forecast_results' from running (a)
#' \code{\link[=predict.forecast_model]{predict}} on a trained model or (b) \code{combine_forecasts()}.
#' @param data_test Required for forecast results only. If \code{data_results} is an object of class 'forecast_results', a data.frame used to
#' assess the accuracy of a 'forecast_results' object. \code{data_test} should have the outcome/target columns
#' and any grouping columns.
#' @param test_indices Required if \code{data_test} is given or 'rmsse' %in% \code{metrics}. A vector or 1-column data.frame of numeric
#' row indices or dates (class 'Date' or 'POSIXt') with length \code{nrow(data_test)}.
#' @param aggregate Default \code{median}. A function--without parentheses--that aggregates historical prediction or forecast error across time series.
#' All error metrics are first calculated at the level of the individual time series. \code{aggregate} is then used to combine error metrics across
#' validation windows and horizons. Aggregations are returned at the group level if \code{data_results} contains groups.
#' @param metrics A character vector of common forecast error metrics. The default behavior is to return all metrics.
#' @param models Optional. A character vector of user-defined model names supplied to \code{train_model()} to filter results.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param group_filter Optional. A string for filtering plot results for grouped time series
#' (e.g., \code{"group_col_1 == 'A'"}). \code{group_filter} is passed to \code{dplyr::filter()} internally.
#'
#' @return An S3 object of class 'validation_error', 'forecast_error', or 'forecastML_error': A list of data.frames
#' of error metrics for the validation or forecast dataset depending on the class of \code{data_results}: 'training_results',
#' 'forecast_results', or 'forecastML' from \code{combine_forecasts()}.
#'
#' A list containing: \cr
#'
#' \itemize{
#' \item Error metrics by model, horizon, and validation window
#' \item Error metrics by model and horizon, collapsed across validation windows
#' \item Global error metrics by model collapsed across horizons and validation windows
#'}
#' @section Error Metrics:
#'
#' \itemize{
#' \item \code{mae}: Mean absolute error (works with factor outcomes)
#' \item \code{mape}: Mean absolute percentage error
#' \item \code{mdape}: Median absolute percentage error
#' \item \code{smape}: Symmetrical mean absolute percentage error
#' \item \code{rmse}: Root mean squared error
#' \item \code{rmsse}: Root mean squared scaled error from the M5 competition
#'}
#' @section Methods and related functions:
#'
#' The output of \code{return_error()} has the following generic S3 methods
#'
#' \itemize{
#' \item \code{\link[=plot.validation_error]{plot}} from \code{return_error()}
#' \item \code{\link[=plot.forecast_error]{plot}} from \code{return_error()}
#' }
#' @example /R/examples/example_return_error.R
#' @export
return_error <- function(data_results, data_test = NULL, test_indices = NULL, aggregate = stats::median,
metrics = c("mae", "mape", "mdape", "smape", "rmse", 'rmsse'),
models = NULL, horizons = NULL, windows = NULL, group_filter = NULL) {
if (!(methods::is(data_results, "training_results") || methods::is(data_results, "forecast_results"))) {
stop("The 'data_results' argument takes an object of class 'training_results' or 'forecast_results' as input. Run predict() on a 'forecast_model' object first.")
}
if (methods::is(data_results, "training_results") && "rmsse" %in% metrics && any(is.null(data_test), is.null(test_indices))) {
warning("'rmsse' was not calculated. The 'rmsse' metric needs a dataset of actuals passed in 'data_test' and 'test_indices'.")
}
if (methods::is(data_results, "forecast_results") && is.null(data_test)) {
stop("Computing forecast error metrics requires a data.frame input to the 'data_test' argument.")
}
if (xor(is.null(data_test), is.null(test_indices))) {
stop("If using a test dataset to assess forecast error, both 'data_test' and 'test_indices' need to be specified.")
}
# The order of these available metrics should match the error_functions vector later in the script. Only 'mae'
# is available for factor outcomes at present; an error will be thrown below if this is not the case.
if (is.null(data_test)) { # The M5 rmsse requires a dataset of actuals.
error_metrics <- c("mae", "mape", "mdape", "smape", "rmse")
} else {
error_metrics <- c("mae", "mape", "mdape", "smape", "rmse", 'rmsse')
}
# Filter the user input error metrics to only those that are available.
metrics <- metrics[metrics %in% error_metrics]
if (length(metrics) == 0) {
stop("None of the error 'metrics' match any of 'mae', 'mape', 'mdape', 'smape', 'rmse' or 'rmsse'.")
}
# The return() from combine_forecasts(), 'forecastML', is also an object of class 'forecast_results' but does not need
# filtering, so these input types will be handled slightly differently.
is_forecastML <- methods::is(data_results, "forecastML")
data <- data_results
method <- attributes(data)$method
outcome_name <- attributes(data)$outcome_name
outcome_levels <- attributes(data)$outcome_levels
groups <- attributes(data)$groups
if (is_forecastML) {type <- attributes(data)$type} # 'horizon' or 'error'; used for group_by() for outputs from combine_forecasts().
#----------------------------------------------------------------------------
# For factor outcomes, is the prediction a factor level or probability?
if (!is.null(outcome_levels)) {
factor_level <- if (any(names(data) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
factor_prob <- !factor_level
if (!all(metrics %in% c("mae"))) {
stop("Only the 'mae' metric is available for factor outcomes. Set 'metrics = 'mae'' and re-run.")
}
# This will eventually change.
if (factor_prob) {
stop("Error metrics with predicted class probabilities are not currently supported.")
}
}
#----------------------------------------------------------------------------
# If dates were given, use dates.
if (!is.null(data$date_indices)) {
data$valid_indices <- data$date_indices
}
#----------------------------------------------------------------------------
# Merge user-supplied test data to the forecasts from predict.forecast_model()
if (methods::is(data, "forecast_results")) {
data_test$forecast_period <- test_indices
data_test <- dplyr::select(data_test, .data$forecast_period, !!outcome_name, !!groups)
data <- dplyr::inner_join(data, data_test, by = c("forecast_period", groups))
if (nrow(data) == 0) {
stop("The test dataset in 'data_test' does not overlap with the forecast period in 'data_results'.")
}
}
#----------------------------------------------------------------------------
# Special rmsse error metric from the M5 forecasting competition. This metric needs
# access to the actuals from the training data which is unique for the supported error metrics.
if (any(metrics %in% "rmsse")) {
data_test$index <- test_indices
data_test <- data_test %>% dplyr::select(.data$index, !!outcome_name, !!groups)
# Repeat prediction data for multiple windows.
n_windows <- length(unique(data_results$window_number))
if (n_windows == 0) {n_windows <- 1}
n_rows <- nrow(data_test)
data_test <- data_test[rep(1:n_rows, n_windows), , drop = FALSE]
data_test$window_number <- rep(1:n_windows, each = n_rows)
if (methods::is(data_results, "training_results")) {
data_merge <- data %>% dplyr::select(.data$valid_indices, .data$window_number, !!groups)
names(data_merge)[names(data_merge) == "valid_indices"] <- "index"
data_merge$instance_in_valid_dataset <- TRUE
}
if (methods::is(data_results, "training_results")) {
if (unique(data$window_length) != 0) { # Validation rows need to be removed from training rows.
data_test <- dplyr::left_join(data_test, data_merge, by = c("index", "window_number", groups))
data_test <- dplyr::distinct(data_test, .data$index, .data$window_number, !!!rlang::syms(groups), .keep_all = TRUE)
data_test[which(data_test$instance_in_valid_dataset), outcome_name] <- NA
} else { # All training rows are also validation rows--essentially model fit error.
data_test$window_number <- NULL
data_test <- dplyr::left_join(data_test, data_merge, by = c("index", groups))
data_test <- dplyr::distinct(data_test, .data$index, !!!rlang::syms(groups), .keep_all = TRUE)
}
}
data_test <- data_test %>%
dplyr::arrange(!!!rlang::syms(groups), .data$window_number, .data$index) %>%
dplyr::group_by_at(dplyr::vars(.data$window_number, !!groups)) %>%
dplyr::mutate("outcome_lag_1" = dplyr::lag(!!rlang::sym(outcome_name), 1),
"outcome_minus_outcome_lag_1" = !!rlang::sym(outcome_name) - .data$outcome_lag_1,
"sse_denom" = sum(.data$outcome_minus_outcome_lag_1^2, na.rm = TRUE),
"n" = sum(!is.na(.data$outcome_minus_outcome_lag_1))) %>% # Used to compute average to convert sse_denom to mse_denom.
dplyr::filter(n > 0) %>% # Avoid a division by zero error for sparse data with NAs or only 1 data point.
dplyr::summarize("mse_denom" = .data$sse_denom[1] / (n[1])) # No need to subtract 1 per the formula, 'n' already accounts for missing data from lags.
if (!is_forecastML) {
data <- dplyr::left_join(data, data_test, by = c("window_number", groups))
} else {
data$window_number <- 1
data <- dplyr::left_join(data, data_test, by = c("window_number", groups))
}
}
#----------------------------------------------------------------------------
# Residual calculations
if (is.null(outcome_levels)) { # Numeric outcome.
data$residual <- data[, outcome_name] - data[, paste0(outcome_name, "_pred")]
} else { # Factor outcome.
if (factor_level) {
# Binary accuracy/residual. A residual of 1 is an incorrect classification.
data$residual <- ifelse(as.character(data[, outcome_name, drop = TRUE]) != as.character(data[, paste0(outcome_name, "_pred"), drop = TRUE]), 1, 0)
} else { # Class probabilities were predicted.
# data_residual <- 1 - data[, names(data) %in% outcome_levels]
# names(data_residual) <- paste0(names(data_residual), "_residual")
# data <- dplyr::bind_cols(data, data_residual)
# rm(data_residual)
}
}
#----------------------------------------------------------------------------
# Filter results based on user input.
if (!is_forecastML) {
models <- if (is.null(models)) {unique(data$model)} else {models}
horizons <- if (is.null(horizons)) {unique(data$model_forecast_horizon)} else {horizons}
windows <- if (is.null(windows)) {unique(data$window_number)} else {windows}
data <- data[data$model %in% models & data$model_forecast_horizon %in% horizons & data$window_number %in% windows, ]
if (!is.null(group_filter)) {
data <- dplyr::filter(data, eval(parse(text = group_filter)))
}
}
#----------------------------------------------------------------------------
# Select error functions. The forecastML internal error functions are in zzz.R.
# The functions are called with named x, y, and z args in dplyr::summarize().
if (is.null(data_test)) { # The M5 rmsse requires a dataset of actuals.
error_functions <- c(forecastML_mae, forecastML_mape, forecastML_mdape, forecastML_smape,
forecastML_rmse)
} else {
error_functions <- c(forecastML_mae, forecastML_mape, forecastML_mdape, forecastML_smape,
forecastML_rmse, forecastML_rmsse)
}
# The user-selected 'metrics' are used to select the appropriate error functions.
select_error_funs <- sapply(metrics, function(x) {which(x == error_metrics)})
error_functions <- error_functions[select_error_funs]
names(error_functions) <- error_metrics[select_error_funs]
#----------------------------------------------------------------------------
# The error combinations at higher levels of aggregation--'data_2' and 'data_3'--are based on
# error metrics calculated at the lowest level--data_1--and combined using the user-supplied
# 'aggregate' method.
# For all error function args: x = 'residual', y = 'actual', and z = 'prediction'. Unused
# function args for a given metric are passed in ... and ignored. See zzz.R.
if (methods::is(data_results, "training_results")) { # Error metrics for training data.
#--------------------------------------------------------------------------
# Compute error metrics at the validation window level.
if (!all(metrics %in% "rmsse")) {
data_1 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number)) %>%
dplyr::mutate("window_start" = min(.data$valid_indices, na.rm = TRUE),
"window_stop" = max(.data$valid_indices, na.rm = TRUE),
"window_midpoint" = base::mean(.data$valid_indices, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon,
.data$window_number, !!groups, .data$window_start,
.data$window_stop, .data$window_midpoint)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_1_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number)) %>%
dplyr::mutate("window_start" = min(.data$valid_indices, na.rm = TRUE),
"window_stop" = max(.data$valid_indices, na.rm = TRUE),
"window_midpoint" = base::mean(.data$valid_indices, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon,
.data$window_number, !!groups, .data$window_start,
.data$window_stop, .data$window_midpoint)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_1_rmsse$rmsse <- with(data_1_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_1_rmsse$h <- NULL
data_1_rmsse$sse_num <- NULL
data_1_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_1 <- data_1_rmsse
} else {
data_1$rmsse <- data_1_rmsse$rmsse
}
}
#--------------------------------------------------------------------------
# Compute error metric by horizon and window length across all validation windows.
data_2 <- data_1 %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon)) %>%
dplyr::mutate("window_start" = min(.data$window_start, na.rm = TRUE),
"window_stop" = max(.data$window_stop, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, !!groups, .data$window_start, .data$window_stop)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#--------------------------------------------------------------------------
# Compute error metric by model.
data_3 <- data_1 %>%
dplyr::group_by_at(dplyr::vars(.data$model)) %>%
dplyr::mutate("window_start" = min(.data$window_start, na.rm = TRUE),
"window_stop" = max(.data$window_start, na.rm = TRUE)) %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups, .data$window_start, .data$window_stop)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#--------------------------------------------------------------------------
} else if (!is_forecastML) { # Error metrics for the forecast_results class which has no validation windows and slightly different grouping columns.
#------------------------------------------------------------------------
# Compute error metrics at the validation window level.
if (!all(metrics %in% "rmsse")) {
data_1 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_1_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, .data$window_number, !!groups)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_1_rmsse$rmsse <- with(data_1_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_1_rmsse$h <- NULL
data_1_rmsse$sse_num <- NULL
data_1_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_1 <- data_1_rmsse
} else {
data_1$rmsse <- data_1_rmsse$rmsse
}
}
#------------------------------------------------------------------------
# Compute error metric by horizon and window across all validation windows.
data_2 <- data_1 %>%
dplyr::ungroup() %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$model_forecast_horizon, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#------------------------------------------------------------------------
# Compute error metric by model.
data_3 <- data_1 %>%
dplyr::ungroup() %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(!!!rlang::syms(metrics)), aggregate, na.rm = TRUE)
#--------------------------------------------------------------------------
} else { # Final forecasts from combine_forecasts().
#------------------------------------------------------------------------
# There are no validation windows to compute error metrics for.
data_1 <- data.frame()
#------------------------------------------------------------------------
# Compute error metric by model and horizon across all validation windows.
if (!all(metrics %in% "rmsse")) {
if (type == "horizon") {
data_2 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$horizon, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_2_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, .data$horizon, !!groups)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_2_rmsse$rmsse <- with(data_2_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_2_rmsse$h <- NULL
data_2_rmsse$sse_num <- NULL
data_2_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_2 <- data_2_rmsse
} else {
data_2$rmsse <- data_2_rmsse$rmsse
}
}
}
#------------------------------------------------------------------------
# Compute error metric by model.
if (type == "horizon") {
if (type == "horizon") {
data_3 <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups)) %>%
dplyr::summarize_at(dplyr::vars(1), # 1 is a col position that gets the fun to run; args x, y, and z defined below.
.funs = error_functions,
x = rlang::quo(.data$residual),
y = rlang::sym(outcome_name),
z = rlang::sym(paste0(outcome_name, "_pred")))
}
if (any(metrics %in% "rmsse")) {
data_3_rmsse <- data %>%
dplyr::group_by_at(dplyr::vars(.data$model, !!groups)) %>%
dplyr::summarize("h" = sum(!is.na(.data$residual)),
"sse_num" = sum(.data$residual^2, na.rm = TRUE),
"mse_denom" = .data$mse_denom[1])
data_3_rmsse$rmsse <- with(data_3_rmsse, sqrt((1 / h) * (sse_num / mse_denom)))
data_3_rmsse$h <- NULL
data_3_rmsse$sse_num <- NULL
data_3_rmsse$mse_denom <- NULL
if (all(metrics %in% "rmsse")) {
data_3 <- data_3_rmsse
} else {
data_3$rmsse <- data_3_rmsse$rmsse
}
}
}
#------------------------------------------------------------------------
} # End error metrics for results from combine_forecasts().
#----------------------------------------------------------------------------
data_out <- list("error_by_window" = data_1, "error_by_horizon" = data_2, "error_global" = data_3)
data_out[] <- lapply(data_out, as.data.frame) # Remove the tibble class.
if (methods::is(data_results, "training_results")) { # Validation error.
class(data_out) <- c("validation_error", class(data_out))
} else { # Forecast error
if (!is_forecastML) {
class(data_out) <- c("forecast_error", class(data_out))
} else {
class(data_out) <- c("forecastML_error", class(data_out))
}
}
attr(data_out, "error_metrics") <- metrics
attr(data_out, "outcome_name") <- outcome_name
attr(data_out, "outcome_levels") <- outcome_levels
attr(data_out, "method") <- method
attr(data_out, "groups") <- groups
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot validation dataset forecast error
#'
#' Plot forecast error at various levels of aggregation across validation datasets.
#' @param x An object of class 'validation_error' from \code{return_error()}.
#' @param type Select plot type; \code{type = "window"} is the default plot.
#' @param metric Select error metric to plot (e.g., "mae"); \code{attributes(x)$error_metrics[1]} is the default metric.
#' @param models Optional. A vector of user-defined model names from \code{train_model()} to filter results.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param group_filter A string for filtering plot results for grouped time series (e.g., \code{"group_col_1 == 'A'"}).
#' @param facet Optional. A formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series).
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}. The default faceting is set internally depending on the plot \code{type}.
#' @param ... Not used.
#' @return Forecast error plots of class 'ggplot'.
#' @export
plot.validation_error <- function(x, type = c("window", "horizon", "global"), metric = NULL,
facet = NULL, models = NULL, horizons = NULL, windows = NULL, group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
data_error <- x
type <- type[1]
method <- attributes(data_error)$method
groups <- attributes(data_error)$groups
error_metrics <- metric
if (is.null(error_metrics)) {
error_metrics <- attributes(data_error)$error_metrics[1]
}
if (!error_metrics %in% attributes(data_error)$error_metrics) {
stop("The error metric in 'metric' is not in the validation error dataset 'x'; re-run return_error() with this metric.")
}
#----------------------------------------------------------------------------
# Set default plot facets for each plot type.
if (is.null(facet)) {
if (type == "window") {
facet <- horizon ~ model
} else if (type == "horizon") {
facet <- horizon ~ model
} else if (type == "global") {
facet <- horizon ~ model
}
}
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
if (type == "window") {
data_plot <- data_error$error_by_window
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
} else if (type == "horizon") {
data_plot <- data_error$error_by_horizon
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
data_plot$window_number <- TRUE # Not in data, added for filtering.
} else if (type == "global") {
data_plot <- data_error$error_global
data_plot$model_forecast_horizon <- TRUE # Not in data, added for filtering.
data_plot$window_number <- TRUE # Not in data, added for filtering.
data_plot$horizon <- TRUE # Not in data, added for filtering.
}
#----------------------------------------------------------------------------
# Filter the datasets based on user input.
models <- if (is.null(models)) {unique(data_plot$model)} else {models}
if (is.null(horizons)) {
horizons <- unique(data_plot$horizon)
}
windows <- if (is.null(windows)) {unique(data_plot$window_number)} else {windows}
data_plot <- data_plot[data_plot$horizon %in% horizons & data_plot$model %in% models & data_plot$window_number %in% windows, ]
#----------------------------------------------------------------------------
# User filtering to display select results in grouped time series.
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
# Melt the data for plotting.
data_plot <- tidyr::gather(data_plot, "error_metric", "value", -!!names(data_plot)[!names(data_plot) %in% error_metrics])
#----------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
#----------------------------------------------------------------------------
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$horizon, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$horizon, .data$window_number, !!!rlang::syms(groups))
}
if (type == "global") {
data_plot$horizon <- "All"
data_plot$horizon <- ordered(data_plot$horizon)
}
data_plot$horizon <- factor(data_plot$horizon, levels = unique(data_plot$horizon), ordered = TRUE)
data_plot[, groups] <- lapply(seq_along(data_plot[, groups, drop = FALSE]), function(i) {
factor(data_plot[, groups[i]], levels = unique(data_plot[, groups[i]]), ordered = TRUE)
})
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
# Give error a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Error"
data_plot$ggplot_color <- ordered(data_plot$ggplot_color)
}
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
x_axis_title <- paste(temp_1, temp_2, sep = "")
x_axis_title <- paste(x_axis_title, collapse = " + ")
#----------------------------------------------------------------------------
# Create plots.
if (type == "window") {
p <- ggplot()
if (length(unique(data_plot$window_midpoint)) != 1) {
p <- p + geom_line(data = data_plot, aes(x = .data$window_midpoint,
y = .data$value,
color = .data$ggplot_color,
group = .data$ggplot_color
), size = 1.05)
}
p <- p + geom_point(data = data_plot, aes(x = .data$window_midpoint,
y = .data$value,
color = .data$ggplot_color,
group = .data$ggplot_color
))
p <- p + scale_color_viridis_d()
p <- p + facet_grid(facet, scales = "free")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + xlab("Dataset index") + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(color = x_axis_title) +
ggtitle("Forecast Error by Validation Window")
return(suppressWarnings(p))
}
if (type == "horizon") {
p <- ggplot()
p <- p + geom_col(data = data_plot,
aes(x = .data$ggplot_color,
y = .data$value,
fill = .data$ggplot_color,
group = .data$ggplot_color))
p <- p + facet_grid(facet, scales = "free_y")
p <- p + scale_color_viridis_d()
p <- p + theme_bw() + theme(
panel.spacing = unit(0, "lines"),
axis.text.x = element_blank()
)
p <- p + xlab(x_axis_title) + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(fill = x_axis_title, alpha = NULL) +
ggtitle("Forecast Error by Forecast Horizon")
p
return(suppressWarnings(p))
}
if (type == "global") {
p <- ggplot()
p <- p + geom_bar(data = data_plot,
aes(x = .data$ggplot_color,
y = .data$value,
fill = .data$ggplot_color),
stat = "identity", position = position_dodge(width = 1))
p <- p + scale_fill_viridis_d()
p <- p + facet_grid(facet, scales = "free")
p <- p + theme_bw() + theme(
panel.spacing = unit(0, "lines"),
axis.text.x = element_blank()
)
p <- p + xlab(x_axis_title) + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(fill = x_axis_title, alpha = NULL) +
ggtitle("Forecast Error Across Validation Windows and Horizons")
return(suppressWarnings(p))
}
} # nocov end
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot forecast error
#'
#' Plot forecast error at various levels of aggregation.
#' @param x An object of class 'forecast_error' from \code{return_error()}.
#' @param type Select plot type; \code{type = "global"} is the default plot.
#' @param metric Select error metric to plot (e.g., "mae"); \code{attributes(x)$error_metrics[1]} is the default metric.
#' @param models Optional. A vector of user-defined model names from \code{train_model()} to filter results.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param group_filter A string for filtering plot results for grouped time series (e.g., \code{"group_col_1 == 'A'"}).
#' @param facet Optional. A formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series).
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}. The default faceting is set internally depending on the plot \code{type}.
#' @param ... Not used.
#' @return Forecast error plots of class 'ggplot'.
#' @export
plot.forecast_error <- function(x, type = c("global"), metric = NULL,
facet = NULL, models = NULL, horizons = NULL, windows = NULL, group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
data_error <- x
type <- type[1]
method <- attributes(data_error)$method
groups <- attributes(data_error)$groups
error_metrics <- metric
if (is.null(error_metrics)) {
error_metrics <- attributes(data_error)$error_metrics[1]
}
if (!error_metrics %in% attributes(data_error)$error_metrics) {
stop("The error metric in 'metric' is not in the forecast error dataset 'x'; re-run return_error() with this metric.")
}
#----------------------------------------------------------------------------
# Set default plot facets for each plot type.
if (is.null(facet)) {
if (type == "window") {
facet <- horizon ~ model
} else if (type == "horizon") {
facet <- horizon ~ model
} else if (type == "global") {
facet <- horizon ~ model
}
}
facets <- forecastML_facet_plot(facet, groups) # Function in zzz.R.
facet <- facets[[1]]
facet_names <- facets[[2]]
if (type == "window") {
data_plot <- data_error$error_by_window
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
} else if (type == "horizon") {
data_plot <- data_error$error_by_horizon
if (method == "direct") {
data_plot$horizon <- data_plot$model_forecast_horizon # Added for faceting.
}
data_plot$window_number <- TRUE # Not in data, added for filtering.
} else if (type == "global") {
data_plot <- data_error$error_global
data_plot$model_forecast_horizon <- TRUE # Not in data, added for filtering.
data_plot$window_number <- TRUE # Not in data, added for filtering.
data_plot$horizon <- TRUE # Not in data, added for filtering.
}
#----------------------------------------------------------------------------
# Filter the datasets based on user input.
models <- if (is.null(models)) {unique(data_plot$model)} else {models}
if (is.null(horizons)) {
horizons <- unique(data_plot$model_forecast_horizon)
}
windows <- if (is.null(windows)) {unique(data_plot$window_number)} else {windows}
data_plot <- data_plot[data_plot$model_forecast_horizon %in% horizons & data_plot$model %in% models & data_plot$window_number %in% windows, ]
#----------------------------------------------------------------------------
# User filtering to display select results in grouped time series.
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
# Melt the data for plotting.
data_plot <- tidyr::gather(data_plot, "error_metric", "value", -!!names(data_plot)[!names(data_plot) %in% error_metrics])
#----------------------------------------------------------------------------
# ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])
#----------------------------------------------------------------------------
if (is.null(groups)) {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number)
} else {
data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number, !!!rlang::syms(groups))
}
data_plot$model_forecast_horizon <- factor(data_plot$model_forecast_horizon, levels = unique(data_plot$model_forecast_horizon), ordered = TRUE)
data_plot[, groups] <- lapply(seq_along(data_plot[, groups, drop = FALSE]), function(i) {
factor(data_plot[, groups[i]], levels = unique(data_plot[, groups[i]]), ordered = TRUE)
})
data_plot$ggplot_color <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
# Give error a name in the legend if plot is faceted by model and horizon (and group if groups are given).
if (length(ggplot_color) == 0) {
data_plot$ggplot_color <- "Error"
}
# Used to avoid lines spanning any gaps between validation windows.
if (all(data_plot$window_number == 1)) { # One window; no need to add the window number to the legend.
data_plot$ggplot_group <- apply(data_plot[, ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
} else {
data_plot$ggplot_group <- apply(data_plot[, c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
}
# Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
data_plot$ggplot_group <- factor(data_plot$ggplot_group, levels = unique(data_plot$ggplot_group), ordered = TRUE)
temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
x_axis_title <- paste(temp_1, temp_2, sep = "")
x_axis_title <- paste(x_axis_title, collapse = " + ")
if (type == "global") {
data_plot$horizon <- "All"
}
p <- ggplot()
p <- p + geom_bar(data = data_plot,
aes(x = .data$ggplot_color,
y = .data$value,
fill = .data$ggplot_color,
group = .data$ggplot_group),
stat = "identity", position = position_dodge(width = 1))
p <- p + scale_fill_viridis_d()
p <- p + facet_grid(facet, scales = "free")
p <- p + theme_bw() + theme(
panel.spacing = unit(0, "lines"),
axis.text.x = element_blank()
)
p <- p + xlab(x_axis_title) + ylab(paste0("Forecast error metric (", error_metrics, ")")) + labs(fill = x_axis_title, alpha = NULL)
if (type %in% c("window")) {
p <- p + ggtitle("Forecast Error by Validation Window, Model Forecast Horizon, & Model")
} else if (type %in% c("horizon")) {
p <- p + ggtitle("Forecast Error by Model Forecast Horizon & Model")
} else if (type %in% c("global")) {
p <- p + ggtitle("Forecast Error by Model")
}
return(suppressWarnings(p))
} # nocov end
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.