inst/doc/loo2-lfo.R

params <-
list(EVAL = TRUE)

## ----SETTINGS-knitr, include=FALSE--------------------------------------------
stopifnot(require(knitr))
opts_chunk$set(
  comment=NA,
  eval = if (isTRUE(exists("params"))) params$EVAL else FALSE,
  dev = "png",
  dpi = 150,
  fig.asp = 0.618,
  fig.width = 5,
  out.width = "60%",
  fig.align = "center"
)

## ----more-knitr-ops, include=FALSE--------------------------------------------
knitr::opts_chunk$set(
  cache = TRUE,
  message = FALSE, 
  warning = FALSE
)

## ----pkgs, cache=FALSE--------------------------------------------------------
library("loo")
library("brms")
library("bayesplot")
library("ggplot2")
color_scheme_set("brightblue")
theme_set(theme_default())

CHAINS <- 4
SEED <- 5838296
set.seed(SEED)

## ----hurondata----------------------------------------------------------------
N <- length(LakeHuron)
df <- data.frame(
  y = as.numeric(LakeHuron),
  year = as.numeric(time(LakeHuron)),
  time = 1:N
)

ggplot(df, aes(x = year, y = y)) + 
  geom_point(size = 1) +
  labs(
    y = "Water Level (ft)", 
    x = "Year",
    title = "Water Level in Lake Huron (1875-1972)"
  ) 

## ----fit, results = "hide"----------------------------------------------------
fit <- brm(
  y ~ ar(time, p = 4), 
  data = df, 
  prior = prior(normal(0, 0.5), class = "ar"),
  control = list(adapt_delta = 0.99), 
  seed = SEED, 
  chains = CHAINS
)

## ----plotpreds, cache = FALSE-------------------------------------------------
preds <- posterior_predict(fit)
preds <- cbind(
  Estimate = colMeans(preds), 
  Q5 = apply(preds, 2, quantile, probs = 0.05),
  Q95 = apply(preds, 2, quantile, probs = 0.95)
)

ggplot(cbind(df, preds), aes(x = year, y = Estimate)) +
  geom_smooth(aes(ymin = Q5, ymax = Q95), stat = "identity", size = 0.5) +
  geom_point(aes(y = y)) + 
  labs(
    y = "Water Level (ft)", 
    x = "Year",
    title = "Water Level in Lake Huron (1875-1972)",
    subtitle = "Mean (blue) and 90% predictive intervals (gray) vs. observed data (black)"
  ) 

## ----setL---------------------------------------------------------------------
L <- 20

## ----loo1sap, cache = FALSE---------------------------------------------------
loo_cv <- loo(log_lik(fit)[, (L + 1):N])
print(loo_cv)

## ----exact_loglik, results="hide"---------------------------------------------
loglik_exact <- matrix(nrow = nsamples(fit), ncol = N)
for (i in L:(N - 1)) {
  past <- 1:i
  oos <- i + 1
  df_past <- df[past, , drop = FALSE]
  df_oos <- df[c(past, oos), , drop = FALSE]
  fit_i <- update(fit, newdata = df_past, recompile = FALSE)
  loglik_exact[, i + 1] <- log_lik(fit_i, newdata = df_oos, oos = oos)[, oos]
}

## ----helpers------------------------------------------------------------------
# some helper functions we'll use throughout

# more stable than log(sum(exp(x))) 
log_sum_exp <- function(x) {
  max_x <- max(x)  
  max_x + log(sum(exp(x - max_x)))
}

# more stable than log(mean(exp(x)))
log_mean_exp <- function(x) {
  log_sum_exp(x) - log(length(x))
}

# compute log of raw importance ratios
# sums over observations *not* over posterior samples
sum_log_ratios <- function(loglik, ids = NULL) {
  if (!is.null(ids)) loglik <- loglik[, ids, drop = FALSE]
  rowSums(loglik)
}

# for printing comparisons later
rbind_print <- function(...) {
  round(rbind(...), digits = 2)
}

## ----exact1sap, cache = FALSE-------------------------------------------------
exact_elpds_1sap <- apply(loglik_exact, 2, log_mean_exp)
exact_elpd_1sap <- c(ELPD = sum(exact_elpds_1sap[-(1:L)]))

rbind_print(
  "LOO" = loo_cv$estimates["elpd_loo", "Estimate"],
  "LFO" = exact_elpd_1sap
)

## ----setkthresh---------------------------------------------------------------
k_thres <- 0.7

## ----refit_loglik, results="hide"---------------------------------------------
approx_elpds_1sap <- rep(NA, N)

# initialize the process for i = L
past <- 1:L
oos <- L + 1
df_past <- df[past, , drop = FALSE]
df_oos <- df[c(past, oos), , drop = FALSE]
fit_past <- update(fit, newdata = df_past, recompile = FALSE)
loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
approx_elpds_1sap[L + 1] <- log_mean_exp(loglik[, oos])

# iterate over i > L
i_refit <- L
refits <- L
ks <- NULL
for (i in (L + 1):(N - 1)) {
  past <- 1:i
  oos <- i + 1
  df_past <- df[past, , drop = FALSE]
  df_oos <- df[c(past, oos), , drop = FALSE]
  loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
  
  logratio <- sum_log_ratios(loglik, (i_refit + 1):i)
  psis_obj <- suppressWarnings(psis(logratio))
  k <- pareto_k_values(psis_obj)
  ks <- c(ks, k)
  if (k > k_thres) {
    # refit the model based on the first i observations
    i_refit <- i
    refits <- c(refits, i)
    fit_past <- update(fit_past, newdata = df_past, recompile = FALSE)
    loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
    approx_elpds_1sap[i + 1] <- log_mean_exp(loglik[, oos])
  } else {
    lw <- weights(psis_obj, normalize = TRUE)[, 1]
    approx_elpds_1sap[i + 1] <- log_sum_exp(lw + loglik[, oos])
  }
} 

## ----plot_ks------------------------------------------------------------------
plot_ks <- function(ks, ids, thres = 0.6) {
  dat_ks <- data.frame(ks = ks, ids = ids)
  ggplot(dat_ks, aes(x = ids, y = ks)) + 
    geom_point(aes(color = ks > thres), shape = 3, show.legend = FALSE) + 
    geom_hline(yintercept = thres, linetype = 2, color = "red2") + 
    scale_color_manual(values = c("cornflowerblue", "darkblue")) + 
    labs(x = "Data point", y = "Pareto k") + 
    ylim(-0.5, 1.5)
}

## ----refitsummary1sap, cache=FALSE--------------------------------------------
cat("Using threshold ", k_thres, 
    ", model was refit ", length(refits), 
    " times, at observations", refits)

plot_ks(ks, (L + 1):(N - 1))

## ----lfosummary1sap, cache = FALSE--------------------------------------------
approx_elpd_1sap <- sum(approx_elpds_1sap, na.rm = TRUE)
rbind_print(
  "approx LFO" = approx_elpd_1sap,
  "exact LFO" = exact_elpd_1sap
)

## ----plot1sap, cache = FALSE--------------------------------------------------
dat_elpd <- data.frame(
  approx_elpd = approx_elpds_1sap,
  exact_elpd = exact_elpds_1sap
)

ggplot(dat_elpd, aes(x = approx_elpd, y = exact_elpd)) +
  geom_abline(color = "gray30") +
  geom_point(size = 2) +
  labs(x = "Approximate ELPDs", y = "Exact ELPDs")

## ----diffs1sap, cache=FALSE---------------------------------------------------
max_diff <- with(dat_elpd, max(abs(approx_elpd - exact_elpd), na.rm = TRUE))
mean_diff <- with(dat_elpd, mean(abs(approx_elpd - exact_elpd), na.rm = TRUE))

rbind_print(
  "Max diff" = round(max_diff, 2), 
  "Mean diff" =  round(mean_diff, 3)
)

## ----exact_loglikm, results="hide"--------------------------------------------
M <- 4
loglikm <- matrix(nrow = nsamples(fit), ncol = N)
for (i in L:(N - M)) {
  past <- 1:i
  oos <- (i + 1):(i + M)
  df_past <- df[past, , drop = FALSE]
  df_oos <- df[c(past, oos), , drop = FALSE]
  fit_past <- update(fit, newdata = df_past, recompile = FALSE)
  loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
  loglikm[, i + 1] <- rowSums(loglik[, oos])
}

## ----exact4sap, cache = FALSE-------------------------------------------------
exact_elpds_4sap <- apply(loglikm, 2, log_mean_exp)
(exact_elpd_4sap <- c(ELPD = sum(exact_elpds_4sap, na.rm = TRUE)))

## ----refit_loglikm, results="hide"--------------------------------------------
approx_elpds_4sap <- rep(NA, N)

# initialize the process for i = L
past <- 1:L
oos <- (L + 1):(L + M)
df_past <- df[past, , drop = FALSE]
df_oos <- df[c(past, oos), , drop = FALSE]
fit_past <- update(fit, newdata = df_past, recompile = FALSE)
loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
loglikm <- rowSums(loglik[, oos])
approx_elpds_1sap[L + 1] <- log_mean_exp(loglikm)

# iterate over i > L
i_refit <- L
refits <- L
ks <- NULL
for (i in (L + 1):(N - M)) {
  past <- 1:i
  oos <- (i + 1):(i + M)
  df_past <- df[past, , drop = FALSE]
  df_oos <- df[c(past, oos), , drop = FALSE]
  loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
  
  logratio <- sum_log_ratios(loglik, (i_refit + 1):i)
  psis_obj <- suppressWarnings(psis(logratio))
  k <- pareto_k_values(psis_obj)
  ks <- c(ks, k)
  if (k > k_thres) {
    # refit the model based on the first i observations
    i_refit <- i
    refits <- c(refits, i)
    fit_past <- update(fit_past, newdata = df_past, recompile = FALSE)
    loglik <- log_lik(fit_past, newdata = df_oos, oos = oos)
    loglikm <- rowSums(loglik[, oos])
    approx_elpds_4sap[i + 1] <- log_mean_exp(loglikm)
  } else {
    lw <- weights(psis_obj, normalize = TRUE)[, 1]
    loglikm <- rowSums(loglik[, oos])
    approx_elpds_4sap[i + 1] <- log_sum_exp(lw + loglikm)
  }
} 

## ----refitsummary4sap, cache = FALSE------------------------------------------
cat("Using threshold ", k_thres, 
    ", model was refit ", length(refits), 
    " times, at observations", refits)

plot_ks(ks, (L + 1):(N - M))

## ----lfosummary4sap, cache = FALSE--------------------------------------------
approx_elpd_4sap <- sum(approx_elpds_4sap, na.rm = TRUE)
rbind_print(
  "Approx LFO" = approx_elpd_4sap,
  "Exact LFO" = exact_elpd_4sap
)

## ----plot4sap, cache = FALSE--------------------------------------------------
dat_elpd_4sap <- data.frame(
  approx_elpd = approx_elpds_4sap,
  exact_elpd = exact_elpds_4sap
)

ggplot(dat_elpd_4sap, aes(x = approx_elpd, y = exact_elpd)) +
  geom_abline(color = "gray30") +
  geom_point(size = 2) +
  labs(x = "Approximate ELPDs", y = "Exact ELPDs")

## ----sessioninfo--------------------------------------------------------------
sessionInfo()

Try the loo package in your browser

Any scripts or data that you put into this service are public.

loo documentation built on March 31, 2023, 10:11 p.m.