test_hard_restriction.R

#!/usr/bin/Rscript

"Usage:
  test_hard_restriction single [--num-entities=<num-entities> --true-hyper-sd=<sd>]
  test_hard_restriction multi <cores> <runs> [--num-entities=<num-entities> --true-hyper-sd=<sd> --output=<output-name>] [--append --alt-hyper-sd=<sd>]
  test_hard_restriction prior-sequence <cores> <from-tau> <to-tau> <by-tau> [--num-entities=<num-entities>]
  test_hard_restriction constrained-prior-sequence <cores> <unconst-tau> <from-tau> <to-tau> <by-tau> [--num-entities=<num-entities>]

Options:
  --output=<output-name>  Output name to use in file names [default: test_run]
  --num-entities=<num-entities>  Number of entities in model [default: 3]
  --true-hyper-sd=<sd>  True SD hyperparameter for prior [default: 2]
  --alt-hyper-sd=<sd>  Alternative SD hyperparameter to use for fit
" -> opt_desc

script_options <- if (interactive()) {
  docopt::docopt(opt_desc, "single --num-entities=1")
  # docopt::docopt(opt_desc, "multi 12 12 --num-entities=1")
} else {
  docopt::docopt(opt_desc)
}

# Setup -------------------------------------------------------------------

library(magrittr)
library(tidyverse)
library(rlang)
library(rstan)
library(bayesplot)

library(econometr)
library(boundr)

script_options %<>%
  modify_at(c("cores", "runs", "num-entities"), as.integer) %>%
  modify_at(c("true-hyper-sd", "alt-hyper-sd"), ~ if (!is_null(.x)) as.numeric(.x))

options(mc.cores = max(1, parallel::detectCores()))
rstan_options(auto_write = TRUE)

true_discretized_beta_hyper_sd <- lst(
  default = 1,
  # "never below" = script_options$`true-hyper-sd`,
  # "always below" = script_options$`true-hyper-sd`,
  # "migration complier" = ~ mutate(., sd = if_else(fct_match(r_m, c("always")), 0, 1))
)

test_parallel_map <- function(.x, .f, ..., cores = script_options$cores) {
  pbmcapply::pbmclapply(.x, as_mapper(.f), ..., ignore.interactive = TRUE, mc.silent = TRUE, mc.cores = cores)
}

# Models ------------------------------------------------------------------

pruning_data <- tribble(
  ~ hi,                  ~ "always below",  ~ "program complier", ~ "wedge complier", ~ "migration complier", ~ "program defier", ~ "wedge defier", ~ "migration defier", ~ "never below",

  "always below",        TRUE,              TRUE,                 TRUE,              TRUE,                   TRUE,               TRUE,             TRUE,                TRUE,
  "program complier",    FALSE,             TRUE,                 FALSE,             FALSE,                  FALSE,              FALSE,            FALSE,               TRUE,
  "wedge complier",       FALSE,             FALSE,                TRUE,              FALSE,                  FALSE,              FALSE,            FALSE,               TRUE,
  "migration complier",  FALSE,             FALSE,                FALSE,             TRUE,                   FALSE,              FALSE,            FALSE,               TRUE,
  "program defier",      FALSE,             FALSE,                FALSE,             FALSE,                  TRUE,               FALSE,            FALSE,               TRUE,
  "wedge defier",         FALSE,             FALSE,                FALSE,             FALSE,                  FALSE,              TRUE,             FALSE,               TRUE,
  "migration defier",    FALSE,             FALSE,                FALSE,             FALSE,                  FALSE,              FALSE,            TRUE,                TRUE,
  "never below",         FALSE,             FALSE,                FALSE,             FALSE,                  FALSE,              FALSE,            FALSE,               TRUE,
) %>%
  pivot_longer(-hi, names_to = "low", values_to = "allow") %>%
  filter(allow) %>%
  select(-allow)

test_model <- define_structural_causal_model(
  define_response(
    "z",

    "village assigned treatment" = ~ 1,
    "village assigned control" = ~ 0,
  ),

  define_response(
    "m",
    input = c("z"),

    "never" = ~ 0,
    "treatment complier" = ~ z,
    "always" = ~ 1,
  ),

  define_discretized_response_group(
    "y",
    cutpoints = c(-100, -20, 100),

    input = c("m"),

    "never below" = ~ 0,
    "migration complier" = ~ m,
    "migration defier" = ~ 1 - m,
    "always below" = ~ 1,

    # pruning_data = pruning_data
  ),

  exogenous_prob = tribble(
    ~ z, ~ ex_prob,
    0,   0.4,
    1,   0.6
  ),
)

default_model <- test_model

# Estimands ---------------------------------------------------------------

test_estimands <- build_estimand_collection(
  model = test_model,
  utility = c(0, 1),

  build_diff_estimand(
    build_atom_estimand("m", z = 1),
    build_atom_estimand("m", z = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", z = 0, m = 1),
    build_discretized_atom_estimand("y", z = 0, m = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", z = 1),
    build_discretized_atom_estimand("y", z = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", z = 0, m = 1, cond = m == 1 & z == 0),
    build_discretized_atom_estimand("y", z = 0, m = 0, cond = m == 1 & z == 0)
  ),
)

default_estimands <- test_estimands

default_unobs_cf <- "Pr[Y^{y}_{z=0,m=0} < c | M = 1, Z = 0]"
default_obs_cf <- "Pr[Y^{y}_{z=0,m=1} < c | M = 1, Z = 0]"
default_cf_diff <- str_c(default_obs_cf, " - ", default_unobs_cf)

# Single Run --------------------------------------------------------------

if (script_options$single) {
  entity_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                   discrete_beta_hyper_sd = script_options$`true-hyper-sd`,
                                                   discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,
                                                   tau_level_sigma = 1,
                                                   num_entities = script_options$`num-entities`) %>%
    unnest(entity_data) %>%
    select(entity_index, sim) %>%
    deframe()

  known_results <- entity_data %>%
    map_dfr(boundr:::get_known_estimands, default_estimands, .id = "entity_index") %>%
    group_by_at(vars(-entity_index, -prob)) %>%
    summarize(prob = mean(prob)) %>%
    ungroup() %>%
    select(estimand_name, cutpoint, prob)

  test_sim_data <- entity_data %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    mutate(y = if_else(y_1 == 0, 30, -30))
    # mutate(y = if_else(y_1 == 0, if_else(y_2 == 0, 0, 30), -30))
    # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, sample(c(-1, 1), n(), replace = TRUE), -30)))
    # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, runif(n(), -19, 19), -30)))

  # test_model %>%
  #   get_linear_programming_bounds(test_sim_data, "y_1", b = 1, g = 1, z = 1, m = 0)

  test_sampler <- create_sampler(
    default_model,
    model_levels = "entity_index",
    analysis_data = test_sim_data,
    estimands = default_estimands,
    # y = y < -20,
    y = y,

    discrete_beta_hyper_sd = script_options$`true-hyper-sd`,
    discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,

    tau_level_sigma = 1,
    calculate_marginal_prob = TRUE
  )

  test_prior_fit <- test_sampler %>%
    sampling(
      chains = 4,
      iter = 1000,
      # control = lst(adapt_delta = 0.99, max_treedepth = 12),
      pars = c("iter_estimand", "discretized_beta"),
      run_type = "prior-predict",
      save_background_joint_prob = TRUE
    )

  test_prior_results <- test_prior_fit %>%
    get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %T>%
    print(n = 1000)

  test_prior_marginal_prob <- test_prior_fit %>% get_marginal_latent_type_prob()

  test_fit <- test_sampler %>%
    sampling(
      chains = 4,
      iter = 1000,
      # control = lst(adapt_delta = 0.99, max_treedepth = 12),
      # pars = c("iter_estimand", "single_discrete_marginal_p_r"),
      pars = c("iter_estimand", "single_discrete_marginal_p_r", "discretized_beta"),
    )

  test_results <- test_fit %>%
    get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %T>%
    print(n = 1000)

  known_results %>%
    inner_join(test_results, by = c("cutpoint", "estimand_name")) %>%
    mutate_at(vars(starts_with("per_")), ~ . - prob) %>%
    mutate(coverage = if_else(per_0.1 > 0 | per_0.9 < 0, "outside", "inside") %>% factor()) %>%
    select(estimand_name, coverage) %>%
    print(n = 1000)

  bind_rows(prior = test_prior_results, posterior = test_results, .id = "fit_type") %>%
    filter(estimand_name == default_unobs_cf) %>%
    select(fit_type, starts_with("per_")) %>%
    pivot_longer(names_to = "quant", cols = -fit_type, names_prefix = "per_", names_ptypes = list("quant" = numeric())) %>%
    ggplot() +
    geom_col(aes(quant, value, group = fit_type, fill = fit_type), position = position_dodge()) +
    geom_vline(xintercept = known_results %>% filter(estimand_name == default_unobs_cf) %>% pull(prob)) +
    scale_fill_discrete("") +
    scale_x_continuous(breaks = seq(0, 1, 0.2)) +
    labs(x = "", y = "") +
    theme_minimal() +
    theme(legend.position = "top")

  test_prior_marginal_prob %>%
    filter(discretized) %>%
    select(type, r_m, iter_data) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r)) +
    labs(x = "", y = "") +
    facet_grid(rows = vars(type), cols = vars(r_m), scales = "free") +
    theme_minimal() +
    NULL

  test_prior_marginal_prob %>%
    filter(!discretized) %>%
    select(type, iter_data) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r)) +
    labs(x = "", y = "") +
    facet_wrap(vars(type), scales = "free") +
    theme_minimal() +
    NULL
}

# Multiple Runs -----------------------------------------------------------

if (script_options$multi) {
  num_runs <- script_options$runs

  test_sim_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                     discrete_beta_hyper_sd = script_options$`true-hyper-sd`,
                                                     discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,
                                                     tau_level_sigma = 1,
                                                     num_entities = script_options$`num-entities`,
                                                     num_sim = num_runs) %>%
    deframe()

  used_discretized_beta_hyper_sd <- script_options$`alt-hyper-sd` %||% true_discretized_beta_hyper_sd

  test_run_data <- test_sim_data %>%
    test_parallel_map(cores = script_options$cores %/% 4,
    # map(
      function(entity_data, discrete_beta_hyper_sd, discretized_beta_hyper_sd, save_iter_data) {
        entity_data %<>% deframe()

        known_results <- entity_data %>%
          map_dfr(boundr:::get_known_estimands, default_estimands, .id = "entity_index") %>%
          group_by_at(vars(-entity_index, -prob)) %>%
          summarize(prob = mean(prob)) %>%
          ungroup()

        known_marginal_prob <- entity_data %>%
          map_dfr(boundr:::get_known_marginal_latent_type_prob, .id = "entity_index") %>%
          group_by_at(vars(-entity_index, -marginal_prob)) %>%
          summarize(marginal_prob = mean(marginal_prob)) %>%
          ungroup()

        sampler <- entity_data %>%
          map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
          # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, 0, -30))) %>%
          # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, runif(n(), -19, 19), -30))) %>%
          mutate(y = if_else(y_1 == 0, 30, -30)) %>%
          create_sampler(
            default_model,
            model_levels = "entity_index",
            analysis_data = .,
            estimands = default_estimands,
            y = y,

            discrete_beta_hyper_sd = discrete_beta_hyper_sd,
            discretized_beta_hyper_sd = discretized_beta_hyper_sd,
            tau_level_sigma = 1,
            calculate_marginal_prob = TRUE
          )

        fit <- sampler %>%
          sampling(
            pars = c("iter_estimand", "single_discrete_marginal_p_r", "discretized_marginal_p_r"),
            chains = 4, iter = 1000
          )

        results <- fit %>%
          get_estimation_results(quants = seq(0, 1, 0.1)) %>%
          select(estimand_name, cutpoint, starts_with("per_"), if (save_iter_data) "iter_data") %>%
          inner_join(
            select(known_results, estimand_name, cutpoint, prob),
            by = c("estimand_name", "cutpoint")
          ) %>%
          mutate(coverage = if_else((per_0.1 - prob) > 0 | (per_0.9 - prob) < 0, "outside", "inside") %>% factor())

        marginal_prob <- fit %>% get_marginal_latent_type_prob() %>%
          select(type_variable, type, starts_with("per_"), if (save_iter_data) "iter_data") %>%
          inner_join(known_marginal_prob, by = c("type_variable", "type")) %>%
          mutate(coverage = if_else((per_0.1 - marginal_prob) > 0 | (per_0.9 - marginal_prob) < 0, "outside", "inside") %>% factor())

        lp_bounds <- entity_data %>%
          map(get_linear_programming_bounds, "y_1", z = 0, m = 0, cond = m == 1 & z == 0) %>%
          map(map, pluck, "objval") %>%
          map_df(as_tibble)

        tibble(results = list(results),
               marginal_prob = list(marginal_prob),
               lp_bounds = list(lp_bounds))
      },

      discrete_beta_hyper_sd = script_options$`alt-hyper-sd` %||% script_options$`true-hyper-sd`,
      discretized_beta_hyper_sd =  used_discretized_beta_hyper_sd,
      save_iter_data = TRUE
    ) %>%
    compact() %>%
    bind_rows(.id = "iter_id")

  dummy_data <- test_sim_data[[1]] %>%
    select(entity_index, sim) %>%
    deframe() %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, runif(n(), -19, 19), -30))) %>%
    mutate(y = if_else(y_1 == 0, 30, -30)) # This data isn't really used in prior prediction

  true_prior_sampler <- create_sampler(
    default_model,
    model_levels = "entity_index",
    analysis_data = dummy_data,
    estimands = default_estimands,
    y = y,

    discrete_beta_hyper_sd = script_options$`true-hyper-sd`,
    discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,
    tau_level_sigma = 1,
    calculate_marginal_prob = TRUE
  )

  true_prior_fit <- true_prior_sampler %>%
    sampling(
      pars = c("iter_estimand", "single_discrete_marginal_p_r", "discretized_marginal_p_r"),
      chains = 4, iter = 1000,
      run_type = "prior-predict",
    )

  true_prior_results <- true_prior_fit %>% get_estimation_results(quants = seq(0, 1, 0.1))
  true_prior_marginal_prob <- true_prior_fit %>% get_marginal_latent_type_prob()
  prior_results <- true_prior_results
  prior_marginal_prob <- true_prior_marginal_prob

  test_run_data_file <- file.path("temp-data", str_c(script_options$output, ".rds"))

  if (script_options$append && file.exists(test_run_data_file)) {
    test_run_data %<>%
      bind_rows(read_rds(test_run_data_file))
  }

  write_rds(test_run_data, test_run_data_file)

  test_run_data %>%
    select(iter_id, results) %>%
    unnest(results) %>%
    group_by_at(vars(estimand_name, any_of("cutpoint"))) %>%
    summarize(coverage = mean(fct_match(coverage, "inside"))) %>%
    ungroup() %>%
    arrange_at(vars(estimand_name, any_of("cutpoint"))) %>%
    print(n = 1000)

  get_prior_and_post <- function(estimand) {
    bind_rows(
      # prior = if (script_options$`different-priors`) {
      #   prior_results %>%
      #     filter(estimand_name %in% estimand) %>%
      #     map_df(seq(num_runs), ~ mutate(..2, iter_id = ..1), .)
      # },

      posterior = test_run_data %>%
        select(iter_id, results) %>%
        unnest(results) %>%
        # filter(estimand_name == default_unobs_cf) %>%
        filter(estimand_name %in% estimand) %>%
        mutate(iter_id = as.integer(iter_id)),

      "true prior" = true_prior_results %>%
        # filter(estimand_name == default_unobs_cf) %>%
        filter(estimand_name %in% estimand) %>%
        map_df(seq(num_runs), ~ mutate(..2, iter_id = ..1), .),

      .id = "fit_type"
    )
  }

  if (script_options$`density-plots`) {
    density_plots <- get_prior_and_post(tot_cf_diff) %>%
      # filter(iter_id %in% sample(.$iter_id, 50, replace = FALSE)) %>%
      select(iter_id, prob, fit_type, iter_data) %>%
      mutate(iter_data = map_if(iter_data, ~ !is_null(.x), select, -iter_id)) %>%
      unnest(iter_data) %>%
      ggplot() +
      geom_density(aes(iter_estimand, group = fit_type, color = fit_type)) +
      geom_vline(xintercept = 0, linetype = "dotted") +
      geom_vline(aes(xintercept = prob), data = . %>% distinct(iter_id, prob)) +
      scale_fill_discrete("", aesthetics = c("color", "fill")) +
      labs(x = "", y = "") +
           # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")) +
      facet_wrap(vars(iter_id), scales = "free") +
      theme_minimal() +
      theme(legend.position = "top", strip.text = element_blank(), plot.subtitle = element_text(size = 9))

    ggsave(file.path("temp-img", str_c(script_options$output, ".png")), density_plots)

    if (!interactive() && require(tcltk)) {
      x11()
      plot(density_plots)
      capture <- tk_messageBox(message = "Hit spacebar to close plots.")
    } else {
      plot(density_plots)
    }

    get_prior_and_post(default_cf_diff) %>%
      filter(fct_match(fit_type, "posterior")) %>%
      mutate(
        iter_id = seq(n()),
        iter_id = fct_reorder(factor(iter_id), prob)
      ) %>% {
        cowplot::plot_grid(
          ggplot(., aes(iter_id)) +
            geom_hline(yintercept = 0, linetype = "dotted") +
            geom_point(aes(y = prob), shape = 1, size = 2) +
            geom_pointrange(aes(y = per_0.5, ymin = per_0.1, ymax = per_0.9, color = coverage), fatten = 1) +
            scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
            labs(x = "", y = "", subtitle = "Uncentered") +
            theme_minimal() +
            theme(axis.text.x = element_blank(),
                  axis.ticks.x = element_blank(),
                  panel.grid.minor.x = element_blank(),
                  panel.grid.major.x = element_blank(),
                  legend.position = "none"),

          mutate_at(., vars(starts_with("per_")), ~ . - prob) %>%
            ggplot(aes(iter_id)) +
            geom_hline(yintercept = 0, linetype = "dotted") +
            geom_pointrange(aes(y = per_0.5, ymin = per_0.1, ymax = per_0.9, color = coverage), fatten = 1) +
            scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
            labs(x = "", y = "", subtitle = "Centered") +
            theme_minimal() +
            theme(axis.text.x = element_blank(),
                  axis.ticks.x = element_blank(),
                  panel.grid.minor.x = element_blank(),
                  panel.grid.major.x = element_blank(),
                  legend.position = "none"),

          ncol = 1
        )
      }

    get_prior_and_post(c(default_obs_cf, default_cf_diff)) %>%
      filter(fct_match(fit_type, "posterior")) %>%
      group_by(estimand_name) %>%
      mutate(
        iter_id = seq(n()),
        iter_id = fct_reorder(factor(iter_id), prob)
      ) %>%
      ungroup() %>%
      select(iter_id, estimand_name, per_0.1, per_0.5, per_0.8, prob, coverage) %>%
      mutate(estimand_name = if_else(str_detect(estimand_name, "-"), "diff", "obs")) %>%
      pivot_wider(names_from = "estimand_name", values_from = c("per_0.1", "per_0.5", "per_0.8", "prob", "coverage")) %>% {
        cowplot::plot_grid(
          ggplot(.) +
            geom_point(aes(prob_obs, per_0.5_diff, color = coverage_diff)) +
            geom_point(aes(prob_obs, prob_diff), shape = 3) +
            scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
            # labs(x = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\]"),
            labs(x = "",
                 y = "TOT") +
                 # y = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\] - P\\[Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1\\]")) +
            theme_minimal() +
            theme(legend.position = "none",
                  axis.text.x = element_blank()) +
            NULL,

        ggplot(.) +
          geom_pointrange(aes(prob_obs, ymin = per_0.1_diff, y = per_0.5_diff, ymax = per_0.8_diff,  color = coverage_diff), fatten = 1) +
          geom_line(aes(prob_obs, prob_obs), linetype = "dashed") +
          geom_line(aes(prob_obs, prob_obs - 1), linetype = "dashed") +
          geom_point(aes(prob_obs, prob_diff), shape = 3) +
          scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
          labs(x = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\]"),
               y = "TOT") +
               # y = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\] - P\\[Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1\\]")) +
          theme_minimal() +
          theme(legend.position = "none") +
          NULL,

        ncol = 1,
        rel_heights = c(1, 1.1))
      }

    prior_results %>%
      filter(estimand_name %in% c(default_obs_cf, default_unobs_cf)) %>%
      select(estimand_name, iter_data) %>%
      mutate(estimand_name = if_else(str_detect(estimand_name, "m=1"), "obs", "unobs")) %>%
      unnest(iter_data) %>%
      pivot_wider(names_from = "estimand_name", values_from = "iter_estimand") %>%
      ggplot(aes(obs, unobs)) +
      # geom_density2d()
      stat_density_2d(aes(fill = after_stat(nlevel)), geom = "polygon") +
      labs(x = "Observable", y = "Unobservable") +
      scale_fill_viridis_c() +
      theme_minimal() +
      theme(legend.position = "none")

    test_run_data %>%
      select(iter_id, results, lp_bounds) %>%
      # sample_n(150) %>%
      mutate(
        results = map(results, filter, estimand_name == default_unobs_cf) %>%
          map(select, per_0.1, per_0.8, prob),
      ) %>%
      unnest(c(results, lp_bounds)) %>%
      mutate(iter_id = fct_reorder(iter_id, prob)) %>%
      ggplot(aes(iter_id)) +
      geom_linerange(aes(ymin = per_0.1, ymax = per_0.8), color = "red", size = 2, alpha = 0.5) +
      geom_errorbar(aes(ymin = min, ymax = max)) +
      geom_point(aes(y = prob), size = 1) +
      labs(
        x = "", y = ""
      ) +
      theme_minimal() +
      theme(legend.position = "top",
            axis.text.x = element_blank(),
            plot.subtitle = element_text(size = 9))

    marginal_prob_density_plots <- bind_rows(
      prior = if (script_options$`different-priors`) {
        prior_marginal_prob %>%
          map_df(seq(1), ~ mutate(.y, iter_id = .x), .)
      },

      # posterior = test_run_data %>%
      #   slice(1) %>%
      #   select(iter_id, marginal_prob) %>%
      #   unnest(marginal_prob) %>%
      #   filter(discretized) %>%
      #   mutate(iter_id = as.integer(iter_id)),

      "true prior" = true_prior_marginal_prob %>%
        map_df(seq(1), ~ mutate(.y, iter_id = .x), .),

    .id = "fit_type"
    ) %>%
      filter(
        # iter_id %in% sample(.$iter_id, 6, replace = FALSE),
        fct_match(type_variable, "r_y_1")
      ) %>%
      mutate(estimand_name = str_c(type_variable, " = ", type) %>% str_replace("^r", "R")) %>%
      # select(estimand_name, iter_id, marginal_prob, fit_type, iter_data, r_m) %>%
      select(estimand_name, iter_id, fit_type, iter_data, r_m) %>%
      mutate(iter_data = map_if(iter_data, ~ !is_null(.x), select, -iter_id)) %>%
      unnest(iter_data) %>%
      ggplot() +
      geom_density(aes(iter_p_r, group = fit_type, color = fit_type, fill = fit_type), alpha = 0.25) +
      # geom_vline(aes(xintercept = marginal_prob), data = . %>% distinct(estimand_name, iter_id, marginal_prob)) +
      scale_fill_discrete("", aesthetics = c("color", "fill")) +
      labs(
        x = "", y = ""
        # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")
      ) +
      # facet_grid(rows = vars(estimand_name), cols = vars(iter_id), scales = "free", switch = "y") +
      facet_grid(rows = vars(estimand_name), cols = vars(r_m), scales = "free", switch = "y") +
      # coord_cartesian(ylim = c(0, 8)) +
      theme_minimal() +
      theme(legend.position = "top",
            # strip.text.x = element_blank(),
            strip.text.y.left = element_text(angle = 0),
            axis.text.y = element_blank(),
            plot.subtitle = element_text(size = 9))

    ggsave(file.path("temp-img", str_c(script_options$output, "_marginal_type_prob.png")), marginal_prob_density_plots)

    if (!interactive() && require(tcltk)) {
      x11()
      plot(marginal_prob_density_plots)
      capture <- tk_messageBox(message = "Hit spacebar to close plots.")
    } else {
      plot(marginal_prob_density_plots)
    }
  }
}


# Prior Sequence ----------------------------------------------------------

if (script_options$`prior-sequence`) {
  taus <- seq(script_options$`from-tau`, script_options$`to-tau`, script_options$`by-tau`)

  dummy_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                   discrete_beta_hyper_sd = script_options$`true-hyper-sd`, discretized_beta_hyper_sd = true_discretized_beta_hyper_sd, tau_level_sigma = 1,
                                                   num_entities = script_options$`num-entities`) %>%
    unnest(entity_data) %>%
    select(entity_index, sim) %>%
    deframe() %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    mutate(y = if_else(y_1 == 0, 30, -30))

  all_prior_results <- taus %>%
    test_parallel_map(cores = script_options$cores %/% 4,
                      function(curr_tau) {
                        test_sampler <- create_sampler(
                          default_model,
                          model_levels = "entity_index",
                          analysis_data = dummy_data,
                          estimands = default_estimands,
                          # y = y < -20,
                          y = y,

                          discrete_beta_hyper_sd = if (script_options$`different-priors`) script_options$`alt-hyper-sd` else script_options$`true-hyper-sd`,
                          discretized_beta_hyper_sd = curr_tau,

                          tau_level_sigma = 1,
                          calculate_marginal_prob = TRUE
                        )

                        test_prior_fit <- test_sampler %>%
                          sampling(
                            chains = 4,
                            warmup = 500,
                            iter = 2500,
                            pars = c("iter_estimand", "marginal_p_r"),
                            run_type = "prior-predict",
                          )

                        tibble(
                          results = test_prior_fit %>%
                            get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %>%
                            list(),

                          marginal_prob = test_prior_fit %>%
                            get_marginal_latent_type_prob() %>%
                            list()
                        )
                      }) %>%
    set_names(taus) %>%
    bind_rows(.id = "tau") %>%
    mutate(tau = as.numeric(tau))

  all_prior_results %>%
    select(tau, results) %>%
    unnest(results) %>%
    # filter(estimand_name == "Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c | M = 1, B = 0, G = 0, Z = 0]") %>%
    filter(estimand_name == default_unobs_cf) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_estimand, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(x = "", y = "",
         # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")) +
         subtitle = latex2exp::TeX("P(Y_{z=0,m=0} < c | M_{z=0} = 1)")) +
    theme_minimal() +
    theme(legend.position = "right", plot.subtitle = element_text(size = 9))

  all_prior_results %>%
    select(tau, marginal_prob) %>%
    unnest(marginal_prob) %>%
    filter(fct_match(type_variable, "r_y_1")) %>%
    mutate(estimand_name = str_c(type_variable, " = ", type) %>% str_replace("^r", "R")) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(
      x = "", y = ""
    ) +
    facet_wrap(vars(estimand_name), scales = "free") +
    coord_cartesian(ylim = c(0, 8)) +
    theme_minimal() +
    theme(legend.position = "top",
          # strip.text.x = element_blank(),
          # strip.text.y.left = element_text(angle = 0),
          axis.text.y = element_blank(),
          plot.subtitle = element_text(size = 9))
}

# Constrained Prior Sequence ----------------------------------------------------------

if (script_options$`constrained-prior-sequence`) {
  unconst_tau <- script_options$`unconst-tau`
  taus <- seq(script_options$`from-tau`, script_options$`to-tau`, script_options$`by-tau`)

  dummy_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                   discrete_beta_hyper_sd = script_options$`true-hyper-sd`, discretized_beta_hyper_sd = true_discretized_beta_hyper_sd, tau_level_sigma = 1,
                                                   num_entities = script_options$`num-entities`) %>%
    unnest(entity_data) %>%
    select(entity_index, sim) %>%
    deframe() %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    mutate(y = if_else(y_1 == 0, 30, -30))

  all_prior_results <- taus %>%
    # test_parallel_map(cores = script_options$cores %/% 4,
    map(
                      function(curr_tau) {
                        test_sampler <- create_sampler(
                          default_model,
                          model_levels = "entity_index",
                          analysis_data = dummy_data,
                          estimands = default_estimands,
                          # y = y < -20,
                          y = y,

                          discrete_beta_hyper_sd = unconst_tau,
                          discretized_beta_hyper_sd = list(
                            default = curr_tau,
                            "always below" = unconst_tau,
                            "never below" = unconst_tau
                          ),

                          tau_level_sigma = 1,
                          calculate_marginal_prob = TRUE
                        )

                        test_prior_fit <- test_sampler %>%
                          sampling(
                            chains = 4,
                            warmup = 500,
                            iter = 2500,
                            pars = c("iter_estimand", "marginal_p_r"),
                            run_type = "prior-predict",
                          )

                        tibble(
                          results = test_prior_fit %>%
                            get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %>%
                            list(),

                          marginal_prob = test_prior_fit %>%
                            get_marginal_latent_type_prob() %>%
                            list()
                        )
                      }) %>%
    set_names(taus) %>%
    bind_rows(.id = "tau") %>%
    mutate(tau = as.numeric(tau))

  all_prior_results %>%
    select(tau, results) %>%
    unnest(results) %>%
    # filter(estimand_name == "Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c | M = 1, B = 0, G = 0, Z = 0]") %>%
    filter(estimand_name == default_unobs_cf) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_estimand, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(x = "", y = "",
         # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")) +
         subtitle = latex2exp::TeX("P(Y_{z=0,m=0} < c | M_{z=0} = 1)")) +
    theme_minimal() +
    theme(legend.position = "right", plot.subtitle = element_text(size = 9))

  all_prior_results %>%
    select(tau, marginal_prob) %>%
    unnest(marginal_prob) %>%
    filter(fct_match(type_variable, "r_y_1")) %>%
    mutate(estimand_name = str_c(type_variable, " = ", type) %>% str_replace("^r", "R")) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(
      x = "", y = ""
    ) +
    facet_wrap(vars(estimand_name), scales = "free") +
    coord_cartesian(ylim = c(0, 8)) +
    theme_minimal() +
    theme(legend.position = "top",
          # strip.text.x = element_blank(),
          # strip.text.y.left = element_text(angle = 0),
          axis.text.y = element_blank(),
          plot.subtitle = element_text(size = 9))
}

# Diagnostics -------------------------------------------------------------

# color_scheme_set("darkgray")
# test_posterior <- as.array(test_fit)
# test_np <- nuts_params(test_fit)
# mcmc_pairs(test_posterior, np = test_np, regex_pars =  c("level_beta_sigma", "obs_beta\\[[2-4]"))
karimn/boundr documentation built on March 1, 2021, 6:57 p.m.