predict.mcmc_output: Predictions for State Space Models

View source: R/predict.R

predict.mcmc_outputR Documentation

Predictions for State Space Models

Description

Draw samples from the posterior predictive distribution for future time points given the posterior draws of hyperparameters \theta and latent state alpha_{n+1} returned by run_mcmc. Function can also be used to draw samples from the posterior predictive distribution p(\tilde y_1, \ldots, \tilde y_n | y_1,\ldots, y_n).

Usage

## S3 method for class 'mcmc_output'
predict(
  object,
  model,
  nsim,
  type = "response",
  future = TRUE,
  seed = sample(.Machine$integer.max, size = 1),
  ...
)

Arguments

object

Results object of class mcmc_output from run_mcmc.

model

A bssm_model object. Should have same structure and class as the original model which was used in run_mcmc, in order to plug the posterior samples of the model parameters to the right places. It is also possible to input the original model for obtaining predictions for past time points. In this case, set argument future to FALSE.

nsim

Positive integer defining number of samples to draw. Should be less than or equal to sum(object$counts) i.e. the number of samples in the MCMC output. Default is to use all the samples.

type

Type of predictions. Possible choices are "mean" "response", or "state" level.

future

Default is TRUE, in which case predictions are for the future, using posterior samples of (theta, alpha_T+1) i.e. the posterior samples of hyperparameters and latest states. Otherwise it is assumed that model corresponds to the original model.

seed

Seed for the C++ RNG (positive integer). Note that this affects only the C++ side, and predict also uses R side RNG for subsampling, so for replicable results you should call set.seed before predict.

...

Ignored.

Value

A data.frame consisting of samples from the predictive posterior distribution.

See Also

fitted for in-sample predictions.

Examples

library("graphics")
y <- log10(JohnsonJohnson)
prior <- uniform(0.01, 0, 1)
model <- bsm_lg(window(y, end = c(1974, 4)), sd_y = prior,
  sd_level = prior, sd_slope = prior, sd_seasonal = prior)

mcmc_results <- run_mcmc(model, iter = 5000)
future_model <- model
future_model$y <- ts(rep(NA, 25), 
  start = tsp(model$y)[2] + 2 * deltat(model$y), 
  frequency = frequency(model$y))
# use "state" for illustrative purposes, we could use type = "mean" directly
pred <- predict(mcmc_results, model = future_model, type = "state", 
  nsim = 1000)

library("dplyr")
sumr_fit <- as.data.frame(mcmc_results, variable = "states") |>
  group_by(time, iter) |> 
  mutate(signal = 
      value[variable == "level"] + 
      value[variable == "seasonal_1"]) |>
  group_by(time) |>
  summarise(mean = mean(signal), 
    lwr = quantile(signal, 0.025), 
    upr = quantile(signal, 0.975))

sumr_pred <- pred |> 
  group_by(time, sample) |>
  mutate(signal = 
      value[variable == "level"] + 
      value[variable == "seasonal_1"]) |>
  group_by(time) |>
  summarise(mean = mean(signal),
    lwr = quantile(signal, 0.025), 
    upr = quantile(signal, 0.975)) 
    
# If we used type = "mean", we could do
# sumr_pred <- pred |> 
#   group_by(time) |>
#   summarise(mean = mean(value),
#     lwr = quantile(value, 0.025), 
#     upr = quantile(value, 0.975)) 
    
library("ggplot2")
rbind(sumr_fit, sumr_pred) |> 
  ggplot(aes(x = time, y = mean)) + 
  geom_ribbon(aes(ymin = lwr, ymax = upr), 
   fill = "#92f0a8", alpha = 0.25) +
  geom_line(colour = "#92f0a8") +
  theme_bw() + 
  geom_point(data = data.frame(
    mean = log10(JohnsonJohnson), 
    time = time(JohnsonJohnson)))

# Posterior predictions for past observations:
yrep <- predict(mcmc_results, model = model, type = "response", 
  future = FALSE, nsim = 1000)
meanrep <- predict(mcmc_results, model = model, type = "mean", 
  future = FALSE, nsim = 1000)
  
sumr_yrep <- yrep |> 
  group_by(time) |>
  summarise(earnings = mean(value),
    lwr = quantile(value, 0.025), 
    upr = quantile(value, 0.975)) |>
  mutate(interval = "Observations")

sumr_meanrep <- meanrep |> 
  group_by(time) |>
  summarise(earnings = mean(value),
    lwr = quantile(value, 0.025), 
    upr = quantile(value, 0.975)) |>
  mutate(interval = "Mean")
    
rbind(sumr_meanrep, sumr_yrep) |> 
  mutate(interval = 
    factor(interval, levels = c("Observations", "Mean"))) |>
  ggplot(aes(x = time, y = earnings)) + 
  geom_ribbon(aes(ymin = lwr, ymax = upr, fill = interval), 
   alpha = 0.75) +
  theme_bw() + 
  geom_point(data = data.frame(
    earnings = model$y, 
    time = time(model$y)))    



bssm documentation built on Nov. 2, 2023, 6:25 p.m.