R/plot_forecast.R

Defines functions plot_forecast

Documented in plot_forecast

#' Plot Observed Data and BSTS Forecast
#'
#' Creates a plot of observed data, forecasted values, and confidence intervals.
#'
#' @param forecast A matrix of BSTS forecast samples.
#' @param data_train Numeric vector, training data.
#' @param data_test Numeric vector, test data.
#' @param time Numeric vector, representing time indices.
#' @param quant_high Numeric, upper quantile for confidence interval.
#' @param quant_low Numeric, lower quantile for confidence interval.
#' @param observed_col Character, color for observed data.
#' @param forecast_col Character, color for forecasted data.
#' @param title Character, title of the plot.
#' @return A ggplot2 object.
#' @importFrom ggplot2 ggplot aes geom_line geom_ribbon scale_color_manual scale_fill_manual theme_light labs
#' @importFrom stats quantile
#' @export
#'
plot_forecast <- function(forecast, data_train, data_test, time, quant_high, quant_low, observed_col, forecast_col, title){

  Observed <- c(data_train, rep(NA, length(data_test)))
  Future <- c(rep(NA, (length(data_train)-1)), data_train[length(data_train)], data_test)
  AvgForecast <- c(rep(NA, (length(data_train)-1)), data_train[length(data_train)],colMeans(forecast))
  QuantHigh <- c(rep(NA, length(data_train)), apply(forecast, 2, quantile, probs=quant_high))
  QuantLow <- c(rep(NA, length(data_train)), apply(forecast, 2, quantile, probs=quant_low))
  Time <- as.numeric(time)
  res_data <- data.frame(Time, Observed, Future, AvgForecast, QuantHigh, QuantLow)

  ts_plot <- ggplot2::ggplot(res_data, ggplot2::aes(x = Time)) +
    # Plot observed data
    ggplot2::geom_line(ggplot2::aes(y = Observed, color = "Observed Data"), linewidth = 0.8, na.rm = TRUE) +
    # Plot actual forecasted data
    ggplot2::geom_line(ggplot2::aes(y = Future, color = "Observed Forecast"), linetype = "dotted", linewidth = 1,
                       na.rm = TRUE) +
    # Plot average forecast
    ggplot2::geom_line(ggplot2::aes(y = AvgForecast, color = "Average Forecast"), linewidth = 1, na.rm = TRUE) +
    # Plot the confidence interval as a shaded area
    ggplot2::geom_ribbon(ggplot2::aes(ymin = QuantLow, ymax = QuantHigh, fill = "Confidence Interval"), alpha = 0.45,
                       na.rm = TRUE) +
    # Additional lines for confidence interval bounds
    ggplot2::geom_line(ggplot2::aes(y = QuantLow, color = "Confidence Interval"), linewidth = 0.1, na.rm = TRUE) +
    ggplot2::geom_line(ggplot2::aes(y = QuantHigh, color = "Confidence Interval"), linewidth = 0.1, na.rm = TRUE) +

    # Labels and theme adjustments
    ggplot2::labs(title = title,
         x = "Time", y = "Crop yields") +

    # Define colors for the lines
    ggplot2::scale_color_manual(name = "Legend",
                       values = c("Observed Data" = observed_col,
                                  "Observed Forecast" = observed_col,
                                  "Average Forecast" = forecast_col)) +

    # Define fill color for the ribbon
    ggplot2::scale_fill_manual(name = "",
                      values = c("Confidence Interval" = forecast_col)) +

    ggplot2::theme_light() +

    # Position the legend
    ggplot2::theme(legend.position = "none")

  return(ts_plot)
}

Try the STCCGEV package in your browser

Any scripts or data that you put into this service are public.

STCCGEV documentation built on April 4, 2025, 1:50 a.m.