test_boundr.R

#!/usr/bin/Rscript

"Usage:
  test_boundr single [--different-priors --num-entities=<num-entities> --true-hyper-sd=<sd> --constrained-prior]
  test_boundr multi <cores> <runs> [--num-entities=<num-entities> --true-hyper-sd=<sd> --constrained-prior --density-plots --output=<output-name>] [--append --different-priors --alt-hyper-sd=<sd> --alt-hyper-sd-constrained=<const-sd>]
  test_boundr prior-sequence <cores> <from-tau> <to-tau> <by-tau> [--num-entities=<num-entities>]
  test_boundr constrained-prior-sequence <cores> <unconst-tau> <from-tau> <to-tau> <by-tau> [--num-entities=<num-entities>]

Options:
  --output=<output-name>  Output name to use in file names [default: test_run]
  --num-entities=<num-entities>  Number of entities in model [default: 1]
  --true-hyper-sd=<sd>  True SD hyperparameter for prior [default: 2.5]
  --alt-hyper-sd=<sd>  Alternative SD hyperparameter to use for fit [default: 2]
  --alt-hyper-sd-constrained=<const-sd>  Alternative SD hyperparameter for the complier/defier groups
" -> opt_desc

script_options <- if (interactive()) {
  # docopt::docopt(opt_desc, "multi 12 12 --density-plots --num-entities=1 --true-hyper-sd=0.25 --alt-hyper-sd=0.9 --output=test.rds --different-priors")
  # docopt::docopt(opt_desc, "multi 12 300 --density-plots --num-entities=1 --append --output=constrained4.rds --different-priors --alt-hyper-sd=2 --alt-hyper-sd-constrained=1.5")
  # docopt::docopt(opt_desc, "single --constrained-prior --true-hyper-sd=5 --num-entities=1")
  # docopt::docopt(opt_desc, "multi 12 3 --density-plots --constrained-prior --true-hyper-sd=2 --num-entities=1")
  # docopt::docopt(opt_desc, "single --different-priors --num-entities=1")
  docopt::docopt(opt_desc, "single")
  # docopt::docopt(opt_desc, "prior-sequence 12 0.5 10 0.5 --num-entities=1")
  # docopt::docopt(opt_desc, "constrained-prior-sequence 12 5 0.25 0.25 0.0 --num-entities=1")
} else {
  docopt::docopt(opt_desc)
}

# Setup -------------------------------------------------------------------

library(magrittr)
library(tidyverse)
library(rlang)
library(rstan)
library(bayesplot)

library(econometr)
library(boundr)

script_options %<>%
  modify_at(c("cores", "runs", "num_entities"), as.integer) %>%
  modify_at(c("true_hyper_sd", "alt_hyper_sd", "alt_hyper_sd_constrained", "unconst_tau", "from_tau", "to_tau", "by_tau"), as.numeric)

options(mc.cores = max(1, parallel::detectCores()))
rstan_options(auto_write = TRUE)

true_discretized_beta_hyper_sd <- if (script_options$constrained_prior) {
  lst(default = script_options$true_hyper_sd,
      "migration complier" = ~ mutate(., sd = if_else(fct_match(r_m, c("always", "program defier", "wedge defier")),
                                                      0.01,
                                                      1)))
                                                      # script_options$true_hyper_sd)))
} else script_options$true_hyper_sd

true_discretized_beta_hyper_mean <- if (script_options$constrained_prior) {
  lst(default = 0,
      "migration complier" = ~ mutate(., mean = if_else(fct_match(r_m, c("always", "program defier", "wedge defier")),
                                                      -2,
                                                       0)))
} else 0

test_parallel_map <- function(.x, .f, ..., cores = script_options$cores) {
  pbmcapply::pbmclapply(.x, as_mapper(.f), ..., ignore.interactive = TRUE, mc.silent = TRUE, mc.cores = cores)
}

# Models ------------------------------------------------------------------

discrete_variables <- list2(
  define_response(
    "b",

    "program branch" = ~ 1,
    "control branch" = ~ 0,
  ),

  define_response(
    "g",

    "treatment sector" = ~ 1,
    "control sector" = ~ 0,
  ),

  define_response(
    "z",

    "village assigned treatment" = ~ 1,
    "village assigned control" = ~ 0,
  ),

  define_response(
    "m",
    input = c("b", "g", "z"),

    "never" = ~ 0,
    "program complier" = ~ b,
    "program defier" = ~ 1 - b,
    "wedge complier" = ~ g,
    "wedge defier" = ~ 1 - g,
    "treatment complier" = ~ z,
    # "treatment defier" = ~ 1 - z,
    "always" = ~ 1,
  ),
)

pruning_data <- tribble(
  ~ hi,                  ~ "always below",  ~ "program complier", ~ "wedge complier", ~ "migration complier", ~ "program defier", ~ "wedge defier", ~ "migration defier", ~ "never below",

  "always below",        TRUE,              TRUE,                 TRUE,              TRUE,                   TRUE,               TRUE,             TRUE,                TRUE,
  "program complier",    FALSE,             TRUE,                 FALSE,             FALSE,                  FALSE,              FALSE,            FALSE,               TRUE,
  "wedge complier",       FALSE,             FALSE,                TRUE,              FALSE,                  FALSE,              FALSE,            FALSE,               TRUE,
  "migration complier",  FALSE,             FALSE,                FALSE,             TRUE,                   FALSE,              FALSE,            FALSE,               TRUE,
  "program defier",      FALSE,             FALSE,                FALSE,             FALSE,                  TRUE,               FALSE,            FALSE,               TRUE,
  "wedge defier",         FALSE,             FALSE,                FALSE,             FALSE,                  FALSE,              TRUE,             FALSE,               TRUE,
  "migration defier",    FALSE,             FALSE,                FALSE,             FALSE,                  FALSE,              FALSE,            TRUE,                TRUE,
  "never below",         FALSE,             FALSE,                FALSE,             FALSE,                  FALSE,              FALSE,            FALSE,               TRUE,
) %>%
  pivot_longer(-hi, names_to = "low", values_to = "allow") %>%
  filter(allow) %>%
  select(-allow)

# test_model <- define_structural_causal_model(
#   !!!discrete_variables,
#
#   define_discretized_response_group(
#     "y",
#     cutpoints = c(-100, -20, 20, 100),
#     # cutpoints = c(-100, 20, 100),
#
#     input = c("b", "g", "m"),
#
#     "never below" = ~ 0,
#     "program complier" = ~ b,
#     "program defier" = ~ 1 - b,
#     "wedge complier" = ~ g,
#     "wedge defier" = ~ 1 - g,
#     "migration complier" = ~ m,
#     "migration defier" = ~ 1 - m,
#     "always below" = ~ 1,
#
#     pruning_data = pruning_data,
#   ),
#
#   exogenous_prob = tribble(
#     ~ b, ~ g, ~ z, ~ ex_prob,
#     0,   0,   0,   0.4,
#     1,   0,   0,   0.2,
#     1,   1,   0,   0.2,
#     1,   1,   1,   0.2
#   ),
# )
#
# test_model2 <- define_structural_causal_model(
#   !!!discrete_variables,
#
#   define_discretized_response_group(
#     "y",
#     cutpoints = c(-100, -20, -10, 10, 20, 100),
#     # cutpoints = c(-100, 20, 100),
#
#     input = c("b", "g", "m"),
#
#     "never below" = ~ 0,
#     "program complier" = ~ b,
#     "program defier" = ~ 1 - b,
#     "wedge complier" = ~ g,
#     "wedge defier" = ~ 1 - g,
#     "migration complier" = ~ m,
#     "migration defier" = ~ 1 - m,
#     "always below" = ~ 1,
#
#     pruning_data = pruning_data,
#   ),
#
#   exogenous_prob = tribble(
#     ~ b, ~ g, ~ z, ~ ex_prob,
#     0,   0,   0,   0.4,
#     1,   0,   0,   0.2,
#     1,   1,   0,   0.2,
#     1,   1,   1,   0.2
#   ),
# )
#
# test_model3 <- define_structural_causal_model(
#   !!!discrete_variables,
#
#   define_response(
#     "y",
#
#     input = c("b", "g", "m"),
#
#     "never below" = ~ 0,
#     "program complier" = ~ b,
#     "program defier" = ~ 1 - b,
#     "wedge complier" = ~ g,
#     "wedge defier" = ~ 1 - g,
#     "migration complier" = ~ m,
#     "migration defier" = ~ 1 - m,
#     "always below" = ~ 1,
#   ),
#
#   exogenous_prob = tribble(
#     ~ b, ~ g, ~ z, ~ ex_prob,
#     0,   0,   0,   0.4,
#     1,   0,   0,   0.2,
#     1,   1,   0,   0.2,
#     1,   1,   1,   0.2
#   ),
# )

test_model4 <- define_structural_causal_model(
  !!!discrete_variables,

  define_discretized_response_group(
    "y",
    cutpoints = c(-100, -20, 100),

    input = c("b", "g", "m"),

    "never below" = ~ 0,
    "program complier" = ~ b,
    "program defier" = ~ 1 - b,
    "wedge complier" = ~ g,
    "wedge defier" = ~ 1 - g,
    "migration complier" = ~ m,
    "migration defier" = ~ 1 - m,
    "always below" = ~ 1,

    pruning_data = pruning_data,
  ),

  exogenous_prob = tribble(
    ~ b, ~ g, ~ z, ~ ex_prob,
    0,   0,   0,   0.4,
    1,   0,   0,   0.2,
    1,   1,   0,   0.2,
    1,   1,   1,   0.2
  ),
)

test_model5 <- define_structural_causal_model(
  define_response(
    "z",

    "village assigned treatment" = ~ 1,
    "village assigned control" = ~ 0,
  ),

  define_response(
    "m",
    input = c("z"),

    "never" = ~ 0,
    "treatment complier" = ~ z,
    "always" = ~ 1,
  ),

  define_discretized_response_group(
    "y",
    cutpoints = c(-100, -20, 100),

    input = c("m"),

    "never below" = ~ 0,
    "migration complier" = ~ m,
    "migration defier" = ~ 1 - m,
    "always below" = ~ 1,

    pruning_data = pruning_data
  ),

  exogenous_prob = tribble(
    ~ z, ~ ex_prob,
    0,   0.4,
    1,   0.6
  ),
)

default_model <- test_model4

# Estimands ---------------------------------------------------------------

with_discretized_estimands <- list2(
  build_diff_estimand(
    build_atom_estimand("m", b = 1, g = 1, z = 1),
    build_atom_estimand("m", b = 0, g = 0, z = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", b = 0, g = 0, z = 0, m = 1),
    build_discretized_atom_estimand("y", b = 0, g = 0, z = 0, m = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", b = 1, g = 1, z = 1),
    build_discretized_atom_estimand("y", b = 0, g = 0, z = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", b = 0, g = 0, z = 0, m = 1, cond = m == 1 & b == 0 & g == 0 & z == 0),
    build_discretized_atom_estimand("y", b = 0, g = 0, z = 0, m = 0, cond = m == 1 & b == 0 & g == 0 & z == 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", b = 1, g = 1, z = 1, m = 1, cond = m == 1 & b == 1 & g == 1 & z == 1),
    build_discretized_atom_estimand("y", b = 1, g = 1, z = 1, m = 0, cond = m == 1 & b == 1 & g == 1 & z == 1)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", b = 1, g = 1, z = 1, m = 1, cond = fct_match(r_m, "treatment complier"), cond_desc = "M_{z=1} > M_{z=0}"),
    build_discretized_atom_estimand("y", b = 1, g = 1, z = 1, m = 0, cond = fct_match(r_m, "treatment complier"), cond_desc = "M_{z=1} > M_{z=0}")
  ),
)

# test_estimands <- build_estimand_collection(
#   model = test_model,
#   utility = c(0, 1, 1.5),
#
#   !!!with_discretized_estimands
# )
#
# test_estimands2 <- build_estimand_collection(
#   model = test_model2,
#   utility = c(0, 1, 1.5, 1.75, 1.8),
#
#   !!!with_discretized_estimands
# )
#
# test_estimands3 <- build_estimand_collection(
#   model = test_model3,
#
#   build_diff_estimand(
#     build_atom_estimand("m", b = 1, g = 1, z = 1),
#     build_atom_estimand("m", b = 0, g = 0, z = 0)
#   ),
#
#   build_diff_estimand(
#     build_atom_estimand("y", b = 0, g = 0, z = 0, m = 1),
#     build_atom_estimand("y", b = 0, g = 0, z = 0, m = 0)
#   ),
#
#   build_diff_estimand(
#     build_atom_estimand("y", b = 1, g = 1, z = 1),
#     build_atom_estimand("y", b = 0, g = 0, z = 0)
#   ),
#
#   build_diff_estimand(
#     build_atom_estimand("y", b = 0, g = 0, z = 0, m = 1, cond = m == 1 & b == 0 & g == 0 & z == 0),
#     build_atom_estimand("y", b = 0, g = 0, z = 0, m = 0, cond = m == 1 & b == 0 & g == 0 & z == 0)
#   ),
# )

test_estimands4 <- build_estimand_collection(
  model = test_model4,
  utility = c(0, 1),

  !!!with_discretized_estimands
)

test_estimands5 <- build_estimand_collection(
  model = test_model5,
  utility = c(0, 1),

  build_diff_estimand(
    build_atom_estimand("m", z = 1),
    build_atom_estimand("m", z = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", z = 0, m = 1),
    build_discretized_atom_estimand("y", z = 0, m = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", z = 1),
    build_discretized_atom_estimand("y", z = 0)
  ),

  build_discretized_diff_estimand(
    build_discretized_atom_estimand("y", z = 0, m = 1, cond = m == 1 & z == 0),
    build_discretized_atom_estimand("y", z = 0, m = 0, cond = m == 1 & z == 0)
  ),

)

default_estimands <- test_estimands4

default_unobs_cf <- "Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c | M = 1, B = 0, G = 0, Z = 0]"
default_obs_cf <- "Pr[Y^{y}_{b=0,g=0,z=0,m=1} < c | M = 1, B = 0, G = 0, Z = 0]"
default_cf_diff <- "Pr[Y^{y}_{b=0,g=0,z=0,m=1} < c | M = 1, B = 0, G = 0, Z = 0] - Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c | M = 1, B = 0, G = 0, Z = 0]"

tot_unobs_cf <- "Pr[Y^{y}_{b=1,g=1,z=1,m=0} < c | M = 1, B = 1, G = 1, Z = 1]"
tot_obs_cf <- "Pr[Y^{y}_{b=1,g=1,z=1,m=1} < c | M = 1, B = 1, G = 1, Z = 1]"
tot_cf_diff <- "Pr[Y^{y}_{b=1,g=1,z=1,m=1} < c | M = 1, B = 1, G = 1, Z = 1] - Pr[Y^{y}_{b=1,g=1,z=1,m=0} < c | M = 1, B = 1, G = 1, Z = 1]"

default_ate <- "Pr[Y^{y}_{b=0,g=0,z=0,m=1} < c] - Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c]"

# default_unobs_cf <- "Pr[Y^{y}_{z=0,m=0} < c | M = 1, Z = 0]"
# default_obs_cf <- "Pr[Y^{y}_{z=0,m=1} < c | M = 1, Z = 0]"

# Single Run --------------------------------------------------------------

if (script_options$single) {
  entity_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                   discretized_beta_hyper_mean = true_discretized_beta_hyper_mean,
                                                   discrete_beta_hyper_sd = script_options$true_hyper_sd,
                                                   discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,
                                                   tau_level_sigma = 1,
                                                   num_entities = script_options$num_entities) %>%
    unnest(entity_data) %>%
    select(entity_index, sim) %>%
    deframe()

  known_results <- entity_data %>%
    map_dfr(boundr:::get_known_estimands, default_estimands, .id = "entity_index") %>%
    group_by_at(vars(-entity_index, -prob)) %>%
    summarize(prob = mean(prob)) %>%
    ungroup() %>%
    select(estimand_name, cutpoint, prob)

  test_sim_data <- entity_data %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    mutate(y = if_else(y_1 == 0, 30, -30))
    # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, sample(c(-1, 1), n(), replace = TRUE), -30)))
    # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, runif(n(), -19, 19), -30)))

  # test_model %>%
  #   get_linear_programming_bounds(test_sim_data, "y_1", b = 1, g = 1, z = 1, m = 0)

  test_sampler <- create_sampler(
    default_model,
    model_levels = "entity_index",
    analysis_data = test_sim_data,
    estimands = default_estimands,
    # y = y < -20,
    y = y,

    discretized_beta_hyper_mean = true_discretized_beta_hyper_mean,

    discrete_beta_hyper_sd = if (script_options$different_priors) script_options$alt_hyper_sd else script_options$true_hyper_sd,
    discretized_beta_hyper_sd = if (script_options$different_priors) script_options$alt_hyper_sd else true_discretized_beta_hyper_sd,
    # discretized_beta_hyper_sd = lst(default = 2,
    #                                 "migration complier" = ~ mutate(., sd = if_else(fct_match(r_m, c("always", "program defier", "wedge defier")),
    #                                                                                 0.1,
    #                                                                                 2))),
    tau_level_sigma = 1,
    calculate_marginal_prob = TRUE
  )

  test_prior_fit <- test_sampler %>%
    sampling(
      chains = 4,
      iter = 1000,
      # control = lst(adapt_delta = 0.99, max_treedepth = 12),
      pars = c("iter_estimand", "total_abducted_log_prob"),
      run_type = "prior-predict",
      save_background_joint_prob = TRUE
    )

  test_prior_results <- test_prior_fit %>%
    get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %T>%
    print(n = 1000)

  test_prior_marginal_prob <- test_prior_fit %>% get_marginal_latent_type_prob()

  test_fit <- test_sampler %>%
    sampling(
      chains = 4,
      iter = 1000,
      # control = lst(adapt_delta = 0.99, max_treedepth = 12),
      pars = c("iter_estimand", "single_discrete_marginal_p_r"),
    )

  test_results <- test_fit %>%
    get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %T>%
    print(n = 1000)

  known_results %>%
    inner_join(test_results, by = c("cutpoint", "estimand_name")) %>%
    mutate_at(vars(starts_with("per_")), ~ . - prob) %>%
    mutate(coverage = if_else(per_0.1 > 0 | per_0.9 < 0, "outside", "inside") %>% factor()) %>%
    select(estimand_name, coverage) %>%
    print(n = 1000)

  bind_rows(prior = test_prior_results, posterior = test_results, .id = "fit_type") %>%
    filter(estimand_name == default_unobs_cf) %>%
    select(fit_type, starts_with("per_")) %>%
    pivot_longer(names_to = "quant", cols = -fit_type, names_prefix = "per_", names_ptypes = list("quant" = numeric())) %>%
    ggplot() +
    geom_col(aes(quant, value, group = fit_type, fill = fit_type), position = position_dodge()) +
    geom_vline(xintercept = known_results %>% filter(estimand_name == default_unobs_cf) %>% pull(prob)) +
    scale_fill_discrete("") +
    scale_x_continuous(breaks = seq(0, 1, 0.2)) +
    labs(x = "", y = "") +
    theme_minimal() +
    theme(legend.position = "top")

  test_prior_marginal_prob %>%
    filter(discretized) %>%
    select(type, r_m, iter_data) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r)) +
    labs(x = "", y = "") +
    facet_grid(rows = vars(type), cols = vars(r_m), scales = "free") +
    theme_minimal() +
    NULL

  test_prior_marginal_prob %>%
    filter(!discretized) %>%
    select(type, iter_data) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r)) +
    labs(x = "", y = "") +
    facet_wrap(vars(type), scales = "free") +
    theme_minimal() +
    NULL
}

# Multiple Runs -----------------------------------------------------------

if (script_options$multi) {
  num_runs <- script_options$runs

  test_sim_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                     discretized_beta_hyper_mean = true_discretized_beta_hyper_mean,
                                                     discrete_beta_hyper_sd = script_options$true_hyper_sd,
                                                     discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,
                                                     tau_level_sigma = 1,
                                                     num_entities = script_options$num_entities,
                                                     num_sim = num_runs) %>%
    deframe()

  used_discretized_beta_hyper_sd <- if (script_options$different_priors) script_options$alt_hyper_sd else true_discretized_beta_hyper_sd
  used_discretized_beta_hyper_sd <- if (!is_empty(script_options$alt_hyper_sd_constrained)) {
    list(
      default = script_options$alt_hyper_sd_constrained,
      "always below" = used_discretized_beta_hyper_sd,
      "never below" = used_discretized_beta_hyper_sd
    )
  } else {
    used_discretized_beta_hyper_sd
  }

  test_run_data <- test_sim_data %>%
    test_parallel_map(cores = script_options$cores %/% 4,
    # map(
      function(entity_data, discrete_beta_hyper_sd, discretized_beta_hyper_sd, save_iter_data) {
        entity_data %<>% deframe()

        known_results <- entity_data %>%
          map_dfr(boundr:::get_known_estimands, default_estimands, .id = "entity_index") %>%
          group_by_at(vars(-entity_index, -prob)) %>%
          summarize(prob = mean(prob)) %>%
          ungroup()

        known_marginal_prob <- entity_data %>%
          map_dfr(boundr:::get_known_marginal_latent_type_prob, .id = "entity_index") %>%
          group_by_at(vars(-entity_index, -marginal_prob)) %>%
          summarize(marginal_prob = mean(marginal_prob)) %>%
          ungroup()

        sampler <- entity_data %>%
          map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
          # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, 0, -30))) %>%
          # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, runif(n(), -19, 19), -30))) %>%
          mutate(y = if_else(y_1 == 0, 30, -30)) %>%
          create_sampler(
            default_model,
            model_levels = "entity_index",
            analysis_data = .,
            estimands = default_estimands,
            y = y,

            discretized_beta_hyper_mean = true_discretized_beta_hyper_mean,

            discrete_beta_hyper_sd = discrete_beta_hyper_sd,
            discretized_beta_hyper_sd = discretized_beta_hyper_sd,
            tau_level_sigma = 1,
            calculate_marginal_prob = TRUE
          )

        fit <- sampler %>%
          sampling(
            pars = c("iter_estimand", "single_discrete_marginal_p_r", "discretized_marginal_p_r"),
            chains = 4, iter = 1000
          )

        results <- fit %>%
          get_estimation_results(quants = seq(0, 1, 0.1)) %>%
          select(estimand_name, cutpoint, starts_with("per_"), if (save_iter_data) "iter_data") %>%
          inner_join(
            select(known_results, estimand_name, cutpoint, prob),
            by = c("estimand_name", "cutpoint")
          ) %>%
          mutate(coverage = if_else((per_0.1 - prob) > 0 | (per_0.9 - prob) < 0, "outside", "inside") %>% factor())

        marginal_prob <- fit %>% get_marginal_latent_type_prob() %>%
          select(type_variable, type, starts_with("per_"), if (save_iter_data) "iter_data") %>%
          inner_join(known_marginal_prob, by = c("type_variable", "type")) %>%
          mutate(coverage = if_else((per_0.1 - marginal_prob) > 0 | (per_0.9 - marginal_prob) < 0, "outside", "inside") %>% factor())

        lp_bounds <- entity_data %>%
          # map(get_linear_programming_bounds, "y_1", b = 0, g = 0, z = 0, m = 0, cond = m == 1 & b == 0 & g == 0 & z == 0) %>%
          map(get_linear_programming_bounds, "y_1", z = 0, m = 0, cond = m == 1 & z == 0) %>%
          map(map, pluck, "objval") %>%
          map_df(as_tibble)

        tibble(results = list(results), marginal_prob = list(marginal_prob), lp_bounds = list(lp_bounds))
      },

      discrete_beta_hyper_sd = if (script_options$different_priors) script_options$alt_hyper_sd else script_options$true_hyper_sd,
      discretized_beta_hyper_sd = used_discretized_beta_hyper_sd,
      save_iter_data = script_options$density_plots
    ) %>%
    compact() %>%
    bind_rows(.id = "iter_id")

  dummy_data <- test_sim_data[[1]] %>%
    select(entity_index, sim) %>%
    deframe() %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    # mutate(y = if_else(y_2 == 0, 30, if_else(y_1 == 0, runif(n(), -19, 19), -30))) %>%
    mutate(y = if_else(y_1 == 0, 30, -30)) # This data isn't really used in prior prediction

  true_prior_sampler <- create_sampler(
    default_model,
    model_levels = "entity_index",
    analysis_data = dummy_data,
    estimands = default_estimands,
    y = y,


    discretized_beta_hyper_mean = true_discretized_beta_hyper_mean,

    discrete_beta_hyper_sd = script_options$true_hyper_sd,
    discretized_beta_hyper_sd = true_discretized_beta_hyper_sd,
    tau_level_sigma = 1,
    calculate_marginal_prob = TRUE
  )

  true_prior_fit <- true_prior_sampler %>%
    sampling(
      pars = c("iter_estimand", "single_discrete_marginal_p_r", "discretized_marginal_p_r"),
      chains = 4, iter = 1000,
      run_type = "prior-predict",
    )

  true_prior_results <- true_prior_fit %>% get_estimation_results(quants = seq(0, 1, 0.1))
  true_prior_marginal_prob <- true_prior_fit %>% get_marginal_latent_type_prob()
  prior_results <- true_prior_results
  prior_marginal_prob <- true_prior_marginal_prob

  test_run_data_file <- file.path("temp-data", str_c(script_options$output, ".rds"))

  if (script_options$append && file.exists(test_run_data_file)) {
    test_run_data %<>%
      bind_rows(read_rds(test_run_data_file))
  }

  write_rds(test_run_data, test_run_data_file)

  if (script_options$different_priors) {
    prior_sampler <- create_sampler(
      default_model,
      model_levels = "entity_index",
      analysis_data = dummy_data,
      estimands = default_estimands,
      y = y,

      discrete_beta_hyper_sd = script_options$alt_hyper_sd,
      discretized_beta_hyper_sd = used_discretized_beta_hyper_sd, #script_options$alt_hyper_sd,
      tau_level_sigma = 1,
      calculate_marginal_prob = TRUE
    )

    prior_fit <- prior_sampler %>%
      sampling(
        pars = c("iter_estimand", "single_discrete_marginal_p_r", "discretized_marginal_p_r"),
        chains = 4, iter = 1000,
        run_type = "prior-predict",
      )

    prior_results <- prior_fit %>% get_estimation_results(quants = seq(0, 1, 0.1))
    prior_marginal_prob <- prior_fit %>% get_marginal_latent_type_prob()

    test_run_data %>%
      select(iter_id, results) %>%
      unnest(results) %>%
      filter(estimand_name %in% c(default_unobs_cf, default_obs_cf, default_cf_diff, tot_unobs_cf, tot_obs_cf, tot_cf_diff, default_ate)) %>%
      pack(post = starts_with("per_")) %>%
      left_join(
        prior_results %>%
          pack(wrong_prior = starts_with("per_")),
        by = c("estimand_name", "cutpoint")
      ) %>%
      left_join(
        true_prior_results %>%
          pack(true_prior = starts_with("per_")),
        by = c("estimand_name", "cutpoint")
      ) %>%
      unpack(c(true_prior, wrong_prior, post), names_sep = "_") %>%
      mutate_at(vars(contains("prior_per_")), ~ . - prob) %>%
      mutate(wrong_prior_coverage = if_else(wrong_prior_per_0.1 > 0 | wrong_prior_per_0.9 < 0, "outside", "inside") %>% factor(),
             true_prior_coverage = if_else(true_prior_per_0.1 > 0 | true_prior_per_0.9 < 0, "outside", "inside") %>% factor()) %>%
      group_by_at(vars(estimand_name, any_of("cutpoint"))) %>%
      summarize_at(vars(ends_with("coverage")), ~ mean(fct_match(., "inside"))) %>%
      ungroup() %>%
      arrange_at(vars(estimand_name, any_of("cutpoint"))) %>%
      print(n = 1000, width = 160)
  } else {
    test_run_data %>%
      select(iter_id, results) %>%
      unnest(results) %>%
      group_by_at(vars(estimand_name, any_of("cutpoint"))) %>%
      summarize(coverage = mean(fct_match(coverage, "inside"))) %>%
      ungroup() %>%
      arrange_at(vars(estimand_name, any_of("cutpoint"))) %>%
      print(n = 1000)
  }

  get_prior_and_post <- function(estimand) {
    bind_rows(
      prior = if (script_options$different_priors) {
        prior_results %>%
          filter(estimand_name %in% estimand) %>%
          map_df(seq(num_runs), ~ mutate(..2, iter_id = ..1), .)
      },

      posterior = test_run_data %>%
        select(iter_id, results) %>%
        unnest(results) %>%
        # filter(estimand_name == default_unobs_cf) %>%
        filter(estimand_name %in% estimand) %>%
        mutate(iter_id = as.integer(iter_id)),

      "true prior" = true_prior_results %>%
        # filter(estimand_name == default_unobs_cf) %>%
        filter(estimand_name %in% estimand) %>%
        map_df(seq(num_runs), ~ mutate(..2, iter_id = ..1), .),

      .id = "fit_type"
    )
  }

  if (script_options$density_plots) {
    density_plots <- get_prior_and_post(tot_cf_diff) %>%
      # filter(iter_id %in% sample(.$iter_id, 50, replace = FALSE)) %>%
      select(iter_id, prob, fit_type, iter_data) %>%
      mutate(iter_data = map_if(iter_data, ~ !is_null(.x), select, -iter_id)) %>%
      unnest(iter_data) %>%
      ggplot() +
      geom_density(aes(iter_estimand, group = fit_type, color = fit_type)) +
      geom_vline(xintercept = 0, linetype = "dotted") +
      geom_vline(aes(xintercept = prob), data = . %>% distinct(iter_id, prob)) +
      scale_fill_discrete("", aesthetics = c("color", "fill")) +
      labs(x = "", y = "") +
           # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")) +
      facet_wrap(vars(iter_id), scales = "free") +
      theme_minimal() +
      theme(legend.position = "top", strip.text = element_blank(), plot.subtitle = element_text(size = 9))

    ggsave(file.path("temp-img", str_c(script_options$output, ".png")), density_plots)

    if (!interactive() && require(tcltk)) {
      x11()
      plot(density_plots)
      capture <- tk_messageBox(message = "Hit spacebar to close plots.")
    } else {
      plot(density_plots)
    }

    get_prior_and_post(default_cf_diff) %>%
      filter(fct_match(fit_type, "posterior")) %>%
      mutate(
        iter_id = seq(n()),
        iter_id = fct_reorder(factor(iter_id), prob)
      ) %>% {
        cowplot::plot_grid(
          ggplot(., aes(iter_id)) +
            geom_hline(yintercept = 0, linetype = "dotted") +
            geom_point(aes(y = prob), shape = 1, size = 2) +
            geom_pointrange(aes(y = per_0.5, ymin = per_0.1, ymax = per_0.9, color = coverage), fatten = 1) +
            scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
            labs(x = "", y = "", subtitle = "Uncentered") +
            theme_minimal() +
            theme(axis.text.x = element_blank(),
                  axis.ticks.x = element_blank(),
                  panel.grid.minor.x = element_blank(),
                  panel.grid.major.x = element_blank(),
                  legend.position = "none"),

          mutate_at(., vars(starts_with("per_")), ~ . - prob) %>%
            ggplot(aes(iter_id)) +
            geom_hline(yintercept = 0, linetype = "dotted") +
            geom_pointrange(aes(y = per_0.5, ymin = per_0.1, ymax = per_0.9, color = coverage), fatten = 1) +
            scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
            labs(x = "", y = "", subtitle = "Centered") +
            theme_minimal() +
            theme(axis.text.x = element_blank(),
                  axis.ticks.x = element_blank(),
                  panel.grid.minor.x = element_blank(),
                  panel.grid.major.x = element_blank(),
                  legend.position = "none"),

          ncol = 1
        )
      }

    get_prior_and_post(c(default_obs_cf, default_cf_diff)) %>%
      filter(fct_match(fit_type, "posterior")) %>%
      group_by(estimand_name) %>%
      mutate(
        iter_id = seq(n()),
        iter_id = fct_reorder(factor(iter_id), prob)
      ) %>%
      ungroup() %>%
      select(iter_id, estimand_name, per_0.1, per_0.5, per_0.8, prob, coverage) %>%
      mutate(estimand_name = if_else(str_detect(estimand_name, "-"), "diff", "obs")) %>%
      pivot_wider(names_from = "estimand_name", values_from = c("per_0.1", "per_0.5", "per_0.8", "prob", "coverage")) %>% {
        cowplot::plot_grid(
          ggplot(.) +
            geom_point(aes(prob_obs, per_0.5_diff, color = coverage_diff)) +
            geom_point(aes(prob_obs, prob_diff), shape = 3) +
            scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
            # labs(x = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\]"),
            labs(x = "",
                 y = "TOT") +
                 # y = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\] - P\\[Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1\\]")) +
            theme_minimal() +
            theme(legend.position = "none",
                  axis.text.x = element_blank()) +
            NULL,

        ggplot(.) +
          geom_pointrange(aes(prob_obs, ymin = per_0.1_diff, y = per_0.5_diff, ymax = per_0.8_diff,  color = coverage_diff), fatten = 1) +
          geom_line(aes(prob_obs, prob_obs), linetype = "dashed") +
          geom_line(aes(prob_obs, prob_obs - 1), linetype = "dashed") +
          geom_point(aes(prob_obs, prob_diff), shape = 3) +
          scale_color_manual("", values = c("inside" = "black", "outside" = "red")) +
          labs(x = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\]"),
               y = "TOT") +
               # y = latex2exp::TeX("P\\[Y_{b=0,g=0,z=0,m=1} < c | M_{b=0,g=0,z=0} = 1\\] - P\\[Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1\\]")) +
          theme_minimal() +
          theme(legend.position = "none") +
          NULL,

        ncol = 1,
        rel_heights = c(1, 1.1))
      }

    prior_results %>%
      filter(estimand_name %in% c(default_obs_cf, default_unobs_cf)) %>%
      select(estimand_name, iter_data) %>%
      mutate(estimand_name = if_else(str_detect(estimand_name, "m=1"), "obs", "unobs")) %>%
      unnest(iter_data) %>%
      pivot_wider(names_from = "estimand_name", values_from = "iter_estimand") %>%
      ggplot(aes(obs, unobs)) +
      # geom_density2d()
      stat_density_2d(aes(fill = after_stat(nlevel)), geom = "polygon") +
      labs(x = "Observable", y = "Unobservable") +
      scale_fill_viridis_c() +
      theme_minimal() +
      theme(legend.position = "none")

    test_run_data %>%
      select(iter_id, results, lp_bounds) %>%
      sample_n(150) %>%
      mutate(
        results = map(results, filter, estimand_name == default_unobs_cf) %>%
          map(select, per_0.1, per_0.8, prob),
      ) %>%
      unnest(c(results, lp_bounds)) %>%
      mutate(iter_id = fct_reorder(iter_id, prob)) %>%
      ggplot(aes(iter_id)) +
      geom_linerange(aes(ymin = per_0.1, ymax = per_0.8), color = "red", size = 2, alpha = 0.5) +
      geom_errorbar(aes(ymin = min, ymax = max)) +
      geom_point(aes(y = prob), size = 1) +
      labs(
        x = "", y = ""
      ) +
      theme_minimal() +
      theme(legend.position = "top",
            axis.text.x = element_blank(),
            plot.subtitle = element_text(size = 9))

    marginal_prob_density_plots <- bind_rows(
      prior = if (script_options$different_priors) {
        prior_marginal_prob %>%
          map_df(seq(1), ~ mutate(.y, iter_id = .x), .)
      },

      # posterior = test_run_data %>%
      #   slice(1) %>%
      #   select(iter_id, marginal_prob) %>%
      #   unnest(marginal_prob) %>%
      #   filter(discretized) %>%
      #   mutate(iter_id = as.integer(iter_id)),

      "true prior" = true_prior_marginal_prob %>%
        map_df(seq(1), ~ mutate(.y, iter_id = .x), .),

    .id = "fit_type"
    ) %>%
      filter(
        # iter_id %in% sample(.$iter_id, 6, replace = FALSE),
        fct_match(type_variable, "r_y_1")
      ) %>%
      mutate(estimand_name = str_c(type_variable, " = ", type) %>% str_replace("^r", "R")) %>%
      # select(estimand_name, iter_id, marginal_prob, fit_type, iter_data, r_m) %>%
      select(estimand_name, iter_id, fit_type, iter_data, r_m) %>%
      mutate(iter_data = map_if(iter_data, ~ !is_null(.x), select, -iter_id)) %>%
      unnest(iter_data) %>%
      ggplot() +
      geom_density(aes(iter_p_r, group = fit_type, color = fit_type, fill = fit_type), alpha = 0.25) +
      # geom_vline(aes(xintercept = marginal_prob), data = . %>% distinct(estimand_name, iter_id, marginal_prob)) +
      scale_fill_discrete("", aesthetics = c("color", "fill")) +
      labs(
        x = "", y = ""
        # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")
      ) +
      # facet_grid(rows = vars(estimand_name), cols = vars(iter_id), scales = "free", switch = "y") +
      facet_grid(rows = vars(estimand_name), cols = vars(r_m), scales = "free", switch = "y") +
      # coord_cartesian(ylim = c(0, 8)) +
      theme_minimal() +
      theme(legend.position = "top",
            # strip.text.x = element_blank(),
            strip.text.y.left = element_text(angle = 0),
            axis.text.y = element_blank(),
            plot.subtitle = element_text(size = 9))

    ggsave(file.path("temp-img", str_c(script_options$output, "_marginal_type_prob.png")), marginal_prob_density_plots)

    if (!interactive() && require(tcltk)) {
      x11()
      plot(marginal_prob_density_plots)
      capture <- tk_messageBox(message = "Hit spacebar to close plots.")
    } else {
      plot(marginal_prob_density_plots)
    }
  }
}


# Prior Sequence ----------------------------------------------------------

if (script_options$prior_sequence) {
  taus <- seq(script_options$from_tau, script_options$to_tau, script_options$by_tau)

  dummy_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                   discrete_beta_hyper_sd = script_options$true_hyper_sd, discretized_beta_hyper_sd = true_discretized_beta_hyper_sd, tau_level_sigma = 1,
                                                   num_entities = script_options$num_entities) %>%
    unnest(entity_data) %>%
    select(entity_index, sim) %>%
    deframe() %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    mutate(y = if_else(y_1 == 0, 30, -30))

  all_prior_results <- taus %>%
    test_parallel_map(cores = script_options$cores %/% 4,
                      function(curr_tau) {
                        test_sampler <- create_sampler(
                          default_model,
                          model_levels = "entity_index",
                          analysis_data = dummy_data,
                          estimands = default_estimands,
                          # y = y < -20,
                          y = y,

                          discrete_beta_hyper_sd = if (script_options$different_priors) script_options$alt_hyper_sd else script_options$true_hyper_sd,
                          discretized_beta_hyper_sd = curr_tau,

                          tau_level_sigma = 1,
                          calculate_marginal_prob = TRUE
                        )

                        test_prior_fit <- test_sampler %>%
                          sampling(
                            chains = 4,
                            warmup = 500,
                            iter = 2500,
                            pars = c("iter_estimand", "marginal_p_r"),
                            run_type = "prior-predict",
                          )

                        tibble(
                          results = test_prior_fit %>%
                            get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %>%
                            list(),

                          marginal_prob = test_prior_fit %>%
                            get_marginal_latent_type_prob() %>%
                            list()
                        )
                      }) %>%
    set_names(taus) %>%
    bind_rows(.id = "tau") %>%
    mutate(tau = as.numeric(tau))

  all_prior_results %>%
    select(tau, results) %>%
    unnest(results) %>%
    # filter(estimand_name == "Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c | M = 1, B = 0, G = 0, Z = 0]") %>%
    filter(estimand_name == default_unobs_cf) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_estimand, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(x = "", y = "",
         # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")) +
         subtitle = latex2exp::TeX("P(Y_{z=0,m=0} < c | M_{z=0} = 1)")) +
    theme_minimal() +
    theme(legend.position = "right", plot.subtitle = element_text(size = 9))

  all_prior_results %>%
    select(tau, marginal_prob) %>%
    unnest(marginal_prob) %>%
    filter(fct_match(type_variable, "r_y_1")) %>%
    mutate(estimand_name = str_c(type_variable, " = ", type) %>% str_replace("^r", "R")) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(
      x = "", y = ""
    ) +
    facet_wrap(vars(estimand_name), scales = "free") +
    coord_cartesian(ylim = c(0, 8)) +
    theme_minimal() +
    theme(legend.position = "top",
          # strip.text.x = element_blank(),
          # strip.text.y.left = element_text(angle = 0),
          axis.text.y = element_blank(),
          plot.subtitle = element_text(size = 9))
}

# Constrained Prior Sequence ----------------------------------------------------------

if (script_options$constrained_prior_sequence) {
  unconst_tau <- script_options$unconst_tau
  taus <- seq(script_options$from_tau, script_options$to_tau, script_options$by_tau)

  dummy_data <- create_prior_predicted_simulation(default_model, sample_size = 4000, chains = 4, iter = 1000,
                                                   discrete_beta_hyper_sd = script_options$true_hyper_sd, discretized_beta_hyper_sd = true_discretized_beta_hyper_sd, tau_level_sigma = 1,
                                                   num_entities = script_options$num_entities) %>%
    unnest(entity_data) %>%
    select(entity_index, sim) %>%
    deframe() %>%
    map_dfr(create_simulation_analysis_data, .id = "entity_index") %>%
    mutate(y = if_else(y_1 == 0, 30, -30))

  all_prior_results <- taus %>%
    # test_parallel_map(cores = script_options$cores %/% 4,
    map(
                      function(curr_tau) {
                        test_sampler <- create_sampler(
                          default_model,
                          model_levels = "entity_index",
                          analysis_data = dummy_data,
                          estimands = default_estimands,
                          # y = y < -20,
                          y = y,

                          discrete_beta_hyper_sd = unconst_tau,
                          discretized_beta_hyper_sd = list(
                            default = curr_tau,
                            "always below" = unconst_tau,
                            "never below" = unconst_tau
                          ),

                          tau_level_sigma = 1,
                          calculate_marginal_prob = TRUE
                        )

                        test_prior_fit <- test_sampler %>%
                          sampling(
                            chains = 4,
                            warmup = 500,
                            iter = 2500,
                            pars = c("iter_estimand", "marginal_p_r"),
                            run_type = "prior-predict",
                          )

                        tibble(
                          results = test_prior_fit %>%
                            get_estimation_results(no_sim_diag = FALSE, quants = seq(0, 1, 0.1)) %>%
                            list(),

                          marginal_prob = test_prior_fit %>%
                            get_marginal_latent_type_prob() %>%
                            list()
                        )
                      }) %>%
    set_names(taus) %>%
    bind_rows(.id = "tau") %>%
    mutate(tau = as.numeric(tau))

  all_prior_results %>%
    select(tau, results) %>%
    unnest(results) %>%
    # filter(estimand_name == "Pr[Y^{y}_{b=0,g=0,z=0,m=0} < c | M = 1, B = 0, G = 0, Z = 0]") %>%
    filter(estimand_name == default_unobs_cf) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_estimand, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(x = "", y = "",
         # subtitle = latex2exp::TeX("P(Y_{b=0,g=0,z=0,m=0} < c | M_{b=0,g=0,z=0} = 1)")) +
         subtitle = latex2exp::TeX("P(Y_{z=0,m=0} < c | M_{z=0} = 1)")) +
    theme_minimal() +
    theme(legend.position = "right", plot.subtitle = element_text(size = 9))

  all_prior_results %>%
    select(tau, marginal_prob) %>%
    unnest(marginal_prob) %>%
    filter(fct_match(type_variable, "r_y_1")) %>%
    mutate(estimand_name = str_c(type_variable, " = ", type) %>% str_replace("^r", "R")) %>%
    unnest(iter_data) %>%
    ggplot() +
    geom_density(aes(iter_p_r, group = tau, color = tau)) +
    scale_color_viridis_c(expression(tau[beta])) +
    labs(
      x = "", y = ""
    ) +
    facet_wrap(vars(estimand_name), scales = "free") +
    coord_cartesian(ylim = c(0, 8)) +
    theme_minimal() +
    theme(legend.position = "top",
          # strip.text.x = element_blank(),
          # strip.text.y.left = element_text(angle = 0),
          axis.text.y = element_blank(),
          plot.subtitle = element_text(size = 9))
}

# Diagnostics -------------------------------------------------------------

# color_scheme_set("darkgray")
# test_posterior <- as.array(test_fit)
# test_np <- nuts_params(test_fit)
# mcmc_pairs(test_posterior, np = test_np, regex_pars =  c("level_beta_sigma", "obs_beta\\[[2-4]"))
karimn/boundr documentation built on March 1, 2021, 6:57 p.m.