sbc <- function(model, generator, N_steps, ...) {
true_list <- list()
observed_list <- list()
for(i in 1:N_steps) {
data <- generator()
true_list[[i]] <- data$true
observed_list[[i]] <- data$observed
}
fits <- sampling_multi(model, observed_list, ...)
param_stats <-
fits %>% imap_dfr(function(fit, data_id) {
samples <- rstan::extract(fit, pars = names(true_list[[data_id]]))
eval <- evaluate_all_params(samples, true_list[[data_id]])
eval %>% mutate(run = data_id)
})
diagnostics <-
fits %>% imap_dfr(function(fit, run_id) {
data.frame(run = run_id,
n_divergent = rstan::get_num_divergent(fit),
n_treedepth = rstan::get_num_max_treedepth(fit),
n_chains_low_bfmi = length(rstan::get_low_bfmi_chains(fit)),
total_time = sum(rstan::get_elapsed_time(fit)),
min_n_eff = min(summary(fit)$summary[,"n_eff"], na.rm = TRUE),
max_Rhat = max(summary(fit)$summary[,"Rhat"], na.rm = TRUE)
)
})
return(list(params = param_stats, diagnostics = diagnostics, data = observed_list, true_values = true_list))
}
plot_sbc_params <- function(params, binwidth = 10, caption = NULL) {
#CI - taken from https://github.com/seantalts/simulation-based-calibration/blob/master/Rsbc/generate_plots_sbc_inla.R
if(100 %% binwidth != 0) {
stop("binwidth has to divide 100")
}
n_simulations <- length(unique(params$run))
CI = qbinom(c(0.005,0.5,0.995), size=n_simulations,prob = binwidth / 100)
lower = CI[1]
mean = CI[2]
upper = CI[3]
#The visualisation style taken as well from https://github.com/seantalts/simulation-based-calibration/blob/master/Rsbc/generate_plots_sbc_inla.R
print(params %>%
ggplot(aes(x = order_within)) +
geom_segment(aes(x=0,y=mean,xend=100,yend=mean),colour="grey25") +
geom_polygon(data=data.frame(x=c(-10,0,-10,110,100,110,-10),y=c(lower,mean,upper,upper,mean,lower,lower)),aes(x=x,y=y),fill="grey45",color="grey25",alpha=0.5) +
geom_histogram(breaks = seq(1, 101, by = binwidth), closed = "left" ,fill="#A25050",colour="black") +
facet_wrap(~param_name, scales = "free_y") +
ggtitle("Posterior order within 100 samples")
)
point_alpha <- 1 / ((n_simulations * 0.03) + 1)
print(params %>%
ggplot(aes(x = true_value, y = median)) + geom_point(alpha = point_alpha) +
geom_abline(slope = 1, intercept = 0, color = "blue") +
facet_wrap(~param_name, scales = "free") +
ggtitle(paste0("Median of marginal posteriors vs. true value - ", caption))
)
}
summarise_sbc_diagnostics <- function(sbc_results) {
sbc_results$diagnostics %>%
summarise(
has_divergence = mean(n_divergent > 0),
has_treedepth = mean(n_treedepth > 0),
has_low_bfmi = mean(n_chains_low_bfmi > 0),
median_total_time = median(total_time),
low_neff = mean(min_n_eff < 100),
high_Rhat = mean(max_Rhat > 1.1)
)
}
sbc_power <- function(sbc_params, effect_threshold = 0.1) {
sbc_params %>%
mutate(sign_determined = sign(q_upper) == sign(true_value) & sign(q_lower) == sign(true_value),
effect_above_threshold = sign_determined & (q_lower > effect_threshold | q_upper < -effect_threshold)
)
}
sbc_power_plot <- function(sbc_power_res, group = 1, size_of_interest = c(0.1,0.5,1)) {
group <- enquo(group)
# if(length(unique(sbc_power_res$run)) < 30) {
# jitter_aes = aes(color = as.factor(run))
# } else {
# jitter_aes = aes()
# }
point_alpha = max(0.01,1 / ((nrow(sbc_power_res) * 0.01) + 1))
sbc_power_res %>%
gather("stat", "value", sign_determined, effect_above_threshold) %>%
mutate(value = as.numeric(value)) %>%
ggplot(aes(x = abs(true_value), y = value, color = !!group, group = !!group)) +
geom_jitter(width = 0,height = 0.1, alpha = point_alpha) +
geom_smooth(method = 'gam', formula = y ~ s(x, bs = "cs"), se = FALSE) +
geom_vline(xintercept = size_of_interest) +
scale_x_log10() +
facet_wrap(~stat, ncol = 1)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.