R/sbc.R

Defines functions sbc_power_plot sbc_power summarise_sbc_diagnostics plot_sbc_params sbc

sbc <- function(model, generator, N_steps, ...) {
  true_list <- list()
  observed_list <- list()
  for(i in 1:N_steps) {
    data <- generator()
    true_list[[i]] <- data$true
    observed_list[[i]] <- data$observed
  }

  fits <- sampling_multi(model, observed_list, ...)

  param_stats <-
    fits %>% imap_dfr(function(fit, data_id) {
      samples <- rstan::extract(fit, pars = names(true_list[[data_id]]))
      eval <- evaluate_all_params(samples, true_list[[data_id]])
      eval %>% mutate(run = data_id)
    })
  diagnostics <-
    fits %>% imap_dfr(function(fit, run_id) {
      data.frame(run = run_id,
                 n_divergent = rstan::get_num_divergent(fit),
                 n_treedepth = rstan::get_num_max_treedepth(fit),
                 n_chains_low_bfmi = length(rstan::get_low_bfmi_chains(fit)),
                 total_time = sum(rstan::get_elapsed_time(fit)),
                 min_n_eff = min(summary(fit)$summary[,"n_eff"], na.rm = TRUE),
                 max_Rhat = max(summary(fit)$summary[,"Rhat"], na.rm = TRUE)
      )
    })

  return(list(params = param_stats, diagnostics = diagnostics, data = observed_list, true_values = true_list))
}

plot_sbc_params <- function(params, binwidth = 10, caption = NULL) {
  #CI - taken from https://github.com/seantalts/simulation-based-calibration/blob/master/Rsbc/generate_plots_sbc_inla.R

  if(100 %% binwidth != 0) {
    stop("binwidth has to divide 100")
  }
  n_simulations <- length(unique(params$run))
  CI = qbinom(c(0.005,0.5,0.995), size=n_simulations,prob  =  binwidth / 100)
  lower = CI[1]
  mean = CI[2]
  upper = CI[3]

  #The visualisation style taken as well from   https://github.com/seantalts/simulation-based-calibration/blob/master/Rsbc/generate_plots_sbc_inla.R
  print(params %>%
          ggplot(aes(x = order_within)) +
          geom_segment(aes(x=0,y=mean,xend=100,yend=mean),colour="grey25") +
          geom_polygon(data=data.frame(x=c(-10,0,-10,110,100,110,-10),y=c(lower,mean,upper,upper,mean,lower,lower)),aes(x=x,y=y),fill="grey45",color="grey25",alpha=0.5) +
          geom_histogram(breaks =  seq(1, 101, by = binwidth), closed = "left" ,fill="#A25050",colour="black") +
          facet_wrap(~param_name, scales = "free_y") +
          ggtitle("Posterior order within 100 samples")
  )

  point_alpha <- 1 / ((n_simulations * 0.03) + 1)
  print(params %>%
          ggplot(aes(x = true_value, y = median)) + geom_point(alpha = point_alpha) +
          geom_abline(slope = 1, intercept = 0, color = "blue") +
          facet_wrap(~param_name, scales = "free")  +
          ggtitle(paste0("Median of marginal posteriors vs. true value - ", caption))
  )
}

summarise_sbc_diagnostics <- function(sbc_results) {
  sbc_results$diagnostics %>%
    summarise(
      has_divergence = mean(n_divergent > 0),
      has_treedepth = mean(n_treedepth > 0),
      has_low_bfmi = mean(n_chains_low_bfmi > 0),
      median_total_time = median(total_time),
      low_neff = mean(min_n_eff < 100),
      high_Rhat = mean(max_Rhat > 1.1)
      )

}

sbc_power <- function(sbc_params, effect_threshold = 0.1) {
  sbc_params %>%
    mutate(sign_determined = sign(q_upper) == sign(true_value) & sign(q_lower) == sign(true_value),
           effect_above_threshold = sign_determined & (q_lower > effect_threshold | q_upper < -effect_threshold)
    )
}

sbc_power_plot <- function(sbc_power_res, group = 1, size_of_interest = c(0.1,0.5,1)) {
  group <- enquo(group)
  # if(length(unique(sbc_power_res$run)) < 30) {
  #   jitter_aes = aes(color = as.factor(run))
  # } else {
  #   jitter_aes = aes()
  # }
  point_alpha = max(0.01,1 / ((nrow(sbc_power_res) * 0.01) + 1))
  sbc_power_res %>%
    gather("stat", "value", sign_determined, effect_above_threshold) %>%
    mutate(value = as.numeric(value)) %>%
    ggplot(aes(x = abs(true_value), y = value, color = !!group, group = !!group)) +
      geom_jitter(width = 0,height = 0.1, alpha = point_alpha) +
      geom_smooth(method = 'gam', formula = y ~ s(x, bs = "cs"), se = FALSE) +
      geom_vline(xintercept = size_of_interest) +
      scale_x_log10() +
      facet_wrap(~stat, ncol = 1)
}
martinmodrak/revize-rs documentation built on March 9, 2021, 5:30 a.m.