tests/simulations/coverage.R

# Seed --------------------------------------------------------------------
set.seed(1)


# Timing ------------------------------------------------------------------
tictoc::tic()


# Config ------------------------------------------------------------------
N_SIMS <- 1e4


# Libraries ---------------------------------------------------------------
library(here)
library(tidyverse)
library(ggthemes)
devtools::load_all()


# Helpers -----------------------------------------------------------------
map_ci <- function(experiment, conf_levels, ci) {
  tibble(conf_level = conf_levels,
         ci         = map(conf_levels, ci, experiment = experiment))
}

is_covered <- function(tmt_effect, best_obs_tmt, ci) {
  true_value <- ifelse(best_obs_tmt == 1, tmt_effect, 0)
  between(true_value, ci[1], ci[3])
}


# Setup -------------------------------------------------------------------
experiment_config <- crossing(n_tmt      = c(2, 10, 50),
                              tmt_effect = seq(-3, 3, by = 1),
                              tmt_se     = 1)

sim_config <- transmute(experiment_config,
                        tmt_mean = map2(tmt_effect, n_tmt,
                                        ~c(.x, rep(0, .y - 1))),
                        tmt_se   = map2(tmt_se, n_tmt, rep))

conf_levels <- seq(0.05, 0.95, by = 0.05)


# Simulate Experiment -----------------------------------------------------
message("Simulating experiments...")
experiment_tb <- mutate(experiment_config,
                        experiments = pmap(sim_config, sim_experiments,
                                           n_sims = N_SIMS))


# Intervals ---------------------------------------------------------------
message("Calculating intervals...")
ci_tb_nested <- experiment_tb %>%
  transmute(n_tmt,
            tmt_effect,
            tmt_se,
            experiment_id = map(experiments, seq_along),
            best_obs_tmt  = map_depth(experiments, 2, ~which.max(.x$ate_hat)),
            standard    = map_depth(experiments, 2, map_ci,
                                    ci          = ci_standard,
                                    conf_levels = conf_levels),
            conditional = map_depth(experiments, 2, map_ci,
                                    ci          = ci_conditional,
                                    conf_levels = conf_levels),
            hybrid      =  map_depth(experiments, 2, map_ci,
                                     ci          = ci_hybrid,
                                     conf_levels = conf_levels))

ci_tb <- ci_tb_nested %>%
  unnest(c(experiment_id, best_obs_tmt,
           standard, conditional, hybrid)) %>%
  pivot_longer(c(standard, conditional, hybrid),
               names_to = "ci_type",
               values_to = "ci") %>%
  unnest(c(ci, best_obs_tmt))


# Coverage ----------------------------------------------------------------
message("Calculating coverage...")

coverage_tb <- mutate(ci_tb,
                      covered = pmap_lgl(list(tmt_effect, best_obs_tmt, ci),
                                         is_covered))
summary_tb <- coverage_tb %>%
  mutate(tmt_effect = factor(tmt_effect)) %>%
  group_by(n_tmt, tmt_effect, tmt_se, ci_type, conf_level) %>%
  summarise(n = n(),
            y = sum(covered),
            p = y / n,
            p_lo = map2_dbl(y, n, ci_proportion, "lo"),
            p_hi = map2_dbl(y, n, ci_proportion, "hi")) %>%
  ungroup()


# Plot --------------------------------------------------------------------
message("Plotting...")
coverage_plot <- ggplot(summary_tb,
       aes(x = conf_level,
           y = p,
           ymin = p_lo,
           ymax = p_hi,
           color = tmt_effect,
           group = tmt_effect)) +
  facet_grid(n_tmt ~ ci_type) +
  theme_bw() +
  scale_color_viridis_d() +
  geom_line()


# Save --------------------------------------------------------------------
message("Saving...")
ggsave(here("tests", "simulations", "coverage_plot.png"),
       coverage_plot,
       width = 8, height = 5, units = "in")

write_rds(ci_tb,
          here("tests", "simulations", "ci_tb.rds"))

write_rds(summary_tb,
          here("tests", "simulations", "summary_tb.rds"))


# Done --------------------------------------------------------------------
message("Done.")
tictoc::toc()
adviksh/winfer documentation built on Dec. 24, 2019, 7:05 p.m.