SBC was primarily designed for continuous parameters, but can be used with models that have discrete parameters - whether the parameters are directly represented (e.g. in BUGS/JAGS) or marginalized out (as is usual in Stan).

Stan version and debugging

library(SBC); 
library(ggplot2)

use_cmdstanr <- getOption("SBC.vignettes_cmdstanr", TRUE) # Set to false to use rstan instead

if(use_cmdstanr) {
  library(cmdstanr)
} else {
  library(rstan)
  rstan_options(auto_write = TRUE)
}

# Multiprocessing support
library(future)
plan(multisession)

# The fits are very fast and we fit just a few, 
# so we force a minimum chunk size to reduce overhead of
# paralellization and decrease computation time.
options(SBC.min_chunk_size = 5)

# Setup caching of results
if(use_cmdstanr) {
  cache_dir <- "./_discrete_vars_SBC_cache"
} else {
  cache_dir <- "./_discrete_vars_rstan_SBC_cache"
}
cache_dir_jags <- "./_discrete_vars_SBC_cache"

if(!dir.exists(cache_dir)) {
  dir.create(cache_dir)
}
if(!dir.exists(cache_dir_jags)) {
  dir.create(cache_dir_jags)
}

We take the changepoint model from: https://mc-stan.org/docs/2_26/stan-users-guide/change-point-section.html

cat(readLines("stan/discrete_vars1.stan"), sep = "\n")
if(use_cmdstanr) {
  model_1 <- cmdstan_model("stan/discrete_vars1.stan")
  backend_1 <- SBC_backend_cmdstan_sample(model_1)
} else {
  model_1 <- stan_model("stan/discrete_vars1.stan")
  backend_1 <- SBC_backend_rstan_sample(model_1)
}

Now, let's generate data from the model.

generate_single_sim_1 <- function(T, r_e, r_l) {
  e <- rexp(1, r_e)
  l <- rexp(1, r_l)
  s <- sample.int(T, size = 1)

  y <- array(NA_real_, T)
  for(t in 1:T) {
    if(t <= s) {
      rate <- e
    } else {
      rate <- l
    }
    y[t] <- rpois(1, rate) 
  }

  list(
    variables = list(
      e = e, l = l, s = s
    ), generated = list(
      T = T,
      r_e = r_e,
      r_l = r_l,
      y = y
    )
  )
}

generator_1 <- SBC_generator_function(generate_single_sim_1, T = 5, r_e = 0.5, r_l = 0.1)
set.seed(85394672)
datasets_1 <- generate_datasets(generator_1, 30)

Additionally, we'll add a derived quantity expressing the total log-likelihood of data given the fitted parameters. The expression within the derived_quantities() call is evaluated for both prior and posterior draws and included as another variable in SBC checks. It turns out this type of derived quantities can increase the sensitivity of the SBC against some issues in the model. See vignette("limits_of_SBC") for a more detailed discussion of this.

log_lik_dq <- derived_quantities(log_lik = sum(dpois(y, ifelse(1:T < s, e, l), log = TRUE)))

So finally, lets actually compute SBC:

results_1 <- compute_SBC(datasets_1, backend_1, 
                    cache_mode = "results", 
                    cache_location = file.path(cache_dir, "model1"),
                    dquants = log_lik_dq)

Here we also use the caching feature to avoid recomputing the fits when recompiling this vignette. In practice, caching is not necessary but is often useful.

TODO the diagnostic failures are false positives, because Rhat and ESS don't work very well for discrete parameters. We need to figure out how to handle this better.

We can quickly note that the statistics for the s parameter are extreme - many ranks of 0 and extreme z-scores, including -Infinity. Seing just one or two such fits should be enough to convince us that there is something fundamentally wrong.

dplyr::filter(results_1$stats, variable == "s") 

Inspecting the statistics shows that quite often, the model is quite sure of the value of s while the simulated value is just one less.

Looking at the ecdf_diff plot we see that this seems to compromise heavily the inference for s, but the other parameters do not show such bad behaviour. Note that the log_lik derived quantity shows even starker failure than s, so it indeed poses a stricter check in this scenario.

plot_ecdf_diff(results_1)
plot_rank_hist(results_1)

An important note: you may wonder, how we got such a wiggly line for the s parameter - doesn't it have just 5 possible values? Shouldn't therefore the ECDF be one big staircase? In fact the package does a little trick to make discrete parameters comparable to continuous - the rank of a discrete parameter is chosen uniformly randomly across all possible ranks (i.e. posterior draws that have exactly equal value). This means that if the model is well behaved, ranks for the discrete parameter will be uniformly distributed across the whole range of possible ranks and we can use exactly the same diagnostics for a discrete parameter as we do for the continuous ones.

But back to the model - what happened? What is wrong with it? After some inspection, you may notice that the simulator does not match the model - the model takes the early rate (e) for points t < s while the simulator takes e for points t <= s, so there is effectively a shift by one time point between the simulator and the model. So let's assume that we beleive that the Stan model is in fact right. We therefore updated the simulator to match the model:

generate_single_sim_2 <- function(T, r_e, r_l) {
  e <- rexp(1, r_e)
  l <- rexp(1, r_l)
  s <- sample.int(T, size = 1)

  y <- array(NA_real_, T)
  for(t in 1:T) {
    if(t < s) { ### <--- Only change here
      rate <- e
    } else {
      rate <- l
    }
    y[t] <- rpois(1, rate) 
  }

  list(
    variables = list(
      e = e, l = l, s = s
    ), generated = list(
      T = T,
      r_e = r_e,
      r_l = r_l,
      y = y
    )
  )
}

generator_2 <- SBC_generator_function(generate_single_sim_2, T = 5, r_e = 0.5, r_l = 0.1)

And we can recompute:

set.seed(5846502)
datasets_2 <- generate_datasets(generator_2, 30)
results_2 <- compute_SBC(datasets_2, backend_1,
                    dquants = log_lik_dq, 
                    cache_mode = "results", 
                    cache_location = file.path(cache_dir, "model2"))
plot_rank_hist(results_2)
plot_ecdf_diff(results_2)

Looks good, so let us add some more simulations to make sure the model behaves well.

set.seed(54321488)
datasets_2_more <- generate_datasets(generator_2, 100)
results_2_more <- compute_SBC(datasets_2_more, backend_1,
                    dquants = log_lik_dq, 
                    cache_mode = "results", 
                    cache_location = file.path(cache_dir, "model3"))

results_2_all <- bind_results(results_2, results_2_more)
plot_rank_hist(results_2_all)
plot_ecdf_diff(results_2_all)

Now - as far as this amount of SBC steps can see, the model is good and we get good behaviour for both the continuous and the discrete parameters and the log_lik derived quantity. Hooray!

JAGS version

We can now write the same model in JAGS. This becomes a bit easier as JAGS lets us represent discrete parameters directly:

cat(readLines("other_models/changepoint.jags"), sep = "\n")

We will use the rjags package, let us verify it is installed correctly.

library(rjags)

We will also default to relatively large number of samples as we can expect some autocorrelation in the Gibbs sampler.

backend_jags <- SBC_backend_rjags("other_models/changepoint.jags",
                                  variable.names = c("e","l","s"),
                                  n.iter = 10000,
                                  n.burnin = 1000,
                                  n.chains = 4,
                                  thin = 10)

Running SBC with all the corrected datasets from before (rJAGS accepts input data in exactly the same format as Stan, so we can reuse the datasets without any change):

datasets_2_all <- bind_datasets(datasets_2, datasets_2_more)
results_jags <- compute_SBC(datasets_2_all, backend_jags,
                            dquants = log_lik_dq,
                        cache_mode = "results",
                        cache_location = file.path(cache_dir_jags, "rjags"))

Similarly to the case above, the Rhat and ESS warnings are false positives due to the s parameter, which we need to handle better.

The rank plots show no problems.

plot_rank_hist(results_jags)
plot_ecdf_diff(results_jags)

As an exercise, we can also write the marginalized version of the model in JAGS. In some cases, marginalization improves performance even for JAGS models, however, for this model it is actually not an improvement, presumably because the model is very simple.

cat(readLines("other_models/changepoint_marginalized.jags"), sep = "\n")

The code got quite a bit more complex, se let's check if we didn't mess up the rewrite - first we build a backend with this new representation:

backend_jags_marginalized <- SBC_backend_rjags("other_models/changepoint_marginalized.jags",
                                  variable.names = c("e","l","s"),
                                  n.iter = 10000,
                                  n.burnin = 1000,
                                  n.chains = 4,
                                  thin = 10)

Then we run the actual SBC:

results_jags_marginalized <- compute_SBC(datasets_2_all, backend_jags_marginalized,
                                         dquants = log_lik_dq,
                        cache_mode = "results",
                        cache_location = file.path(cache_dir_jags, "rjags_marginalized"))

And the ranks plots look good, so we indeed probably did succeed in correctly marginalizing the s parameter!

plot_rank_hist(results_jags_marginalized)
plot_ecdf_diff(results_jags_marginalized)


hyunjimoon/SBC documentation built on March 15, 2024, 3:18 a.m.