R/train_model.R

Defines functions plot.forecast_results plot.training_results predict.forecast_model train_model

Documented in plot.forecast_results plot.training_results predict.forecast_model train_model

#' Train a model across horizons and validation datasets
#'
#' Train a user-defined forecast model for each horizon, 'h', and across the validation
#' datasets, 'd'. If \code{method = "direct"}, a total of 'h' * 'd' models are trained.
#' If \code{method = "multi_output"}, a total of 1 * 'd' models are trained.
#' These models can be trained in parallel with the \code{future} package.
#'
#' @param lagged_df An object of class 'lagged_df' from \code{\link{create_lagged_df}}.
#' @param windows An object of class 'windows' from \code{\link{create_windows}}.
#' @param model_name A name for the model.
#' @param model_function A user-defined wrapper function for model training that takes the following
#' arguments: (1) a horizon-specific data.frame made with \code{create_lagged_df(..., type = "train")}
#' (i.e., the dataset(s) stored in \code{lagged_df}) and, optionally, (2) any number of additional named arguments
#' which can be passed in \code{...} in this function.
#' @param ... Optional. Named arguments passed into the user-defined \code{model_function}.
#' @param use_future Boolean. If \code{TRUE}, the \code{future} package is used for training models in parallel.
#' The models will train in parallel across either (1) model forecast horizons or (b) validation windows,
#' whichever is longer (i.e., \code{length(create_lagged_df())} or \code{nrow(create_windows())}). The user
#' should run \code{future::plan(future::multiprocess)} or similar prior to this function to train these models
#' in parallel.
#' @return An S3 object of class 'forecast_model': A nested list of trained models. Models can be accessed with
#' \code{my_trained_model$horizon_h$window_w$model} where 'h' gives the forecast horizon and 'w' gives
#' the validation dataset window number from \code{create_windows()}.
#'
#' @section Methods and related functions:
#'
#' The output of \code{train_model} can be passed into
#'
#' \itemize{
#'   \item \code{\link{return_error}}
#'   \item \code{\link{return_hyper}}
#' }
#'
#' and has the following generic S3 methods
#'
#' \itemize{
#'   \item \code{\link[=predict.forecast_model]{predict}}
#'   \item \code{\link[=plot.training_results]{plot}} (from \code{predict.forecast_model(data = create_lagged_df(..., type = "train"))})
#'   \item \code{\link[=plot.forecast_results]{plot}} (from \code{predict.forecast_model(data = create_lagged_df(..., type = "forecast"))})
#' }
#' @example /R/examples/example_train_model.R
#' @export
train_model <- function(lagged_df, windows, model_name, model_function, ..., use_future = FALSE) {

  if (missing(lagged_df) || !methods::is(lagged_df, "lagged_df")) {
    stop("The 'lagged_df' argument takes an object of class 'lagged_df' as input. Run create_lagged_df() first.")
  }

  if (missing(windows) || !methods::is(windows, "windows")) {
    stop("The 'windows' argument takes an object of class 'windows' as input. Run create_windows() first.")
  }

  if (missing(model_name)) {
    stop("Enter a model name for the 'model_name' argument.")
  }

  if (missing(model_function) || !is.function(model_function)) {
    stop("The 'model_function' argument takes a user-defined model training function as input.")
  }

  row_indices <- attributes(lagged_df)$row_indices
  date_indices <- attributes(lagged_df)$date_indices
  horizons <- attributes(lagged_df)$horizons
  data_start <- attributes(lagged_df)$data_start
  data_stop <- attributes(lagged_df)$data_stop
  skeleton <- attributes(lagged_df)$skeleton

  # These are the arguments from the user-defined modeling function passed in train_model() with ...
  # which is optional but potentially convenient for the user who can avoid re-defining the modeling function
  # in repeated calls to train_model(). The arguments in ... will be passed as a named list in do.call(). This
  # local scoping within future_lapply() appears to be necessary because global arguments passed in ...
  # aren't being found by the future package when use_future = TRUE.
  n_args <- ...length()
  if (n_args > 0) {
    model_function_args <- as.list(substitute(list(...)))[-1]
    # Global objects passed in ... need to be evaluated prior to any future calls. For instance, if a
    # custom ... model argument outcome_col = x is given, x needs to be evaluated before model_function_args
    # is passed in future_lapply().
    model_function_args <- lapply(model_function_args, eval)
  }
  #----------------------------------------------------------------------------
  # The default future behavior is to parallelize the model training over the longer dimension: (a) number of
  # forecast horizons or (b) number of validation windows. This is due to a current limitation
  # in the future package on changing object size limitations for nested futures where
  # "options(globals.maxSize.default = Inf)" isn't recognized.
  if (isTRUE(use_future)) {

    if (length(horizons) >= nrow(windows)) {

      lapply_across_horizons <- future.apply::future_lapply
      lapply_across_val_windows <- base::lapply

    } else {

      lapply_across_horizons <- base::lapply
      lapply_across_val_windows <- future.apply::future_lapply
    }

  } else {

    lapply_across_horizons <- base::lapply
    lapply_across_val_windows <- base::lapply
  }
  #----------------------------------------------------------------------------
  # Seq along model forecast horizon > cross-validation windows.
  data_out <- lapply_across_horizons(lagged_df, function(data, future.seed, ...) {  # model forecast horizon.

    model_plus_valid_data <- lapply_across_val_windows(1:nrow(windows), function(i, future.seed, ...) {  # validation windows within model forecast horizon.

      window_length <- windows[i, "window_length"]

      if (is.null(date_indices)) {

        valid_indices <- windows[i, "start"]:windows[i, "stop"]
        valid_indices_date <- NULL

      } else {

        # When create_lagged_df(..., groups = NULL, keep_rows = FALSE), the validation indices need an offset to account for the fact that
        # validation windows are selected where row_indices %in% valid_indices which maps back to the input dataset to create_lagged_df()
        # which will have 1:max(lookback) more rows than the dataset that comes out of create_lagged_df(..., groups = NULL, keep_rows = FALSE).
        valid_indices <- min(row_indices) - 1 + which(date_indices >= windows[i, "start"] & date_indices <= windows[i, "stop"])
        valid_indices_date <- date_indices[date_indices >= windows[i, "start"] & date_indices <= windows[i, "stop"]]
      }

      # A window length of 0 that spans the entire dataset removes the nested cross-validation and
      # trains on all input data in lagged_df.
      if (window_length == 0 && (windows[i, "start"] == data_start) && (windows[i, "stop"] == data_stop)) {

        # Model training over all data.
        if (n_args == 0) {  # No user-defined model args passed in ...

          model <- try(model_function(data))

        } else {

          model <- try(do.call(model_function, append(list(data), model_function_args)))
        }

      } else {  # Model training with external block-contiguous cv.

        # These validation indices will always start at 1. They index with respect to the
        # data passed into model_function() as opposed to the dataset passed in create_lagged_df().
        # These indices, stored as an attribute, are for manual filtering any skeleton lagged_dfs in model_function().
        validation_indices <- which(row_indices %in% valid_indices)
        attributes(data)$validation_indices <- validation_indices

        if (n_args == 0) {  # No user-defined model args passed in ...

          model <- try(model_function(data[-(validation_indices), , drop = FALSE]))

        } else {

          model <- try({
              do.call(model_function, append(list(data[-(validation_indices), , drop = FALSE]), model_function_args))
            })
          }
      }

      if (methods::is(model, "try-error")) {
        warning(paste0("A model returned class 'try-error' for validation window ", i))
      }

      list("model" = model, "window" = i, "window_length" = window_length, "valid_indices" = valid_indices,
           "date_indices" = valid_indices_date)
    }, future.seed = 1)  # End model training across nested cross-validation windows for the horizon in "data".

    names(model_plus_valid_data) <- paste0("window_", 1:nrow(windows))
    attr(model_plus_valid_data, "horizon") <- attributes(data)$horizon
    model_plus_valid_data
  }, future.seed = 1)  # End training across horizons.

  attr(data_out, "model_name") <- model_name
  attr(data_out, "horizons") <- horizons
  attr(data_out, "outcome_col") <- attributes(lagged_df)$outcome_col
  attr(data_out, "outcome_cols") <- attributes(lagged_df)$outcome_cols
  attr(data_out, "outcome_name") <- attributes(lagged_df)$outcome_name
  attr(data_out, "outcome_names") <- attributes(lagged_df)$outcome_names
  attr(data_out, "outcome_levels") <- attributes(lagged_df)$outcome_levels
  attr(data_out, "row_indices") <- row_indices
  attr(data_out, "date_indices") <- date_indices
  attr(data_out, "frequency") <- attributes(lagged_df)$frequency
  attr(data_out, "data_stop") <- attributes(lagged_df)$data_stop
  attr(data_out, "groups") <- attributes(lagged_df)$groups
  attr(data_out, "method") <- attributes(lagged_df)$method
  attr(data_out, "skeleton") <- skeleton


  class(data_out) <- c("forecast_model", class(data_out))

  return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------

#' Predict on validation datasets or forecast
#'
#' Predict with a 'forecast_model' object from \code{train_model()}. If \code{data = create_lagged_df(..., type = "train")},
#' predictions are returned for the outer-loop nested cross-validation datasets.
#' If \code{data} is an object of class 'lagged_df' from \code{create_lagged_df(..., type = "forecast")},
#' predictions are returned for the horizons specified in \code{create_lagged_df(horizons = ...)}.
#'
#' @param ... One or more trained models from \code{train_model()}.
#' @param prediction_function A list of user-defined prediction functions with length equal to
#' the number of models supplied in \code{...}. The prediction functions
#' take 2 required positional arguments--(1) a 'forecast_model' object from \code{train_model()} and (2) a
#' data.frame of model features from \code{create_lagged_df()}. For numeric outcomes and \code{method = "direct"}, the function should \code{return()}
#' 1- or 3-column data.frame of model predictions. If the prediction function returns a 1-column data.frame, point forecasts are assumed.
#' If the prediction function returns a 3-column data.frame, lower and upper forecast bounds are assumed (the
#' order and names of the 3 columns does not matter). For factor outcomes and \code{method = "direct"}, the function should \code{return()}
#' (1) 1-column data.frame of the model-predicted factor level or (2) an L-column data.frame of class probabilities where
#' 'L' equals the number of levels in the outcome; columns should be ordered, from left to right, the same as
#' \code{levels(data$outcome)} which is the default behavior for most \code{predict(..., type = "prob")} functions.
#' Column names do not matter. For numeric outcomes and \code{method = "multi_output"}, the function should \code{return()} and
#' h-column data.frame of model predictions--1 column for each horizon. Forecast intervals and factor outcomes are not currently
#' supported with \code{method = "multi_output"}.
#' @param data If \code{data} is a training dataset from \code{create_lagged_df(..., type = "train")}, validation dataset
#' predictions are returned; else, if \code{data} is a forecasting dataset from \code{create_lagged_df(..., type = "forecast")},
#' forecasts from horizons 1:h are returned.
#' @return If \code{data = create_lagged_df(..., type = "forecast")}, an S3 object of class 'training_results'. If
#' \code{data = create_lagged_df(..., type = "forecast")}, an S3 object of class 'forecast_results'.
#'
#'   \strong{Columns in returned 'training_results' data.frame:}
#'     \itemize{
#'       \item \code{model}: User-supplied model name in \code{train_model()}.
#'       \item \code{model_forecast_horizon}: The direct-forecasting time horizon that the model was trained on.
#'       \item \code{window_length}: Validation window length measured in dataset rows.
#'       \item \code{window_number}: Validation dataset number.
#'       \item \code{valid_indices}: Validation dataset row names from \code{attributes(create_lagged_df())$row_indices}.
#'       \item \code{date_indices}: If given and \code{method = "direct"}, validation dataset date indices from \code{attributes(create_lagged_df())$date_indices}.
#'       If given and \code{method = "multi_output"}, date_indices represents the date of the forecast.
#'       \item \code{"groups"}: If given, the user-supplied groups in \code{create_lagged_df()}.
#'       \item \code{"outcome_name"}: The target being forecasted.
#'       \item \code{"outcome_name"_pred}: The model predictions.
#'       \item \code{"outcome_name"_pred_lower}: If given, the lower prediction bounds returned by the user-supplied prediction function.
#'       \item \code{"outcome_name"_pred_upper}: If given, the upper prediction bounds returned by the user-supplied prediction function.
#'       \item \code{forecast_indices}: If \code{method = "multi_output"}, the validation index of the h-step-ahead forecast.
#'       \item \code{forecast_date_indices}: If \code{method = "multi_output"}, the validation date index of the h-step-ahead forecast.
#'    }
#'
#'    \strong{Columns in returned 'forecast_results' data.frame:}
#'     \itemize{
#'       \item \code{model}: User-supplied model name in \code{train_model()}.
#'       \item \code{model_forecast_horizon}: If \code{method = "direct"}, the direct-forecasting time horizon that the model was trained on.
#'       \item \code{horizon}: Forecast horizons, 1:h, measured in dataset rows.
#'       \item \code{window_length}: Validation window length measured in dataset rows.
#'       \item \code{forecast_period}: The forecast period in row indices or dates. The forecast period starts at either \code{attributes(create_lagged_df())$data_stop + 1} for row indices or \code{attributes(create_lagged_df())$data_stop + 1 * frequency} for date indices.
#'       \item \code{"groups"}: If given, the user-supplied groups in \code{create_lagged_df()}.
#'       \item \code{"outcome_name"}: The target being forecasted.
#'       \item \code{"outcome_name"_pred}: The model forecasts.
#'       \item \code{"outcome_name"_pred_lower}: If given, the lower forecast bounds returned by the user-supplied prediction function.
#'       \item \code{"outcome_name"_pred_upper}: If given, the upper forecast bounds returned by the user-supplied prediction function.
#'    }
#'
#' @example /R/examples/example_predict_train_model.R
#' @export
predict.forecast_model <- function(..., prediction_function = list(NULL), data) {

  #----------------------------------------------------------------------------
  model_list <- list(...)

  if (!all(unlist(lapply(model_list, function(x) {class(x)[1]})) %in% "forecast_model")) {
    stop("The '...' argument takes 1 or more objects of class 'forecast_model' as input. Run train_model() first.
         Also, the arguments 'prediction_function' and 'data' need to be named and not positional because they
         follow '...'.")
  }

  if (length(model_list) != length(prediction_function)) {
    stop("The number of prediction functions does not equal the number of forecast models.")
  }

  if (!methods::is(data, "lagged_df")) {
    stop("The 'data' argument takes a training or forecasting dataset of class 'lagged_df' from create_lagged_df().")
  }
  #----------------------------------------------------------------------------
  type <- attributes(data)$type  # train or forecast.

  method <- attributes(model_list[[1]])$method
  horizons <- attributes(model_list[[1]])$horizons
  outcome_col <- attributes(model_list[[1]])$outcome_col
  outcome_cols <- attributes(model_list[[1]])$outcome_cols
  outcome_name <- attributes(model_list[[1]])$outcome_name
  outcome_names <- attributes(model_list[[1]])$outcome_names
  outcome_levels <- attributes(model_list[[1]])$outcome_levels
  row_indices <- attributes(model_list[[1]])$row_indices
  date_indices <- attributes(model_list[[1]])$date_indices
  frequency <- attributes(model_list[[1]])$frequency
  groups <- attributes(model_list[[1]])$groups
  skeleton <- attributes(model_list[[1]])$skeleton
  #----------------------------------------------------------------------------
  if (type == "train") {

    data_stop <- attributes(model_list[[1]])$data_stop

  } else {

    data_stop <- attributes(data)$data_stop
  }
  #----------------------------------------------------------------------------
  # Seq along forecast model[i] > model forecast horizon[j] > validation window number[k].
  data_model <- lapply(seq_along(model_list), function(i) {

    prediction_fun <- prediction_function[[i]]

    data_horizon <- lapply(seq_along(model_list[[i]]), function(j) {

      data_win_num <- lapply(seq_along(model_list[[i]][[j]]), function(k) {

        data_results <- model_list[[i]][[j]][[k]]

        # Predict on training data or the forecast dataset?
        if (type == "train") {  # Nested cross-validation.

          # These validation indices will always start at 1. They index with respect to the
          # data passed into predict_function() as opposed to the dataset passed in create_lagged_df().
          # These indices, stored as an attribute, are for manual filtering any skeleton lagged_dfs in predict_function().
          validation_indices <- which(row_indices %in% data_results$valid_indices)

          x_valid <- data[[j]][validation_indices, -(outcome_cols), drop = FALSE]
          y_valid <- data[[j]][validation_indices, outcome_cols, drop = FALSE]  # Actuals in function return.

          attributes(x_valid)$validation_indices <- validation_indices
          attributes(x_valid)$horizons <- attributes(data[[j]])$horizons

          data_pred <- try(prediction_fun(data_results$model, x_valid))  # Nested cross-validation.

          if (methods::is(data_pred, "try-error")) {
            warning(paste0("Model '", attributes(model_list[[i]])$model_name, "' returned class 'try-error' for model ", j, " in validation window ", k))
          }

          if (!is.null(groups)) {

            data_groups <- x_valid[, groups, drop = FALSE]  # Save out group identifiers.
          }

        } else {  # Forecast.

          forecast_period <- data[[j]][, "index", drop = FALSE]
          names(forecast_period) <- "forecast_period"
          forecast_horizons <- data[[j]][, "horizon", drop = FALSE]

          data_for_forecast <- data[[j]][, !names(data[[j]]) %in% c("index", "horizon"), drop = FALSE]  # Remove ID columns for predict().

          data_pred <- try(prediction_fun(data_results$model, data_for_forecast))  # User-defined prediction function.

          if (methods::is(data_pred, "try-error")) {
            warning(paste0("Model '", attributes(model_list[[i]])$model_name, "' returned class 'try-error' for model ", j, " in validation window ", k))
          }

          if (!is.null(groups)) {
            data_groups <- data_for_forecast[, groups, drop = FALSE]
          }
        }  # End forecast.
        #----------------------------------------------------------------------
        # Check the return of the user-defined predict() function.
        if (is.null(outcome_levels) && method == "direct" && !ncol(data_pred) %in% c(1, 3)) {  # Numeric outcome.
          stop("For numeric outcomes, the user-defined prediction function needs to return 1- or 3-column data.frame of model predictions.")
        }

        if (is.null(outcome_levels) && method == "multi_output" && ncol(data_pred) != length(horizons)) {  # Numeric outcome.
          stop("For numeric outcomes, the user-defined prediction function needs to return a length(horizons)-column data.frame of model predictions.")
        }

        if (!is.null(outcome_levels)) {  # Factor outcome.

          if (ncol(data_pred) == 1 && method == "direct" && !methods::is(data_pred[, 1], "factor")) {
            stop("For factor outcomes where the predicted factor level is returned (e.g., predict(..., type = 'response')), the user-defined
                 prediction function needs to return 1-column data.frame of factor predictions. If returning class
                 probabilities, the number of data.frame columns should equal the number of factor levels and be in the same order as levels(data$outcome).")
          }

          if (method == "direct" && ncol(data_pred) > 1 && ncol(data_pred) != length(outcome_levels)) {
            stop("For factor outcomes where class probabilities are returned (e.g., predict(..., type = 'prob')),
                 the number of data.frame columns returned should equal the number of factor levels.")
          }
        }
        #----------------------------------------------------------------------
        # Format the returned data.frames from predict().
        if (method == "direct") {

          if (is.null(outcome_levels)) {  # Numeric outcome.

            if (ncol(data_pred) == 1) {

              names(data_pred) <- paste0(outcome_names, "_pred")

            } else {  # Confidence/credible forecast intervals.

              # Find the lower, point, and upper forecasts and order the columns accordingly.
              data_pred <- data_pred[order(unlist(lapply(data_pred, mean, na.rm = TRUE)))]

              names(data_pred) <- c(paste0(outcome_names, "_pred_lower"), paste0(outcome_names, "_pred"), paste0(outcome_names, "_pred_upper"))

              # Re-order so that the point forecast is first.
              data_pred <- data_pred[, c(2, 1, 3)]
            }

          } else {  # Factor outcome.

            if (ncol(data_pred) == 1) {  # A predicted factor level.

              names(data_pred) <- paste0(outcome_names, "_pred")

            } else {  # Predicted class probabilities.

              names(data_pred) <- outcome_levels
            }
          }  # End reformatting of the returned data.frame of predictions.

        } else if (method == "multi_output") {

          if (is.null(outcome_levels)) {  # Numeric outcome.

              names(data_pred) <- paste0(outcome_names, "_pred")

          } else {

            stop("Multi-output models do not currently support factor outcomes.")
          }
        }
        #----------------------------------------------------------------------

        model_name <- attributes(model_list[[i]])$model_name

        if (type == "train") {  # Nested cross-validation.

          data_temp <- data.frame("model" = model_name,
                                  "model_forecast_horizon" = horizons[j],
                                  "window_length" = data_results$window_length,
                                  "window_number" = data_results$window,
                                  "valid_indices" = data_results$valid_indices)

          data_temp$date_indices <- data_results$date_indices

          if (is.null(groups)) {

            data_temp <- cbind(data_temp, y_valid, data_pred)

          } else {

            data_temp <- cbind(data_temp, data_groups, y_valid, data_pred)
          }

        } else {  # Forecast.

          data_temp <- data.frame("model" = model_name,
                                  "model_forecast_horizon" = horizons[j],
                                  "horizon" = forecast_horizons,
                                  "window_length" = data_results$window_length,
                                  "window_number" = data_results$window,
                                  "forecast_period" = forecast_period)

          if (is.null(groups)) {

            data_temp <- cbind(data_temp, data_pred)

          } else {

            data_temp <- cbind(data_temp, data_groups, data_pred)
          }
        }  # End forecast results.

        data_temp$model <- as.character(data_temp$model)  # Coerce to remove any factor levels.
        data_temp
      })  # End cross-validation window predictions.
      data_win_num <- suppressMessages(dplyr::bind_rows(data_win_num))
    })  # End horizon-level predictions.
    data_horizon <- suppressMessages(dplyr::bind_rows(data_horizon))
  })  # End model-level predictions.

  data_out <- suppressMessages(dplyr::bind_rows(data_model))

  # For multi-output models, the data (a) need to be reshaped from a wide to long format
  # and (b) the validation indices that represent the forecast origin need to be changed
  # to the index for the predicted time period.
  if (method == "multi_output") {

    if (type == "train") {

      data_actual_temp <- dplyr::select(data_out, -dplyr::ends_with("_pred"))
      data_pred_temp <- dplyr::select(data_out, -!!outcome_names)
      data_actual_temp <- tidyr::pivot_longer(data_actual_temp, cols = !!outcome_names, names_to = "remove", values_to = outcome_name)
      data_actual_temp$remove <- NULL
      data_pred_temp <- tidyr::pivot_longer(data_pred_temp, cols = dplyr::ends_with("_pred"), values_to = paste0(outcome_name, "_pred"))
      data_actual_temp[, paste0(outcome_name, "_pred")] <- data_pred_temp[, paste0(outcome_name, "_pred")]

      data_out <- data_actual_temp

      data_out$model_forecast_horizon <- rep(horizons, nrow(data_out) / length(horizons))

      data_out$forecast_indices <- data_out$valid_indices

      data_out$forecast_indices <- data_out$model_forecast_horizon + data_out$forecast_indices

      data_out$horizons <- NULL

      names(data_out)[names(data_out) == "model_forecast_horizon"] <- "horizon"

      if (!is.null(date_indices)) {

        data_out$forecast_date_indices <- as.Date(unlist(purrr::map(1:nrow(data_out), function(i) {
          base::seq(data_out$date_indices[i], by = frequency, length.out = data_out$horizon[i] + 1)[data_out$horizon[i] + 1]})),
          origin = "1970-01-01")
      }

      data_out <- dplyr::arrange(data_out, .data$model, .data$valid_indices, .data$forecast_indices, .data$horizon)

    } else if (type == "forecast") {

      forecast_period <- data.frame(data_out[, "forecast_period", drop = FALSE], stringsAsFactors = FALSE)
      forecast_period <- data.frame(strsplit(forecast_period$forecast_period, ", "), stringsAsFactors = FALSE)
      names(forecast_period) <- "forecast_period"
      forecast_period$forecast_period <- if (is.null(date_indices)){as.integer(forecast_period$forecast_period)} else {as.Date(forecast_period$forecast_period)}

      data_out <- tidyr::pivot_longer(data_out, cols = dplyr::ends_with("_pred"),
                                      names_to = "remove", values_to = paste0(outcome_name, "_pred"))

      data_out$remove <- NULL

      data_out$model_forecast_horizon <- NULL

      data_out$horizon <- horizons

      data_out$forecast_period <- forecast_period$forecast_period
    }
  }

  data_out <- as.data.frame(data_out)
  row.names(data_out) <- 1:nrow(data_out)

  attr(data_out, "method") <- method
  attr(data_out, "horizons") <- horizons
  attr(data_out, "outcome_col") <- outcome_col
  attr(data_out, "outcome_cols") <- outcome_cols
  attr(data_out, "outcome_name") <- outcome_name
  attr(data_out, "outcome_names") <- outcome_names
  attr(data_out, "outcome_levels") <- outcome_levels
  attr(data_out, "row_indices") <- row_indices
  attr(data_out, "date_indices") <- date_indices
  attr(data_out, "frequency") <- frequency
  attr(data_out, "data_stop") <- data_stop
  attr(data_out, "groups") <- groups

  if (type == "train") {
    class(data_out) <- c("training_results", "forecast_model", class(data_out))
  } else {
    class(data_out) <- c("forecast_results", "forecast_model", class(data_out))
  }

  return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------

#' Plot an object of class training_results
#'
#' Several diagnostic plots can be returned to assess the quality of the forecasts
#' based on predictions on the validation datasets.
#'
#' @param x An object of class 'training_results' from \code{predict.forecast_model()}.
#' @param type Plot type. The default plot is "prediction" for validation dataset predictions.
#' @param models Optional. Filter results by user-defined model name from \code{train_model()}.
#' @param horizons Optional. A numeric vector of model forecast horizons to filter results by horizon-specific model.
#' @param windows Optional. A numeric vector of window numbers to filter results.
#' @param valid_indices Optional. A numeric or date vector to filter results by validation row indices or dates.
#' @param group_filter Optional. A string for filtering plot results for grouped time series
#' (e.g., \code{"group_col_1 == 'A'"}). The results are passed to \code{dplyr::filter()} internally.
#' @param facet Optional. For numeric outcomes, a formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series)
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' @param keep_missing Boolean. If \code{TRUE}, predictions are plotted for indices/dates where the outcome is missing.
#' @param ... Not used.
#' @return Diagnostic plots of class 'ggplot'.
#' @export
plot.training_results <- function(x,
                                  type = c("prediction", "residual", "forecast_stability"),
                                  facet = horizon ~ model, models = NULL, horizons = NULL,
                                  windows = NULL, valid_indices = NULL, group_filter = NULL, keep_missing = FALSE, ...) { # nocov start

  #----------------------------------------------------------------------------
  data <- x
  rm(x)

  type <- type[1]  # The default plot is predicting over historical validation windows.

  if (type == "forecast_stability") {
    if (!xor(is.null(windows), is.null(valid_indices))) {
      stop("Select either (a) one or more validation windows, 'windows', or (b) a range of dataset rows, 'valid_indices', to reduce plot size.")
    }
  }

  if (!is.null(attributes(data)$group) & !type %in% c("prediction", "residual")) {
    stop("Only 'prediction' and 'residual' plots are currently available for grouped models.")
  }

  if (!is.null(attributes(data)$outcome_levels) & !type %in% c("prediction", "residual")) {
    stop("Only 'prediction' and 'residual' plots are currently available for models with factors as outcomes.")
  }
  #----------------------------------------------------------------------------
  method <- attributes(data)$method
  outcome_col <- attributes(data)$outcome_col
  outcome_name <- attributes(data)$outcome_name
  outcome_levels <- attributes(data)$outcome_levels
  date_indices <- attributes(data)$date_indices
  frequency <- attributes(data)$frequency
  groups <- attributes(data)$group
  window_custom <- all(data$window_length == "custom")

  data <- as.data.frame(data)  # Coerce to remove new vctrs warning about mismatched classes when joining data.
  #----------------------------------------------------------------------------
  if (method == "multi_output") {

    data$valid_indices <- data$forecast_indices  # The plot indices are the forecast period, not the forecast origin.

    if (!is.null(date_indices)) {

      data$date_indices <- data$forecast_date_indices  # The plot indices are the forecast period, not the forecast origin.
    }
  }
  #----------------------------------------------------------------------------
  facets <- forecastML_facet_plot(facet, groups)  # Function in zzz.R.
  facet <- facets[[1]]
  facet_names <- facets[[2]]
  #----------------------------------------------------------------------------
  # For factor outcomes, is the prediction a factor level or probability.
  if (!is.null(outcome_levels)) {

    factor_level <- if (any(names(data) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
    factor_prob <- !factor_level

    if (type == "residual" && factor_prob) {
      stop("Residual plots with predicted class probabilities are not currently supported.")
    }

    if (type == "residual" && !is.null(groups)) {
      stop("Residual plots with grouped data are not currently supported.")
    }
  }
  #----------------------------------------------------------------------------
  if (isFALSE(keep_missing) && is.null(outcome_levels)) {
    # Set the predicted values to NA to keep any missing data in the actual
    # time series so gaps in data collection are not connected with a line in the plot.
    data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred")] <- NA

    if (all(any(grepl("_pred_lower", names(data))), any(grepl("_pred_upper", names(data))))) {
      data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred_lower")] <- NA
      data[is.na(data[, outcome_name]), paste0(outcome_name, "_pred_upper")] <- NA
    }

  } else if (isFALSE(keep_missing) && !is.null(outcome_levels)) {

    data <- data[!is.na(data[, outcome_name]), ]
  }
  #----------------------------------------------------------------------------
  # Residual calculations.
  if (is.null(outcome_levels)) {  # Numeric outcome.

    data$residual <- data[, outcome_name] - data[, paste0(outcome_name, "_pred")]

  } else {  # Factor outcome.

    if (factor_level) {

      # Binary accuracy/residual. A residual of 1 is an incorrect classification.
      data$residual <- ifelse(as.character(data[, outcome_name, drop = TRUE]) != as.character(data[, paste0(outcome_name, "_pred"), drop = TRUE]), 1, 0)

    } else {  # Class probabilities were predicted.

      data_residual <- 1 - data[, names(data) %in% outcome_levels]
      names(data_residual) <- paste0(names(data_residual), "_residual")
      data <- dplyr::bind_cols(data, data_residual)
      rm(data_residual)
    }
  }
  #----------------------------------------------------------------------------
  # Filtering results based on user input.
  models <- if (is.null(models)) {unique(data$model)} else {models}

  if (method == "direct") {

    horizons <- if (is.null(horizons)) {unique(data$model_forecast_horizon)} else {horizons}

  } else {

    horizons <- if (is.null(horizons)) {unique(data$horizon)} else {horizons}

  }

  windows <- if (is.null(windows)) {unique(data$window_number)} else {windows}

  if (is.null(date_indices)) {

    valid_indices <- if (is.null(valid_indices)) {unique(data$valid_indices)} else {valid_indices}

  } else {

    valid_indices <- if (is.null(valid_indices)) {unique(data$date_indices)} else {date_indices}
  }
  #----------------------------------------------------------------------------
  data_plot <- data
  #----------------------------------------------------------------------------
  # Rename for consistency with plotting code.
  if (method == "direct") {

    data_plot$horizon <- data_plot$model_forecast_horizon
    data_plot$model_forecast_horizon <- NULL
  }

  data_plot <- data_plot[data_plot$model %in% models & data_plot$horizon %in% horizons &
                         data_plot$window_number %in% windows, ]

  if (!is.null(group_filter)) {

    data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
  }
  #----------------------------------------------------------------------------
  if (methods::is(valid_indices, "Date") || methods::is(valid_indices, "POSIXt")) {

    data_plot <- data_plot[data_plot$date_indices %in% valid_indices, ]  # Filter plots by dates.
    data_plot$index <- data_plot$date_indices

    } else {

      data_plot <- data_plot[data_plot$valid_indices %in% valid_indices, ]  # Filter plots by row indices.

      if (!is.null(date_indices)) {

        data_plot$index <- data_plot$date_indices

      } else {

        data_plot$index <- data_plot$valid_indices
      }
    }
  #----------------------------------------------------------------------------
  # Set up ggplot color and group parameters.
  if (is.null(outcome_levels)) {  # Numeric outcomes.

    if (is.null(groups)) {

      data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number)

    } else {

      data_plot <- dplyr::arrange(data_plot, .data$model, .data$window_number, eval(parse(text = groups)))
    }
    #--------------------------------------------------------------------------
    # ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
    ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])

    temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
    temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
    legend_title <- paste(temp_1, temp_2, sep = "")
    legend_title <- paste(legend_title, collapse = " + ")
    #--------------------------------------------------------------------------

    data_plot$ggplot_color <- apply(data_plot[,  ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})

    # Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
    if (length(ggplot_color) == 0) {
      data_plot$ggplot_color <- "Prediction"
    }

    # Used to avoid lines spanning any gaps between validation windows.
    if (all(data_plot$window_number == 1)) {  # One window; no need to add the window number to the legend.

      data_plot$ggplot_group <- apply(data_plot[,  ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})

    } else {

      data_plot$ggplot_group <- apply(data_plot[,  c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
    }

    # Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
    data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)
    #----------------------------------------------------------------------------

    if (!is.null(date_indices)) {
      # Create a dataset of points for those instances where there the outcomes are NA before and after a given instance.
      # Points are needed because ggplot will not plot a 1-instance geom_line().
      data_plot_point <- data_plot %>%
        dplyr::group_by(.data$ggplot_color) %>%
        dplyr::mutate("lag" = dplyr::lag(eval(parse(text = outcome_name)), 1),
                      "lead" = dplyr::lead(eval(parse(text = outcome_name)), 1)) %>%
        dplyr::filter(is.na(.data$lag) & is.na(.data$lead))

      data_plot <- data_plot[data_plot$date_indices %in% valid_indices, ]
      # This may be an empty data.frame if every time series has 2 or more contiguous records, and
      # suppressWarnings() suppresses a forcats warning.
      data_plot_point <- suppressWarnings(data_plot_point[data_plot_point$date_indices %in% date_indices, ])
    }
    #--------------------------------------------------------------------------
  } else {  # Factor outcomes.

    data_plot$ggplot_color_group <- apply(data_plot[,  c("model", "horizon", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
  }
  #----------------------------------------------------------------------------
  #----------------------------------------------------------------------------

  if (type %in% c("prediction", "residual")) {

    #--------------------------------------------------------------------------
    # Melt the data for plotting.
    if (is.null(outcome_levels)) {  # Numeric outcome.

      data_plot <- tidyr::gather(data_plot, "outcome", "value",
                                 -!!names(data_plot)[!names(data_plot) %in% c(outcome_name, paste0(outcome_name, "_pred"))])

    } else {  # Factor outcome.

      if (any(names(data_plot) %in% paste0(outcome_name, "_pred"))) {  # A factor level was predicted.

        data_plot <- tidyr::gather(data_plot, "outcome", "value",
                                   -!!names(data_plot)[!names(data_plot) %in% c(outcome_name, paste0(outcome_name, "_pred"))])

      } else {  # Class probabilities were predicted.

        data_plot <- suppressWarnings(tidyr::gather(data_plot, "outcome", "value",
                                                    -!!names(data_plot)[!names(data_plot) %in% c(outcome_name, outcome_levels)]))
      }
    }
    #--------------------------------------------------------------------------
    # If date indices exist, plot with them.
    if (!is.null(date_indices)) {
      data_plot$index <- data_plot$date_indices
    }
    #--------------------------------------------------------------------------
    if (type == "prediction") {

      if (is.null(outcome_levels)) {  # Numeric outcome; plot historical predictions.

        p <- ggplot()

        #----------------------------------------------------------------------
        # Plot actuals on the first or lowest layer.
        if (is.null(groups)) {  # Single time series.

          p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
                             aes(x = .data$index, y = .data$value, group = .data$ggplot_group), color = "grey50")

        } else {  # Actuals, grouped time series.

          # If faceting by group, this reduces to the single time series case so the actuals
          # will be the default grey so as not to double encode the plot data.
          if (any(facet_names %in% groups)) {

            p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
                               aes(x = .data$index, y = .data$value), color = "grey50")

          } else {

            p <- p + geom_line(data = data_plot[data_plot$outcome == outcome_name, ],
                               aes(x = .data$index, y = .data$value,
                                   group = .data$ggplot_group,
                                   color = .data$ggplot_color), linetype = 2)
            p <- p + scale_color_viridis_d()
          }
        }
        #----------------------------------------------------------------------
        # Plot predictions.

        # If the plotting data.frame has both lower and upper forecasts plot these bounds.
        # We'll add the shading in a lower ggplot layer so the point forecasts are on top in the final plot.
        if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {

          p <- p + geom_ribbon(data = data_plot[data_plot$outcome == outcome_name, ],
                               aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
                                   ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
                                   fill = .data$ggplot_color, group = .data$ggplot_group, color = NULL), alpha = .25, show.legend = FALSE)
        }

        p <- p + geom_line(data = data_plot[data_plot$outcome != outcome_name, ],  # Predictions in melted data.
                           aes(x = .data$index, y = .data$value, group = .data$ggplot_group, color = .data$ggplot_color),
                           size = 1.05, linetype = 1)

        if (!is.null(facet)) {
          p <- p + facet_grid(facet, drop = TRUE)
        }
        p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    } else {  # Factor outcome.

      if (factor_prob) {  # Plot class probabilities with un-grouped data.

        if (is.null(groups)) {

          # Only 1 actual needs to be plotted.
          data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"

          data_actual <- data_plot[data_plot$outcome %in% outcome_name, ]
          data_pred <- data_plot[data_plot$outcome %in% outcome_levels, ]

          data_actual$value <- factor(data_actual$value, levels = outcome_levels, ordered = TRUE)
          data_pred$outcome <- factor(data_pred$outcome, levels = outcome_levels, ordered = TRUE)

          data_pred$value <- as.numeric(data_pred$value)

          p <- ggplot()
          p <- p + geom_col(data = data_pred,
                            aes(x = .data$index, y = .data$value, color = .data$outcome, fill = .data$outcome),
                            position = position_stack(reverse = TRUE))
          p <- p + geom_col(data = data_actual,
                            aes(x = .data$index, y = 1, color = .data$value, fill = .data$value))
          p <- p + scale_y_continuous(limits = 0:1)
          p <- p + scale_color_viridis_d(drop = FALSE)
          p <- p + scale_fill_viridis_d(drop = FALSE)
          p <- p + facet_wrap(~ ggplot_color_group, ncol = 1, scales = "free_y")
          p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))

        } else {  # Plot class probabilities with grouped data.

          data_plot$actual_or_pred <- NA
          data_plot$actual_or_pred[data_plot$outcome == outcome_name] <- "Actual"
          data_plot$actual_or_pred[data_plot$outcome != outcome_name] <- "Prediction"

          data_plot$y_value <- NA
          data_plot$y_value[data_plot$outcome == outcome_name] <- 1
          data_plot$y_value[data_plot$outcome != outcome_name] <- data_plot$value[data_plot$outcome != outcome_name]
          data_plot$y_value <- as.numeric(data_plot$y_value)

          data_plot$plot_color <- NA
          data_plot$plot_color[data_plot$outcome == outcome_name] <- data_plot$value[data_plot$outcome == outcome_name]
          data_plot$plot_color[data_plot$outcome != outcome_name] <- data_plot$outcome[data_plot$outcome != outcome_name]

          data_plot$plot_color <- factor(data_plot$plot_color, levels = outcome_levels, ordered = TRUE)

          p <- ggplot()
          p <- p + geom_col(data = data_plot,
                            aes(x = .data$index, y = .data$y_value, color = .data$plot_color, fill = .data$plot_color),
                            position = position_stack(reverse = TRUE))
          p <- p + scale_y_continuous(limits = 0:1, breaks = c(0, .5, 1))
          p <- p + scale_color_viridis_d(drop = FALSE)
          p <- p + scale_fill_viridis_d(drop = FALSE)
          p <- p + facet_grid(ggplot_color_group + actual_or_pred ~ ., scales = "free_y")
          p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
        }  # End plot class probabilities with grouped data.

      #------------------------------------------------------------------------

      } else {  # Plot class levels.

        if (is.null(groups)) {  # Plot class levels with un-grouped data.

          # Only 1 actual needs to be plotted.
          data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"

          data_plot$value <- factor(data_plot$value, levels = c(outcome_levels), ordered = TRUE)

          p <- ggplot()
          p <- p + geom_col(data = data_plot,
                            aes(x = .data$index, y = 1, color = .data$value, fill = .data$value),
                            position = position_stack(reverse = TRUE))
          p <- p + scale_color_viridis_d(drop = FALSE)
          p <- p + scale_fill_viridis_d(drop = FALSE)
          p <- p + facet_wrap(~ ggplot_color_group, ncol = 1, scales = "free_y")
          p <- p + theme_bw() + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank())

        } else {  # Plot class levels with grouped data.

          data_plot$actual_or_pred <- NA
          data_plot$actual_or_pred[data_plot$outcome == outcome_name] <- "Actual"
          data_plot$actual_or_pred[data_plot$outcome != outcome_name] <- "Prediction"

          data_plot$value <- factor(data_plot$value, levels = outcome_levels, ordered = TRUE)

          p <- ggplot()
          p <- p + geom_col(data = data_plot,
                            aes(x = .data$index, y = 1, color = .data$value, fill = .data$value),
                            position = position_stack(reverse = TRUE))
          p <- p + scale_y_continuous(limits = 0:1, breaks = c(0, .5, 1))
          p <- p + scale_color_viridis_d(drop = FALSE)
          p <- p + scale_fill_viridis_d(drop = FALSE)
          p <- p + facet_grid(ggplot_color_group + actual_or_pred ~ ., scales = "free_y")
          p <- p + theme_bw() + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), panel.spacing = unit(0, "lines"))
        }
      }
    }  # End prediction plots for numeric and factor outcomes.
    #--------------------------------------------------------------------------
    } else if (type == "residual") {

      if (is.null(outcome_levels)) {  # Numeric outcome.

        # Plot historical predictions.
        p <- ggplot(data_plot[data_plot$outcome != outcome_name, ],
                    aes(x = .data$index, y = .data$residual,
                        group = .data$ggplot_group, color = .data$ggplot_color))

        p <- p + geom_line(size = 1.05, linetype = 1)
        p <- p + geom_hline(yintercept = 0)
        p <- p + scale_color_viridis_d()
        p <- p + facet_grid(facet, drop = TRUE)
        p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))

      } else {  # Factor outcome.

        if (is.null(groups)) {

          data_plot$ggplot_color_group <- factor(as.character(data_plot$ggplot_color_group), ordered = TRUE,
                                                 levels = c(unique(as.character(data_plot$ggplot_color_group)), "Actual"))

          data_plot$ggplot_color_group[data_plot$outcome == outcome_name] <- "Actual"

        }

        p <- ggplot()

        # Plot predictions to avoid duplicate residuals in plots.
        p <- p + geom_tile(data = data_plot[data_plot$outcome != outcome_name, ], aes(x = .data$index, y = .data$ggplot_color_group,
                                                 fill = ordered(.data$residual)))
        p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
      }
    }  # End residual plot.
    #--------------------------------------------------------------------------
    # Plot labels.
    if (type == "prediction") {

      if (is.null(outcome_levels)) {  # Numeric outcome.

        if (!is.null(groups)) {  # Grouped time series.

          if (any(facet_names %in% groups)) {

            p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
              ggtitle("Forecasts vs. Actuals Through Time")

          } else {

            p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
              ggtitle("Forecasts vs. Actuals Through Time", subtitle = c("Dashed lines are actuals"))
          }

        } else {  # Single time series.

          p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = legend_title, group = NULL) +
            ggtitle("Forecasts vs. Actuals Through Time")
        }

      } else {  # Factor outcome.

        if (factor_prob) {

          p <- p + xlab("Dataset index") + ylab("Outcome probability") + labs(color = "Outcome", fill = "Outcome") +
            ggtitle("Forecasts vs. Actuals Through Time")

        } else {

          p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Outcome", fill = "Outcome") +
            ggtitle("Forecasts vs. Actuals Through Time")
        }
      }

    } else if (type == "residual") {

      if (is.null(outcome_levels)) {

        p <- p + xlab("Dataset index") + ylab("Residual") + labs(color = NULL, group = NULL) +
          ggtitle("Forecast Error Through Time")

      } else {

        if (is.null(groups)) {

          p <- p + xlab("Dataset index") + ylab("Residual and model") + labs(fill = "Residual") +
            ggtitle("Forecast Error Through Time")

        } else {

          p <- p + xlab("Dataset index") + ylab("Residual, model, and group") + labs(fill = "Residual") +
            ggtitle("Forecast Error Through Time")
        }
      }
    }
    return(suppressWarnings(p))
  }
  #----------------------------------------------------------------------------

  if (type %in% c("forecast_stability")) {

    data_plot$forecast_origin <- with(data_plot, valid_indices - horizon)

    data_plot$group <- with(data_plot, paste0(valid_indices))
    data_plot$group <- ordered(data_plot$group)

    # Plotting the original time-series in each facet. Because the plot is faceted by valid_indices, we'll do a bit of a hack here to create
    # the same line plot for each facet.
    data_outcome <- data_plot %>%
      dplyr::select(valid_indices, !!outcome_name) %>%
      dplyr::distinct(valid_indices, .keep_all = TRUE)
    data_outcome$index <- data_outcome$valid_indices
    data_outcome$valid_indices <- NULL  # remove to avoid confusion in facet_wrap()

    data_outcome <- data_outcome[rep(1:nrow(data_outcome), length(unique(data_plot$valid_indices))), ]

    p <- ggplot()
    if (max(data_plot$horizon) != 1) {
      p <- p + geom_line(data = data_plot, aes(x = .data$forecast_origin,
                                               y = eval(parse(text = paste0(outcome_name, "_pred"))),
                                               color = factor(.data$model)), size = 1, linetype = 1, show.legend = FALSE)
    }
    p <- p + geom_point(data = data_plot, aes(x = .data$forecast_origin,
                                              y = eval(parse(text = paste0(outcome_name, "_pred"))),
                                              color = factor(.data$model)))
    p <- p + geom_point(data = data_plot, aes(x = .data$valid_indices,
                                              y = eval(parse(text = outcome_name)), fill = "Actual"))
    p <- p + scale_color_viridis_d()
    p <- p + facet_wrap(~ valid_indices)

    p <- p + geom_line(data = data_outcome, aes(x = .data$index,
                                                y = eval(parse(text = outcome_name))), color = "gray50")
    p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
    p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Model") + labs(fill = NULL) +
      ggtitle("Rolling Origin Forecast Stability - Faceted by dataset index")

    return(suppressWarnings(p))
  }
} # nocov end
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------

#' Plot an object of class forecast_results
#'
#' A forecast plot for each horizon for each model in \code{predict.forecast_model()}.
#'
#' @param x An object of class 'forecast_results' from \code{predict.forecast_model()}.
#' @param data_actual A data.frame containing the target/outcome name and any grouping columns.
#' The data can be historical actuals and/or holdout/test data.
#' @param actual_indices Required if \code{data_actual} is given. A vector or 1-column data.frame
#' of numeric row indices or dates (class 'Date' or 'POSIXt') with length \code{nrow(data_actual)}.
#' The data can be historical actuals and/or holdout/test data.
#' @param facet Optional. For numeric outcomes, a formula with any combination of \code{horizon}, \code{model}, or \code{group} (for grouped time series)
#' passed to \code{ggplot2::facet_grid()} internally (e.g., \code{horizon ~ model}, \code{horizon + model ~ .}, \code{~ horizon + group}).
#' Can be \code{NULL}.
#' @param models Optional. Filter results by user-defined model name from \code{train_model()}.
#' @param horizons Optional. Filter results by horizon.
#' @param windows Optional. Filter results by validation window number.
#' @param group_filter Optional. A string for filtering plot results for grouped time-series (e.g., \code{"group_col_1 == 'A'"});
#' passed to \code{dplyr::filter()} internally.
#' @param ... Not used.
#' @return Forecast plot of class 'ggplot'.
#' @export
plot.forecast_results <- function(x, data_actual = NULL, actual_indices = NULL, facet = horizon ~ model,
                                  models = NULL, horizons = NULL, windows = NULL,
                                  group_filter = NULL, ...) { # nocov start

  #----------------------------------------------------------------------------
  if(xor(is.null(data_actual), is.null(actual_indices))) {
    stop("If plotting a hold-out or comparison dataset, both 'data_actual' and 'actual_indices' need to be specified.")
  }

  data_forecast <- x
  rm(x)

  type <- "forecast"  # Only one plot option at present.
  #----------------------------------------------------------------------------
  method <- attributes(data_forecast)$method
  forecast_horizons <- attributes(data_forecast)$horizons
  outcome_col <- attributes(data_forecast)$outcome_col
  outcome_name <- attributes(data_forecast)$outcome_name
  outcome_levels <- attributes(data_forecast)$outcome_levels
  date_indices <- attributes(data_forecast)$date_indices
  groups <- attributes(data_forecast)$group
  data_stop <- attributes(data_forecast)$data_stop

  data_forecast <- as.data.frame(data_forecast)  # Coerce to remove new vctrs warning about mismatched classes when joining data.

  if (!is.null(outcome_levels) && !is.null(groups) && !is.null(data_actual)) {
    stop("Plotting forecasts from grouped time series with an actuals dataset is not currently supported.")
  }
  #----------------------------------------------------------------------------
  names(data_forecast)[names(data_forecast) == "forecast_period"] <- "index"  # For code uniformity.
  #----------------------------------------------------------------------------
  # For factor outcomes, is the prediction a factor level or probability.
  if (!is.null(outcome_levels)) {
    factor_level <- if (any(names(data_forecast) %in% paste0(outcome_name, "_pred"))) {TRUE} else {FALSE}
    factor_prob <- !factor_level
  }
  #----------------------------------------------------------------------------
  facets <- forecastML_facet_plot(facet, groups)  # Function in zzz.R.
  facet <- facets[[1]]
  facet_names <- facets[[2]]
  #----------------------------------------------------------------------------
  if (!is.null(data_actual)) {

    data_actual <- data_actual[, c(outcome_name, groups), drop = FALSE]

    data_actual$index <- actual_indices

    if (!is.null(group_filter)) {

      data_actual <- dplyr::filter(data_actual, eval(parse(text = group_filter)))
    }
  }
  #----------------------------------------------------------------------------
  # Filter plots using user input.
  models <- if (is.null(models)) {unique(data_forecast$model)} else {models}
  horizons <- if (is.null(horizons)) {forecast_horizons} else {horizons}
  windows <- if (is.null(windows)) {unique(data_forecast$window_number)} else {windows}

  if (method == "multi_output") {
    data_forecast$model_forecast_horizon <- data_forecast$horizon  # Used for filtering below.
  }

  data_forecast <- data_forecast[data_forecast$model %in% models &
                                 data_forecast$model_forecast_horizon %in% horizons &
                                 data_forecast$window_number %in% windows, ]

  if (!is.null(group_filter)) {
    data_forecast <- dplyr::filter(data_forecast, eval(parse(text = group_filter)))
  }
  #----------------------------------------------------------------------------
  data_plot <- data_forecast
  #----------------------------------------------------------------------------
  if (type == "forecast") {

    #----------------------------------------------------------------------------
    # Set up ggplot color and group parameters.
    if (is.null(outcome_levels)) {  # Numeric outcomes.

      if (is.null(groups)) {

        data_plot <- dplyr::arrange(data_plot, .data$model, .data$model_forecast_horizon, .data$horizon, .data$window_number)

      } else {

        data_plot <- dplyr::arrange(data_plot, .data$model, .data$model_forecast_horizon, .data$window_number, .data$horizon, eval(parse(text = groups)))
      }
      #--------------------------------------------------------------------------
      # ggplot colors and facets are complimentary: all facets, same color; all colors, no facet.
      ggplot_color <- c(c("model", "horizon", groups)[!c("model", "horizon", groups) %in% facet_names])

      if (method == "multi_output") {  # With 1 model, all horizons belong to the same forecast and should have the same color.
        ggplot_color <-  ggplot_color[!ggplot_color %in% "horizon"]
      }
      #--------------------------------------------------------------------------
      # There is a distinction between horizon and the model horizon; rename to work with the facet inputs.
      data_plot$forecast_horizon <- data_plot$horizon
      data_plot$horizon <- data_plot$model_forecast_horizon

      data_plot$ggplot_color <- apply(data_plot[,  ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})

      # Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
      if (length(ggplot_color) == 0) {
        data_plot$ggplot_color <- "Forecast"
      }

      # Used to avoid lines spanning any gaps between validation windows.
      if (all(data_plot$window_number == 1)) {  # One window; no need to add the window number to the legend.

        data_plot$ggplot_group <- apply(data_plot[,  ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})

      } else {

        data_plot$ggplot_group <- apply(data_plot[,  c("window_number", ggplot_color), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
      }

      # Coerce to viridis color scale with an ordered factor. With the data.frame sorted, unique() pulls the levels in their order of appearance.
      data_plot$ggplot_color <- factor(data_plot$ggplot_color, levels = unique(data_plot$ggplot_color), ordered = TRUE)

      data_plot$ggplot_group <- factor(data_plot$ggplot_group, levels = unique(data_plot$ggplot_group), ordered = TRUE)
    }
    #--------------------------------------------------------------------------
    if (!is.null(data_actual)) {

      #--------------------------------------------------------------------------
      # If the plot is faceted by model, repeat the actuals dataset once for each model in a long format for faceting by model.
      if (length(unique(data_plot$model)) == 1) {

        data_actual$model <- unique(data_plot$model)

      } else {

        n_reps <- nrow(data_actual)
        data_actual <- data_actual[rep(1:nrow(data_actual), length(unique(data_plot$model))), ]
        data_actual$model <- rep(unique(data_plot$model), each = n_reps)
      }
      #--------------------------------------------------------------------------
      # If the plot is colored by model, repeat the actuals dataset once for each model in a long format for faceting by model.
      if (length(unique(data_plot$model)) == 1) {

        data_actual$model <- unique(data_plot$model)

      } else {

        n_reps <- nrow(data_actual)
        data_actual <- data_actual[rep(1:nrow(data_actual), length(unique(data_plot$model))), ]
        data_actual$model <- rep(unique(data_plot$model), each = n_reps)
      }
      #--------------------------------------------------------------------------

      if (length(ggplot_color) > 0) {

        if (ggplot_color == "horizon") {

          data_actual$ggplot_color <- NA

        } else {

          data_actual$ggplot_color <- apply(data_actual[,  ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
        }
      } else {

        data_actual$ggplot_color <- "Forecast"
      }

      # Give predictions a name in the legend if plot is faceted by model and horizon (and group if groups are given).
      if (length(ggplot_color) == 0) {
        data_actual$ggplot_color <- "Forecast"
      }

      if (length(ggplot_color) > 0) {

        if (ggplot_color == "horizon") {

          data_actual$ggplot_group <- "Forecast"

        } else {

          data_actual$ggplot_group <- apply(data_actual[,  ggplot_color, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
        }

      } else {

        data_actual$ggplot_group <- "Forecast"
      }
      #------------------------------------------------------------------------
      # Coerce to viridis color scale with an ordered factor. The levels in the actual data are limited
      # to those factor levels that appear in the forecast data.
      data_actual$ggplot_color <- factor(data_actual$ggplot_color, levels = levels(data_plot$ggplot_color), ordered = TRUE)

      data_actual$ggplot_group <- factor(data_actual$ggplot_group, levels = levels(data_plot$ggplot_color), ordered = TRUE)
    }
    #--------------------------------------------------------------------------
    if (is.null(outcome_levels)) {  # Numeric outcome.

      p <- ggplot()

      if (1 %in% horizons || nrow(data_plot) == 1 || (method == "multi_output" && any(c(facet_names %in% "horizon")))) {  # Use geom_point instead of geom_line to plot a 1-step-ahead forecast.

        if (method == "direct") {

          data_plot_temp <- data_plot[data_plot$horizon == 1, ]

        } else if (method == "multi_output") {

          data_plot_temp <- data_plot
        }

        # If the plotting data.frame has both lower and upper forecasts plot these bounds.
        # We'll add the shading in a lower ggplot layer so the point forecasts are on top in the final plot.
        if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {

          # geom_ribbon() does not work with a single data point when forecast bounds are plotted.
          p <- p + geom_linerange(data = data_plot_temp,
                                  aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
                                      ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
                                      color = .data$ggplot_color, group = .data$ggplot_group), alpha = .25, size = 3, show.legend = FALSE)

        }

        p <- p + geom_point(data = data_plot_temp,
                            aes(x = .data$index, y = eval(parse(text = paste0(outcome_name, "_pred"))),
                            color = .data$ggplot_color, group = .data$ggplot_group))
      }

      if ((method == "direct" && !all(horizons == 1)) || (method == "multi_output" && nrow(data_plot) > 1 && !any(c(facet_names %in% "horizon")))) {  # Plot forecasts for model forecast horizons > 1.

        if (method == "direct") {

          data_plot_temp <- data_plot[data_plot$horizon != 1, ]

        } else if (method == "multi_output") {

          data_plot_temp <- data_plot
        }

        # If the plotting data.frame has bother lower and upper forecasts plot these bounds.
        if (all(any(grepl("_pred_lower", names(data_plot))), any(grepl("_pred_upper", names(data_plot))))) {

          p <- p + geom_ribbon(data = data_plot_temp,
                               aes(x = .data$index, ymin = eval(parse(text = paste0(outcome_name, "_pred_lower"))),
                                   ymax = eval(parse(text = paste0(outcome_name, "_pred_upper"))),
                                   fill = .data$ggplot_color, group = .data$ggplot_group, color = NULL), alpha = .25, show.legend = FALSE)
        }

        p <- p + geom_line(data = data_plot_temp,
                           aes(x = .data$index, y = eval(parse(text = paste0(outcome_name, "_pred"))),
                               color = .data$ggplot_color, group = .data$ggplot_group))
      }

      # Add user-defined actuals data to the plots
      if (!is.null(data_actual)) {

        if (is.null(groups)) {

          p <- p + geom_line(data = data_actual, aes(x = .data$index,
                                                     y = eval(parse(text = outcome_name))), color = "grey50")
        } else {

          # If faceting by group, this reduces to the single time series case so the actuals
          # will be the default grey so as not to double encode the plot data.
          if (any(facet_names %in% groups)) {

            p <- p + geom_line(data = data_actual, aes(x = .data$index, y = eval(parse(text = outcome_name))), color = "grey50", show.legend = FALSE)

          } else if (facet_names != c("model", "horizon") || facet_names != c("horizon", "model")) {  # The actuals colors cannot be uniquely mapped to the forecast plot colors.

            p <- p + geom_line(data = data_actual, aes(x = .data$index, y = eval(parse(text = outcome_name))), color = "grey50", show.legend = FALSE)

          } else if (facet_names == c("model", "horizon") || facet_names == c("horizon", "model")) {  # The actuals can be uniquely mapped to the forecasts given the faceting.

            p <- p + geom_line(data = data_actual, aes(x = .data$index,
                                                       y = eval(parse(text = outcome_name)),
                                                       color = .data$ggplot_color,
                                                       group = .data$ggplot_group), show.legend = FALSE)
          }
        }
      }

      p <- p + scale_color_viridis_d()
      p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
      p <- p + facet_grid(facet, scales = "free_y")

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    } else {  # Factor outcome.

      data_plot <- data_forecast

      if (factor_prob) {
        # Melt the data for plotting the multiple class probabilities in stacked bars.
        data_plot <- tidyr::gather(data_plot, "outcome", "value", -!!names(data_plot)[!names(data_plot) %in% c(outcome_levels)])
      }
      #------------------------------------------------------------------------
      # If there is only 1 validation window, remove it from the grouping to reduce clutter.
      if (length(unique(data_plot$window_number)) == 1) {

        data_plot$ggplot_color_group <- apply(data_plot[,  c("model", "model_forecast_horizon", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})

      } else {

        data_plot$ggplot_color_group <- apply(data_plot[,  c("model", "model_forecast_horizon", "window_number", groups), drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
      }
      #------------------------------------------------------------------------
      if (!is.null(data_actual)) {

        # actual or forecast: these are all actuals.
        data_actual$actual_or_forecast <- "actual"
        # historical, test, or model_forecast: these may be any combination of historical data and a holdout test dataset.
        data_actual$time_series_type <- with(data_actual, ifelse(index <= data_stop, "historical", "test"))
        names(data_actual)[names(data_actual) == outcome_name] <- "outcome"  # Standardize before concat with forecasts.
        data_actual$ggplot_color_group <- "Actual"  # Actuals will be plotted in the top plot facet.
        data_actual$value <- 1  # Plot a solid bar with probability 1 in geom_col().

        # In cases where historical data is provided in data_actual, duplicate the historical data
        # such that it appears as a sequence in each plot facet. Here, 'ggplot_color_group' gives the,
        # possibly user-filtered, plot facets.
        if ("historical" %in% unique(data_actual$time_series_type)) {

          data_hist <- data_actual[data_actual$time_series_type == "historical", c("index", "outcome", "value", "ggplot_color_group")]
          n_rows <- nrow(data_hist)
          data_hist <- data_hist[rep(1:nrow(data_hist), length(unique(data_plot$ggplot_color_group))), ]
          data_hist$ggplot_color_group <- rep(unique(data_plot$ggplot_color_group), each = n_rows)

          data_actual <- suppressWarnings(dplyr::bind_rows(data_hist, data_actual))
        }
      }
      #------------------------------------------------------------------------
      # Standardize names for plotting and before any concatenation with data_actual.
      names(data_plot)[names(data_plot) == "index"] <- "index"
      data_plot$actual_or_forecast <- "forecast"
      data_plot$time_series_type <- "model_forecast"
      #------------------------------------------------------------------------
      if (factor_level) {
        # Standardize names for plotting and before any concatenation with data_actual.
        names(data_plot)[names(data_plot) == paste0(outcome_name, "_pred")] <- "outcome"
        data_plot$value <- 1
      }
      #------------------------------------------------------------------------
      if (!is.null(data_actual)) {
        data_plot <- suppressWarnings(dplyr::bind_rows(data_plot, data_actual))
      }

      data_plot$ggplot_color_group <- factor(data_plot$ggplot_color_group, levels = rev(unique(data_plot$ggplot_color_group)), ordered = TRUE)
      data_plot$value <- as.numeric(data_plot$value)
      data_plot$outcome <- factor(data_plot$outcome, levels = outcome_levels, ordered = TRUE)
      #------------------------------------------------------------------------

      if (!is.null(groups)) {
        data_plot <- dplyr::distinct(data_plot, .data$ggplot_color_group, .data$index, .data$outcome, .keep_all = TRUE)
      }

      p <- ggplot()
      p <- p + geom_col(data = data_plot,
                        aes(x = .data$index, y = .data$value, color = .data$outcome, fill = .data$outcome),
                        position = position_stack(reverse = TRUE))
      p <- p + scale_color_viridis_d(drop = FALSE)
      p <- p + scale_fill_viridis_d(drop = FALSE)
      if (is.null(groups)) {
        p <- p + facet_wrap(~ ggplot_color_group, scales = "free_y")
      } else {
        p <- p + facet_grid(ggplot_color_group ~ ., scales = "free_y")
      }
      p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))

      if (factor_level) {
        p <- p + theme(axis.text.y = element_blank(), axis.ticks.y = element_blank(), panel.spacing = unit(0, "lines"))
      }
    }  # End numeric and factor outcome plot setup.
    #--------------------------------------------------------------------------
    # Add a vertical line to mark the beginning of the forecast period.
    p <- p + geom_vline(xintercept = data_stop, color = "red")
    #--------------------------------------------------------------------------

    if (is.null(outcome_levels)) {  # Numeric outcome.

      temp_1 <- unlist(Map(function(x) {toupper(substr(x[1], 1, 1))}, ggplot_color))
      temp_2 <- unlist(Map(function(x) {substr(x, 2, nchar(x))}, ggplot_color))
      x_axis_title <- paste(temp_1, temp_2, sep = "")
      x_axis_title <- paste(x_axis_title, collapse = " + ")

      p <- p + xlab("Dataset index") + ylab("Outcome") +
        labs(color = x_axis_title, fill = NULL) +
        ggtitle("H-Step-Ahead Model Forecasts")

    } else {  # Factor outcome.

      p <- p + xlab("Dataset index") + ylab("Outcome") +
        labs(color = "Outcome", fill = "Outcome") +
        ggtitle("H-Step-Ahead Model Forecasts")
    }

    return(suppressWarnings(p))
  }
} # nocov end

Try the forecastML package in your browser

Any scripts or data that you put into this service are public.

forecastML documentation built on July 8, 2020, 7:27 p.m.