knitr::opts_chunk$set( collapse = TRUE, comment = "#>" )
The following examples walk through using chkptstanr with the Stan
The basic idea is to (1) write a custom Stan model (done by the user),
(2) fit the model with cmdstanr
(with the desired number of checkpoints), and then (3) return a cmststanr
object.
All but step (1) is done internally, so the workflow is very similar to using
cmdstanr.
library(chkptstanr) library(posterior) library(bayesplot)
The initial overhead is to create a folder that will store the checkpoints, i.e.,
path <- create_folder(folder_name = "chkpt_folder_m1")
Next is the Stan model:
stan_code <- " data { int<lower=0> n; real y[n]; real<lower=0> sigma[n]; } parameters { real mu; real<lower=0> tau; vector[n] eta; } transformed parameters { vector[n] theta; theta = mu + tau * eta; } model { target += normal_lpdf(eta | 0, 1); target += normal_lpdf(y | theta, sigma); } "
When using chkpt_stan()
, this requires supplying a list to the data
argument, much like using rstan.
stan_data <- schools.data <- list( n = 8, y = c(28, 8, -3, 7, -1, 1, 18, 12), sigma = c(15, 10, 16, 11, 9, 11, 10, 18) )
To show the basic idea of checkpointing, the following was stopped after 2 checkpoints.
fit_m1 <- chkpt_stan(model_code = stan_code, data = stan_data, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, path = path) #> Compiling Stan program... #> Initial Warmup (Typical Set) #> Chkpt: 1 / 8; Iteration: 250 / 2000 (warmup) #> Chkpt: 2 / 8; Iteration: 500 / 2000 (warmup)
To finish the remaining 6 checkpoints run the same code, i.e.,
fit_m1 <- chkpt_stan(model_code = stan_code, data = stan_data, iter_warmup = 1000, iter_sampling = 1000, iter_per_chkpt = 250, path = path) #> Sampling next checkpoint #> Chkpt: 3 / 8; Iteration: 750 / 2000 (warmup) #> Chkpt: 4 / 8; Iteration: 1000 / 2000 (warmup) #> Chkpt: 5 / 8; Iteration: 1250 / 2000 (sample) #> Chkpt: 6 / 8; Iteration: 1500 / 2000 (sample) #> Chkpt: 7 / 8; Iteration: 1750 / 2000 (sample) #> Chkpt: 8 / 8; Iteration: 2000 / 2000 (sample) #> Checkpointing complete
Each checkpoint contains 250 draws from the posterior. These need to be
combined with combine_chkpt_draws()
, i.e.,
draws <- combine_chkpt_draws(fit_m1)
We developed chkptstanr to work seamlessly with the Stan ecosystem.
The object draws
has been constructed to mimic what is provided when
using cmdstanr directly.
combine_chkpt_draws(fit_m1) #> # A draws_array: 1000 iterations, 2 chains, and 19 variables #> , , variable = lp__ #> #> chain #> iteration 1 2 #> 1 -34 -43 #> 2 -37 -41 #> 3 -36 -39 #> 4 -38 -38 #> 5 -38 -41 #> #> , , variable = mu #> #> chain #> iteration 1 2 #> 1 5.2 2.6 #> 2 11.3 6.7 #> 3 -2.7 5.3 #> 4 -2.9 3.7 #> 5 -2.7 14.2 #> #> , , variable = tau #> #> chain #> iteration 1 2 #> 1 23.3 2.61 #> 2 6.7 0.21 #> 3 12.7 4.44 #> 4 21.1 7.29 #> 5 18.8 10.94 #> #> , , variable = eta[1] #> #> chain #> iteration 1 2 #> 1 0.10 -0.61 #> 2 0.89 -0.87 #> 3 1.62 0.83 #> 4 1.99 0.84 #> 5 -0.16 1.22 #> #> # ... with 995 more iterations, and 15 more variables
draws
can then be used with the R
package posterior
posterior::summarise_draws(draws) #> # A tibble: 19 x 10 #> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail #> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 lp__ -39.5 -39.2 2.59 2.58 -44.2 -35.9 1.00 640. 1008. #> 2 mu 7.77 7.92 5.48 5.10 -1.43 16.0 1.01 530. 325. #> 3 tau 6.82 5.32 5.75 4.71 0.434 18.7 1.00 649. 658. #> 4 eta[1] 0.383 0.413 0.929 0.909 -1.20 1.87 1.00 1650. 1233. #> 5 eta[2] -0.00335 -0.00816 0.841 0.814 -1.34 1.40 1.00 1443. 1307. #> 6 eta[3] -0.176 -0.174 0.931 0.906 -1.67 1.42 1.00 1829. 1424. #> 7 eta[4] -0.00521 0.000856 0.862 0.841 -1.47 1.39 1.00 1565. 1407. #> 8 eta[5] -0.312 -0.350 0.873 0.835 -1.72 1.24 1.00 1661. 1616. #> 9 eta[6] -0.193 -0.190 0.889 0.909 -1.59 1.28 1.00 1915. 1404. #> 10 eta[7] 0.387 0.358 0.876 0.864 -1.09 1.81 1.00 1574. 1370. #> 11 eta[8] 0.0805 0.0611 0.970 0.960 -1.51 1.66 1.00 1031. 1236. #> 12 theta[1] 11.5 10.2 8.29 6.99 0.268 26.4 1.00 1042. 728. #> 13 theta[2] 7.87 7.87 6.20 5.66 -2.27 17.8 1.00 1549. 1515. #> 14 theta[3] 6.01 6.63 8.25 6.63 -8.69 18.1 1.00 1102. 1075. #> 15 theta[4] 7.75 7.76 6.65 5.96 -3.06 18.9 1.00 1674. 1210. #> 16 theta[5] 5.05 5.70 6.44 5.75 -7.06 14.4 1.00 1405. 1416. #> 17 theta[6] 6.21 6.60 6.92 6.15 -5.98 16.9 1.00 1890. 1195. #> 18 theta[7] 10.8 10.1 6.71 6.03 0.992 23.1 1.00 1497. 1767. #> 19 theta[8] 8.35 8.41 7.72 6.66 -3.88 20.7 1.00 1081. 1075.
The popular R
package bayesplot can also be used.
bayesplot::mcmc_trace(draws) + geom_vline(xintercept = seq(0, 1000, 250), alpha = 0.25, size = 2)
This vertical lines are placed at each checkpoint.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.