library(tidyverse)
library(foreach)
library(doMC)
library(scales)
library(egg)
library(MendelianRandomization)
library(cause)
library(MRMix)
library(MRPRESSO)
library(rslurm)
library(WWER)
# Some helper functions to assist in simualting and generating plots.

# Note that this closes over a globally defined n_cores variable.
run_sim_niter <- function(name, n_iter, Na, Nb, M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s, af_ref_file = NULL){
  doMC::registerDoMC(cores = n_cores)
  print(c(n_cores, n_iter, Na, Nb, M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s))
  res_tibble = foreach(iter = 1:n_iter, .combine = bind_rows) %dopar% {
    run_sim(name, iter, Na, Nb, M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s, af_ref_file)
  }
  return(res_tibble)
}


se <- function(x, na.rm = FALSE){
  sqrt(var(x, na.rm = na.rm)/length(x))
}


make_summary_tibble <- function(plot_tibble){
  summary_tibble <- plot_tibble %>%
    dplyr::group_by(name, method, mr_method, p_thresh, Na, Nb, M, ha, hb, hu, nu, eta, gamma, delta, q, r, s, Rab, Rba) %>% 
    mutate(stat_ab = mean(pab < 0.05, na.rm=T), stat_ba = mean(pba < 0.05, na.rm=T),
           se_stat_ab := se(pab < 0.05, na.rm=T), se_stat_ba := se(pba < 0.05, na.rm=T),
           mae_ab = mean(abs(Rab - Rab_hat), na.rm=T), mae_ba = mean(abs(Rba - Rba_hat), na.rm=T),
           se_mae_ab = se(abs(Rab - Rab_hat), na.rm=T), se_mae_ba = se(abs(Rba - Rba_hat), na.rm=T),
           Rab_hat = mean(Rab_hat, na.rm=T), Rba_hat = mean(Rba_hat, na.rm=T),
           Sab_hat = mean(Sab_hat, na.rm=T), Sba_hat = mean(Sba_hat, na.rm=T),
           runtime = mean(runtime, na.rm=T), rg = mean(rg)) %>%
    ungroup() %>%
    distinct(name, method, mr_method, p_thresh, Na, Nb, M, ha, hb, hu, rg, nu, eta, gamma, delta, q, r, s, Rab, Rba,
             stat_ab, stat_ba, se_stat_ab, se_stat_ba, mae_ab, mae_ba, se_mae_ab, se_mae_ba,
             Rab_hat, Rba_hat, Sab_hat, Sba_hat, runtime) %>%
    mutate(method = if_else(method == "egger_wf_5e-8", "aaa_egger_wf_5e-8", method))
  return(summary_tibble)
}


reverselog_trans <- function(base = exp(1)) {
  trans <- function(x) -log(x, base)
  inv <- function(x) base^(-x)
  trans_new(paste0("reverselog-", format(base)), trans, inv,
            log_breaks(base = base), domain = c(1e-100, Inf))
}
# Setup simulations df for rslurm.
n_iter = 100

af_ref_file <- "/gpfs/commons/projects/UKBB/sumstats/ldsc_results/eur_w_ld_chr/19.l2.ldscore.gz"

# Simulations matching LCV main.
M_lcv = 500000

h = 0.3
eta_lcv = c(0, sqrt(0.1), sqrt(0.2), sqrt(0.4), sqrt(0.6))
nu_lcv = c(0, sqrt(0.1), sqrt(0.2), sqrt(0.4), sqrt(0.6))
p_lcv = c(5000, 4750, 4500, 4000, 3500)/M_lcv
q_lcv = (nu_lcv**2)/(2-nu_lcv**2) # c(0, 0.053, 0.111, 0.286, 0.429)
r_lcv = (1-q_lcv)/2

sim_lcv_1 = tibble(
  name = "Equal power, equal architecture.",
  n_iter = n_iter, Na = 100000, Nb = 100000, M = M_lcv,
  ha = h, hb = h, hu = h,
  nu = nu_lcv, eta = eta_lcv, gamma = 0.0, delta = 0.0,
  p = p_lcv, q = q_lcv, r = r_lcv, s = r_lcv, af_ref_file = af_ref_file)

sim_lcv_2 = tibble(
  name = "Study 1 higher power, equal architecture.",
  n_iter = n_iter, Na = 100000, Nb = 20000, M = M_lcv,
  ha = h, hb = h, hu = h,
  nu = nu_lcv, eta = eta_lcv, gamma = 0.0, delta = 0.0,
  p = p_lcv, q = q_lcv, r = r_lcv, s = r_lcv, af_ref_file = af_ref_file)


r_lcv = (1-q_lcv)/3
s_lcv = 2*(1-q_lcv)/3
sim_lcv_3 = tibble(
  name = "Equal power, trait 2 less polygenic.",
  n_iter = n_iter, Na = 100000, Nb = 100000, M = M_lcv,
  ha = h, hb = h, hu = h,
  nu = nu_lcv, eta = eta_lcv, gamma = 0.0, delta = 0.0,
  p = p_lcv, q = q_lcv, r = r_lcv, s = s_lcv, af_ref_file = af_ref_file)

sim_lcv_4 = tibble(
  name = "Study 1 higher power, trait 2 less polygenic.",
  n_iter = n_iter, Na = 100000, Nb = 20000, M = M_lcv,
  ha = h, hb = h, hu = h,
  nu = nu_lcv, eta = eta_lcv, gamma = 0.0, delta = 0.0,
  p = p_lcv, q = q_lcv, r = r_lcv, s = s_lcv, af_ref_file = af_ref_file)

sim_settings_lcv <- dplyr::bind_rows(sim_lcv_1, sim_lcv_2, sim_lcv_3, sim_lcv_4)

sim_settings_lcv_data <- sim_settings_lcv %>%
  dplyr::rowwise() %>% dplyr::mutate(vars=list(WWER:::h2_to_s2(M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s))) %>%
  dplyr::mutate(sa = vars$sa, sb= vars$sb, su = vars$su, Ma = M*p*r, Mb = M*p*s, Mu = M*p*q) %>% dplyr::select(-vars)
# Note: There does not appear to be a way to have rslurm request multiple CPUs
#   on a node (e.g. for doMC) without also distributing multiple jobs over them.
#   Thus this sets submit=F so I can manually edit the CPUs per node requested
#   in the submission script.
# TODO(brielin): Right now the bimmer functions need to be loaded with
#  devtools::load_all(), then passed via add_objects. Figure out how to do it
#  via just loading the package.
n_cores = 8 # Need to edit this into the submit.sh script.
sjob_lcv <- rslurm::slurm_apply(
  run_sim_niter, sim_settings_lcv, jobname = "bimr_lcv", nodes = nrow(sim_settings_lcv),
  cpus_per_node = 1, slurm_options = list(time = "36:00:00", mem = "32G"),
  add_objects = c("n_cores"),
  submit = F)
# Simulations matching CAUSE main.
af_ref_file <- "/gpfs/commons/projects/UKBB/sumstats/ldsc_results/eur_w_ld_chr/19.l2.ldscore.gz"
n_iter = 100
h = 0.25
q_cause = c(0, 0.05, 0.111, 0.222, 0.333)  # shared/all = (0.10, 0.19, 0.36, 0.5)
M_cause = 500000
M_eff_cause = ceiling((2/(q_cause + 1))*1000)
p_cause = M_eff_cause/M_cause
r_cause = (1-q_cause)/2
nu_cause = sqrt(h)  # Using nu = sqrt(h) means hu = 0.5 when q = 1/3
hu_cause = (M_cause*q_cause/(M_cause*q_cause + M_cause*r_cause)) * (h / nu_cause**2)

# 50k = 100-110 sig
# 20k = 15-30 sig
# 15k = 6-15
sim_cause_1 = tibble(
  name = "Equal power, weak shared effect",
  n_iter = n_iter, Na = 50000, Nb = 50000, M = M_cause,
  ha = h, hb = h, hu = hu_cause,
  nu = nu_cause, eta = sqrt(0.02)*nu_cause, gamma = 0.0, delta = 0.0,
  p = p_cause, q = q_cause, r = r_cause, s = r_cause, af_ref_file = af_ref_file)

sim_cause_2 = tibble(
  name = "Equal power, stronger shared effect",
  n_iter = n_iter, Na = 50000, Nb = 50000, M = M_cause,
  ha = h, hb = h, hu = hu_cause,
  nu = nu_cause, eta = sqrt(0.05)*nu_cause, gamma = 0.0, delta = 0.0,
  p = p_cause, q = q_cause, r = r_cause, s = r_cause, af_ref_file = af_ref_file)

sim_cause_3 = tibble(
  name = "Study 2 higher power, stronger shared effect",
  n_iter = n_iter, Na = 20000, Nb = 50000, M = M_cause,
  ha = h, hb = h, hu = hu_cause,
  nu = nu_cause, eta = sqrt(0.05)*nu_cause, gamma = 0.0, delta = 0.0,
  p = p_cause, q = q_cause, r = r_cause, s = r_cause, af_ref_file = af_ref_file)

sim_cause_4 = tibble(
  name = "Study 1 higher power, stronger shared effect",
  n_iter = n_iter, Na = 50000, Nb = 20000, M = M_cause,
  ha = h, hb = h, hu = hu_cause,
  nu = nu_cause, eta = sqrt(0.05)*nu_cause, gamma = 0.0, delta = 0.0,
  p = p_cause, q = q_cause, r = r_cause, s = r_cause, af_ref_file = af_ref_file)


sim_settings_cause <- dplyr::bind_rows(sim_cause_1, sim_cause_2, sim_cause_3,  sim_cause_4)

sim_settings_cause_data <- sim_settings_cause %>%
  dplyr::rowwise() %>% dplyr::mutate(vars=list(WWER:::h2_to_s2(M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s))) %>%
  dplyr::mutate(sa = vars$sa, sb= vars$sb, su = vars$su, Ma = M*p*r, Mb = M*p*s, Mu = M*p*q) %>% dplyr::select(-vars)
n_cores = 8
sjob_cause <- rslurm::slurm_apply(
  run_sim_niter, sim_settings_cause, jobname = "bimr_cause", nodes = nrow(sim_settings_cause),
  cpus_per_node = 1, slurm_options = list(time = "36:00:00", mem = "32G"),
  add_objects = c("n_cores"),
  submit = F)
# New sims with different polyG separated from conf effect sizes. Equal sample sizes.
h = 0.3
Na = 100000
Nb = 100000
af_ref_file <- "/gpfs/commons/projects/UKBB/sumstats/ldsc_results/eur_w_ld_chr/19.l2.ldscore.gz"
n_iter = 100
nu_asym = c(sqrt(0.1), sqrt(0.3), sqrt(0.3))
eta_asym = c(sqrt(0.3), sqrt(0.1), sqrt(0.3))
nu_sym = c(sqrt(0.1), sqrt(0.3))
eta_sym = c(sqrt(0.3), sqrt(0.3))
M_l = 500
M_h = 2000
M_new = 500000

M_lhh = 2*M_h + 1*M_l
p_lhh = M_lhh/M_new
q_lhh = M_l/M_lhh
r_lhh = M_h/M_lhh
s_lhh = M_h/M_lhh
sim_lhh <- tibble(
  name = "PolyG: high A, high B, low U; Na = Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_lhh, q = q_lhh, r = r_lhh, s = s_lhh, af_ref_file = af_ref_file
)

M_lhl = 1*M_h + 2*M_l
p_lhl = M_lhl/M_new
q_lhl = M_l/M_lhl
r_lhl = M_h/M_lhl
s_lhl = M_l/M_lhl
sim_lhl <- tibble(
  name = "PolyG: high A, low B, low U; Na = Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_asym, eta = eta_asym, gamma = 0, delta = 0,
  p = p_lhl, q = q_lhl, r = r_lhl, s = s_lhl, af_ref_file = af_ref_file
)

M_new = 500000
M_lll = 3*M_l
p_lll = M_lll/M_new
q_lll = M_l/M_lll
r_lll = M_l/M_lll
s_lll = M_l/M_lll
sim_lll <- tibble(
  name = "PolyG: low A, low B, low U; Na = Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_lll, q = q_lll, r = r_lll, s = s_lll, af_ref_file = af_ref_file
)

M_hhh = 3*M_h
p_hhh = M_hhh/M_new
q_hhh = M_h/M_hhh
r_hhh = M_h/M_hhh
s_hhh = M_h/M_hhh
sim_hhh <- tibble(
  name = "PolyG: high A, high B, high U; Na = Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_hhh, q = q_hhh, r = r_hhh, s = s_hhh, af_ref_file = af_ref_file
)

M_hhl = 2*M_h + 1*M_l
p_hhl = M_hhl/M_new
q_hhl = M_h/M_hhl
r_hhl = M_h/M_hhl
s_hhl = M_l/M_hhl
sim_hhl <- tibble(
  name = "PolyG: high A, low B, high U; Na = Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_asym, eta = eta_asym, gamma = 0, delta = 0,
  p = p_hhl, q = q_hhl, r = r_hhl, s = s_hhl, af_ref_file = af_ref_file
)

M_hll = 1*M_h + 2*M_l
p_hll = M_hll/M_new
q_hll = M_h/M_hll
r_hll = M_l/M_hll
s_hll = M_l/M_hll
sim_hll <- tibble(
  name = "PolyG: low A, low B, high U; Na = Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_hll, q = q_hll, r = r_hll, s = s_hll, af_ref_file = af_ref_file
)

sim_settings_polyg_equal <- dplyr::bind_rows(sim_lhh, sim_lhl, sim_lll, sim_hhh, sim_hhl, sim_hll) # %>%
#  dplyr::rowwise() %>% dplyr::mutate(vars=list(h2_to_s2(M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s))) %>%
#  dplyr::mutate(sa = vars$sa, sb= vars$sb, su = vars$su, Ma = M*p*r, Mb = M*p*s, Mu = M*p*q) %>% dplyr::select(-vars)
n_cores = 8
sjob_polyg_equal <- rslurm::slurm_apply(
  run_sim_niter, sim_settings_polyg_equal, jobname = "bimr_polyg_equal", nodes = nrow(sim_settings_polyg_equal),
  cpus_per_node = 1, slurm_options = list(time = "36:00:00", mem = "40G"),
  add_objects = c("n_cores"),
  submit = F)
# New sims with different polyG separated from conf effect sizes. Nb smaller.
h = 0.3
Na = 100000
Nb = 25000
nu_asym = c(sqrt(0.1), sqrt(0.3), sqrt(0.3))
eta_asym = c(sqrt(0.3), sqrt(0.1), sqrt(0.3))
nu_sym = c(sqrt(0.1), sqrt(0.3))
eta_sym = c(sqrt(0.3), sqrt(0.3))
M_l = 500
M_h = 2000
af_ref_file <- "/gpfs/commons/projects/UKBB/sumstats/ldsc_results/eur_w_ld_chr/19.l2.ldscore.gz"
n_iter = 100

M_lhh = 2*M_h + 1*M_l
p_lhh = M_lhh/M_new
q_lhh = M_l/M_lhh
r_lhh = M_h/M_lhh
s_lhh = M_h/M_lhh
sim_lhh <- tibble(
  name = "PolyG: high A, high B, low U; Na > Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_lhh, q = q_lhh, r = r_lhh, s = s_lhh, af_ref_file = af_ref_file
)

M_lhl = 1*M_h + 2*M_l
p_lhl = M_lhl/M_new
q_lhl = M_l/M_lhl
r_lhl = M_h/M_lhl
s_lhl = M_l/M_lhl
sim_lhl <- tibble(
  name = "PolyG: high A, low B, low U; Na > Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_asym, eta = eta_asym, gamma = 0, delta = 0,
  p = p_lhl, q = q_lhl, r = r_lhl, s = s_lhl, af_ref_file = af_ref_file
)

M_new = 500000
M_lll = 3*M_l
p_lll = M_lll/M_new
q_lll = M_l/M_lll
r_lll = M_l/M_lll
s_lll = M_l/M_lll
sim_lll <- tibble(
  name = "PolyG: low A, low B, low U; Na > Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_lll, q = q_lll, r = r_lll, s = s_lll, af_ref_file = af_ref_file
)

M_hhh = 3*M_h
p_hhh = M_hhh/M_new
q_hhh = M_h/M_hhh
r_hhh = M_h/M_hhh
s_hhh = M_h/M_hhh
sim_hhh <- tibble(
  name = "PolyG: high A, high B, high U; Na > Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_hhh, q = q_hhh, r = r_hhh, s = s_hhh, af_ref_file = af_ref_file
)

M_hhl = 2*M_h + 1*M_l
p_hhl = M_hhl/M_new
q_hhl = M_h/M_hhl
r_hhl = M_h/M_hhl
s_hhl = M_l/M_hhl
sim_hhl <- tibble(
  name = "PolyG: high A, low B, high U; Na > Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_asym, eta = eta_asym, gamma = 0, delta = 0,
  p = p_hhl, q = q_hhl, r = r_hhl, s = s_hhl, af_ref_file = af_ref_file
)

M_hll = 1*M_h + 2*M_l
p_hll = M_hll/M_new
q_hll = M_h/M_hll
r_hll = M_l/M_hll
s_hll = M_l/M_hll
sim_hll <- tibble(
  name = "PolyG: low A, low B, high U; Na > Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_hll, q = q_hll, r = r_hll, s = s_hll, af_ref_file = af_ref_file
)

sim_settings_polyg_Nb_small <- dplyr::bind_rows(sim_lhh, sim_lhl, sim_lll, sim_hhh, sim_hhl, sim_hll) # %>%
#  dplyr::rowwise() %>% dplyr::mutate(vars=list(h2_to_s2(M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s))) %>%
#  dplyr::mutate(sa = vars$sa, sb= vars$sb, su = vars$su, Ma = M*p*r, Mb = M*p*s, Mu = M*p*q) %>% dplyr::select(-vars)
n_cores = 8
sjob_polyg_Nb_small <- rslurm::slurm_apply(
  run_sim_niter, sim_settings_polyg_Nb_small, jobname = "bimr_polyg_Nb_small", nodes = nrow(sim_settings_polyg_Nb_small),
  cpus_per_node = 1, slurm_options = list(time = "36:00:00", mem = "40G"),
  add_objects = c("n_cores"),
  submit = F)
# New sims with different polyG separated from conf effect sizes. Na smaller.
h = 0.3
Na = 25000
Nb = 100000
nu_asym = c(sqrt(0.1), sqrt(0.3), sqrt(0.3))
eta_asym = c(sqrt(0.3), sqrt(0.1), sqrt(0.3))
nu_sym = c(sqrt(0.1), sqrt(0.3))
eta_sym = c(sqrt(0.3), sqrt(0.3))
M_l = 500
M_h = 2000
af_ref_file <- "/gpfs/commons/projects/UKBB/sumstats/ldsc_results/eur_w_ld_chr/19.l2.ldscore.gz"
n_iter = 100

M_lhh = 2*M_h + 1*M_l
p_lhh = M_lhh/M_new
q_lhh = M_l/M_lhh
r_lhh = M_h/M_lhh
s_lhh = M_h/M_lhh
sim_lhh <- tibble(
  name = "PolyG: high A, high B, low U; Na < Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_lhh, q = q_lhh, r = r_lhh, s = s_lhh, af_ref_file = af_ref_file
)

M_lhl = 1*M_h + 2*M_l
p_lhl = M_lhl/M_new
q_lhl = M_l/M_lhl
r_lhl = M_h/M_lhl
s_lhl = M_l/M_lhl
sim_lhl <- tibble(
  name = "PolyG: high A, low B, low U; Na < Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_asym, eta = eta_asym, gamma = 0, delta = 0,
  p = p_lhl, q = q_lhl, r = r_lhl, s = s_lhl, af_ref_file = af_ref_file
)

M_new = 500000
M_lll = 3*M_l
p_lll = M_lll/M_new
q_lll = M_l/M_lll
r_lll = M_l/M_lll
s_lll = M_l/M_lll
sim_lll <- tibble(
  name = "PolyG: low A, low B, low U; Na < Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_lll, q = q_lll, r = r_lll, s = s_lll, af_ref_file = af_ref_file
)

M_hhh = 3*M_h
p_hhh = M_hhh/M_new
q_hhh = M_h/M_hhh
r_hhh = M_h/M_hhh
s_hhh = M_h/M_hhh
sim_hhh <- tibble(
  name = "PolyG: high A, high B, high U; Na < Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_hhh, q = q_hhh, r = r_hhh, s = s_hhh, af_ref_file = af_ref_file
)

M_hhl = 2*M_h + 1*M_l
p_hhl = M_hhl/M_new
q_hhl = M_h/M_hhl
r_hhl = M_h/M_hhl
s_hhl = M_l/M_hhl
sim_hhl <- tibble(
  name = "PolyG: high A, low B, high U; Na < Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_asym, eta = eta_asym, gamma = 0, delta = 0,
  p = p_hhl, q = q_hhl, r = r_hhl, s = s_hhl, af_ref_file = af_ref_file
)

M_hll = 1*M_h + 2*M_l
p_hll = M_hll/M_new
q_hll = M_h/M_hll
r_hll = M_l/M_hll
s_hll = M_l/M_hll
sim_hll <- tibble(
  name = "PolyG: low A, low B, high U; Na < Nb",
  n_iter = n_iter, Na = Na, Nb = Nb, M = M_new,
  ha = h, hb = h, hu = h, nu = nu_sym, eta = eta_sym, gamma = 0, delta = 0,
  p = p_hll, q = q_hll, r = r_hll, s = s_hll, af_ref_file = af_ref_file
)

sim_settings_polyg_Na_small <- dplyr::bind_rows(sim_lhh, sim_lhl, sim_lll, sim_hhh, sim_hhl, sim_hll) # %>%
#  dplyr::rowwise() %>% dplyr::mutate(vars=list(h2_to_s2(M, ha, hb, hu, nu, eta, gamma, delta, p, q, r, s))) %>%
#  dplyr::mutate(sa = vars$sa, sb= vars$sb, su = vars$su, Ma = M*p*r, Mb = M*p*s, Mu = M*p*q) %>% dplyr::select(-vars)
n_cores = 8
sjob_polyg_Na_small <- rslurm::slurm_apply(
  run_sim_niter, sim_settings_polyg_Na_small, jobname = "bimr_polyg_Na_small", nodes = nrow(sim_settings_polyg_Na_small),
  cpus_per_node = 1, slurm_options = list(time = "36:00:00", mem = "40G"),
  add_objects = c("n_cores"),
  submit = F)
sjob_lcv <- rslurm::slurm_job(jobname = "bimr_lcv", nodes = nrow(sim_settings_lcv))
sjob_cause <- rslurm::slurm_job(jobname = "bimr_cause", nodes = nrow(sim_settings_cause))
sjob_polyg_equal <- rslurm::slurm_job(jobname = "bimr_polyg_equal", nodes = nrow(sim_settings_polyg_equal))
sjob_polyg_Na_small <- rslurm::slurm_job(jobname = "bimr_polyg_Na_small", nodes = nrow(sim_settings_polyg_Na_small))
sjob_polyg_Nb_small <- rslurm::slurm_job(jobname = "bimr_polyg_Nb_small", nodes = nrow(sim_settings_polyg_Nb_small))
lcv_results <- rslurm::get_slurm_out(sjob_lcv, outtype = "table")
cause_results <- rslurm::get_slurm_out(sjob_cause, outtype = "table")
polyg_equal_results <- rslurm::get_slurm_out(sjob_polyg_equal, outtype = "table")
polyg_Na_small_results <- rslurm::get_slurm_out(sjob_polyg_Na_small, outtype = "table")
polyg_Nb_small_results <- rslurm::get_slurm_out(sjob_polyg_Nb_small, outtype = "table")
lcv_results_summary <- make_summary_tibble(lcv_results)
cause_results_summary <- make_summary_tibble(cause_results)
polyg_equal_results_summary <- make_summary_tibble(polyg_equal_results)
polyg_Na_small_results_summary <- make_summary_tibble(polyg_Na_small_results)
polyg_Nb_small_results_summary <- make_summary_tibble(polyg_Nb_small_results)

polyg_all_summary <- rbind(polyg_equal_results_summary, polyg_Na_small_results_summary, polyg_Nb_small_results_summary)

```r

theme_set(theme_classic())

# show_methods = c("egger_5e-8", "ivw_5e-8", "mbe_5e-8", "aps_5e-8", "raps_5e-8", "mr_mix_5e-8", "median_5e-8",)
show_methods = c("egger_5e-8", "aaa_egger_wf_5e-8", "egger_s_5e-8", "cause_1e-3", "ivw_5e-8", "mr_mix_5e-8", "mr_presso_5e-8", "mbe_5e-8")
# show_methods = c("egger_5e-8", "egger_wa_5e-8", "egger_wf_5e-8", "egger_s_5e-8")
show_labels = c("WWER", "CAUSE", "Egger", "Steiger", "IVW", "MBE", "MR Mix", "MR PRESSO")

g1 <- lcv_results_summary %>% 
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_line(aes(x=rg, y=stat_ab, color=method)) +
  geom_errorbar(aes(ymin = stat_ab - 1.96*se_stat_ab, ymax = stat_ab + 1.96*se_stat_ab, x = rg), width = 0.025) +
  facet_wrap(vars(name), nrow=1) + 
  scale_color_brewer(labels = show_labels, palette = "Dark2") +
  labs(x = NULL, title = "A) Simulations under LCV model.", y = "FPR, A to B") + guides(color = FALSE)

g2 <- lcv_results_summary %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_line(aes(x=rg, y=stat_ba, color=method)) +
  geom_errorbar(aes(ymin = stat_ba - 1.96*se_stat_ba, ymax = stat_ba + 1.96*se_stat_ba, x = rg), width = 0.025) +
  facet_wrap(vars(name), nrow=1) +
  scale_color_brewer(labels = show_labels, palette = "Dark2", name = "Method") +
  labs(x = "Genetic correlation", title = NULL, y = "FPR, B to A") + guides(color = FALSE)

g3 <- cause_results_summary %>% 
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_line(aes(x=q, y=stat_ab, color=method)) +
  geom_errorbar(aes(ymin = stat_ab - 1.96*se_stat_ab, ymax = stat_ab + 1.96*se_stat_ab, x = q), width = 0.015) +
  facet_wrap(vars(name), nrow=1) + 
  scale_color_brewer(labels = show_labels, palette = "Dark2") +
  labs(x = NULL, title = "B) Simulations under CAUSE model.", y = "FPR, A to B", color = "Method")

g4 <- cause_results_summary %>% 
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_line(aes(x=q, y=stat_ba, color=method)) +
  geom_errorbar(aes(ymin = stat_ba - 1.96*se_stat_ba, ymax = stat_ba + 1.96*se_stat_ba, x = q), width = 0.015) +
  facet_wrap(vars(name), nrow=1) +
  scale_color_brewer(labels = show_labels, palette = "Dark2") +
  ylim(0, 1) + 
  labs(x = "Proportion of shared variants", y = "FPR, B to A") + guides(color = FALSE)

g5 <- polyg_all_summary %>%  filter(grepl("PolyG: high A, high B", name), nu == sqrt(0.3), eta == sqrt(0.3), Nb == 100000) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + facet_wrap(vars(name), nrow=1) +
  geom_col(aes(x=method, y=stat_ab, fill=method), position = "dodge") + 
  geom_errorbar(aes(ymin = stat_ab - 1.96*se_stat_ab, ymax = stat_ab + 1.96*se_stat_ab, x=method), width = 0.15) +
  scale_fill_brewer(labels = show_labels, palette = "Dark2") +
  ylim(0, 1) +
  theme(axis.text.x=element_blank(), axis.ticks.x=element_blank()) +
  labs(x = NULL, title = "C) Simulations modeling polygenicity of U.", y = "FPR, A to B") + guides(fill = FALSE)

g6 <- polyg_all_summary %>%  filter(grepl("PolyG: high A, high B", name), nu == sqrt(0.3), eta == sqrt(0.3), Nb == 100000) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + facet_wrap(vars(name), nrow=1) +
  geom_col(aes(x = method, y=stat_ba, fill=method), position = "dodge") + 
  geom_errorbar(aes(ymin = stat_ba - 1.96*se_stat_ba, ymax = stat_ba + 1.96*se_stat_ba, x = method), width = 0.15) +
  scale_fill_brewer(labels = show_labels, palette = "Dark2") +
  scale_x_discrete(labels = show_labels) + theme(axis.text.x=element_text(color = "black", angle=30, vjust=.8, hjust=0.8)) + 
  ylim(0, 1) +
  labs(x = "Method", y = "FPR, B to A") + guides(fill = FALSE)


plot <- ggarrange(g1, g2, g3, g4, g5, g6, nrow=6)
ggsave("../Figures/wwer_fig2.pdf", device = "pdf", plot = plot, dpi = 300)
plot
show_methods = c("egger_5e-8", "egger_wf_5e-8", "egger_s_5e-8", "cause_1e-3", "mr_mix_5e-8", "mbe_5e-8")
# show_methods = c("egger_5e-8", "egger_wf_5e-8", "egger_s_5e-8")
# show_methods = c("egger_wf_5e-8", "egger_5e-8", "median_5e-8", "mbe_5e-8", "raps_5e-8")

polyg_all_summary %>% 
  filter(nu == sqrt(0.3), eta == sqrt(0.3)) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_col(aes(x=method, y=stat_ab, color=method)) +
  geom_errorbar(aes(ymin = stat_ab - 1.96*se_stat_ab, ymax = stat_ab + 1.96*se_stat_ab, x=method), width = 0.15) +
  ylim(0, 1) +
  facet_wrap(vars(name), nrow=6)

polyg_all_summary %>% 
  filter(nu == sqrt(0.3), eta == sqrt(0.3)) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_col(aes(x=method, y=stat_ba, color=method)) +
  geom_errorbar(aes(ymin = stat_ba - 1.96*se_stat_ba, ymax = stat_ba + 1.96*se_stat_ba, x=method), width = 0.15) +
  ylim(0, 1) +
  facet_wrap(vars(name), nrow=6)

polyg_all_summary %>% 
  filter(nu == sqrt(0.1), eta == sqrt(0.3)) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_col(aes(x=method, y=stat_ab, color=method)) +
  geom_errorbar(aes(ymin = stat_ab - 1.96*se_stat_ab, ymax = stat_ab + 1.96*se_stat_ab, x=method), width = 0.15) +
  ylim(0, 1) +
  facet_wrap(vars(name), nrow=6)

polyg_all_summary %>% 
  filter(nu == sqrt(0.1), eta == sqrt(0.3)) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_col(aes(x=method, y=stat_ba, color=method)) +
  geom_errorbar(aes(ymin = stat_ba - 1.96*se_stat_ba, ymax = stat_ba + 1.96*se_stat_ba, x=method), width = 0.15) +
  ylim(0, 1) +
  facet_wrap(vars(name), nrow=6)

polyg_all_summary %>% 
  filter(nu == sqrt(0.3), eta == sqrt(0.1)) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_col(aes(x=method, y=stat_ab, color=method)) +
  geom_errorbar(aes(ymin = stat_ab - 1.96*se_stat_ab, ymax = stat_ab + 1.96*se_stat_ab, x=method), width = 0.15) +
  ylim(0, 1) +
  facet_wrap(vars(name), nrow=3)

polyg_all_summary %>% 
  filter(nu == sqrt(0.3), eta == sqrt(0.1)) %>%
  filter(method %in% show_methods) %>% 
  ggplot() + 
  # geom_line(aes(x=q, y=stat_ab, color=method)) + 
  geom_col(aes(x=method, y=stat_ba, color=method)) +
  geom_errorbar(aes(ymin = stat_ba - 1.96*se_stat_ba, ymax = stat_ba + 1.96*se_stat_ba, x=method), width = 0.15) +
  ylim(0, 1) +
  facet_wrap(vars(name), nrow=3)

Lets make some supp tables

sim_settings_polyg <- rbind(sim_settings_polyg_equal, sim_settings_polyg_Na_small, sim_settings_polyg_Nb_small)

res_fpr_lcv_ab <- lcv_results_summary %>% dplyr::filter(p_thresh != "5e-6", mr_method != "egger_wa") %>%
  dplyr::mutate(FPR = paste0(stat_ab, " (", round(se_stat_ab, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_fpr_lcv_ba <- lcv_results_summary %>% dplyr::filter(p_thresh != "5e-6", mr_method != "egger_wa") %>%
  dplyr::mutate(FPR = paste0(stat_ba, " (", round(se_stat_ba, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_fpr_cause_ab <- cause_results_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(stat_ab, " (", round(se_stat_ab, digits=2), ")"), study = paste(name, round(q, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_fpr_cause_ba <- cause_results_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(stat_ba, " (", round(se_stat_ba, digits=2), ")"), study = paste(name, round(q, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_fpr_polyg_ab <- polyg_all_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(stat_ab, " (", round(se_stat_ab, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_fpr_polyg_ba <- polyg_all_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(stat_ba, " (", round(se_stat_ba, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_mae_lcv_ab <- lcv_results_summary %>% dplyr::filter(p_thresh != "5e-6", mr_method != "egger_wa") %>%
  dplyr::mutate(FPR = paste0(mae_ab, " (", round(se_mae_ab, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_mae_lcv_ba <- lcv_results_summary %>% dplyr::filter(p_thresh != "5e-6", mr_method != "egger_wa") %>%
  dplyr::mutate(FPR = paste0(mae_ba, " (", round(se_mae_ba, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_mae_cause_ab <- cause_results_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(mae_ab, " (", round(se_mae_ab, digits=2), ")"), study = paste(name, round(q, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_mae_cause_ba <- cause_results_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(mae_ba, " (", round(se_mae_ba, digits=2), ")"), study = paste(name, round(q, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_mae_polyg_ab <- polyg_all_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(mae_ab, " (", round(se_mae_ab, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

res_mae_polyg_ba <- polyg_all_summary %>% dplyr::filter(p_thresh != "5e-6") %>%
  dplyr::mutate(FPR = paste0(mae_ba, " (", round(se_mae_ba, digits=2), ")"), study = paste(name, round(nu, 2), round(eta, 2))) %>%
  dplyr::select(study, mr_method, FPR) %>% tidyr::pivot_wider(names_from = mr_method, values_from = FPR)

st4 <- dplyr::left_join(res_fpr_lcv_ab, res_mae_lcv_ab, by = "study", suffix = c("_FPR", "_MAE"))
st5 <- dplyr::left_join(res_fpr_lcv_ba, res_mae_lcv_ba, by = "study", suffix = c("_FPR", "_MAE"))
st6 <- dplyr::left_join(res_fpr_cause_ab, res_mae_cause_ab, by = "study", suffix = c("_FPR", "_MAE"))
st7 <- dplyr::left_join(res_fpr_cause_ba, res_mae_cause_ba, by = "study", suffix = c("_FPR", "_MAE"))
st8 <- dplyr::left_join(res_fpr_polyg_ab, res_mae_polyg_ab, by = "study", suffix = c("_FPR", "_MAE"))
st9 <- dplyr::left_join(res_fpr_polyg_ba, res_mae_polyg_ba, by = "study", suffix = c("_FPR", "_MAE"))

outfile <- "supp_tables_null.xlsx"

supp_tables <- openxlsx::createWorkbook()
openxlsx::addWorksheet(supp_tables, "ST1")
openxlsx::addWorksheet(supp_tables, "ST2")
openxlsx::addWorksheet(supp_tables, "ST3")
openxlsx::addWorksheet(supp_tables, "ST4")
openxlsx::addWorksheet(supp_tables, "ST5")
openxlsx::addWorksheet(supp_tables, "ST6")
openxlsx::addWorksheet(supp_tables, "ST7")
openxlsx::addWorksheet(supp_tables, "ST8")
openxlsx::addWorksheet(supp_tables, "ST9")

openxlsx::writeData(supp_tables, "ST1", sim_settings_lcv)
openxlsx::writeData(supp_tables, "ST2", sim_settings_cause)
openxlsx::writeData(supp_tables, "ST3", sim_settings_polyg)
openxlsx::writeData(supp_tables, "ST4", st4)
openxlsx::writeData(supp_tables, "ST5", st5)
openxlsx::writeData(supp_tables, "ST6", st6)
openxlsx::writeData(supp_tables, "ST7", st7)
openxlsx::writeData(supp_tables, "ST8", st8)
openxlsx::writeData(supp_tables, "ST9", st9)

openxlsx::saveWorkbook(wb = supp_tables, file = outfile, overwrite = TRUE)

Summary table

# TODO(brielin): Add < 0.05 WITH SEs on both 0.05 and 0.2
all_results <- dplyr::bind_rows(lcv_results_summary, cause_results_summary, polyg_all_summary)

long_results <- all_results %>% dplyr::filter(p_thresh != "5e-6") %>%
  mutate(category = paste0(name, nu, eta, q), fpr_se_ab = stat_ab - 1.96*se_stat_ab, fpr_se_ba = stat_ba - 1.96*se_stat_ba) %>%
  dplyr::select(mr_method, category, fpr_se_ab, fpr_se_ba, runtime) %>%
  tidyr::pivot_longer(cols = c("fpr_se_ab", "fpr_se_ba"), names_to = "direction", values_to = "fpr_se") %>%
  dplyr::mutate(cat = paste0(category, direction))

long_results_mae <- all_results %>% dplyr::filter(p_thresh != "5e-6") %>%
  mutate(category = paste0(name, nu, eta, q)) %>%
  dplyr::select(mr_method, category, mae_ab, mae_ba) %>%
  tidyr::pivot_longer(cols = c("mae_ab", "mae_ba"), names_to = "direction", values_to = "mae") %>%
  dplyr::mutate(cat = paste0(category, direction))

long_results$mae <- long_results_mae$mae

result_summary_table <- long_results %>% group_by(cat) %>% mutate(rank_fpr = rank(fpr_se), rank_mae = rank(mae), perc5 = fpr_se < 0.05, perc20 = fpr_se < 0.2) %>%
  group_by(mr_method) %>% mutate(mean_fprr = round(mean(rank_fpr), 3), mean_maer = round(mean(rank_mae), 3), mean_perc5 = round(mean(perc5), 3), mean_perc20 = round(mean(perc20), 3), mean_runtime = round(mean(runtime),3)) %>%
  distinct(mr_method, mean_fprr, mean_maer, mean_perc5, mean_perc20, mean_runtime) %>% arrange(desc(mean_perc20))

knitr::kable(result_summary_table, format = "latex", booktabs = T)

Code to generate panels in Figure 1.

af_ref_file <- "/gpfs/commons/projects/UKBB/sumstats/ldsc_results/eur_w_ld_chr/19.l2.ldscore.gz"
theme_set(theme_classic())

fig_ds_null <- generate_sumstats_bimr(
  Na = 50000, Nb = 50000, M = 25000, ha = 0.3, hb = 0.3, hu = 0.3, 
  nu = sqrt(0.3), eta = sqrt(0.3), gamma = 0, delta = 0,
  p = 0.1, q = 0.33, r = 0.33, s=0.33, af_ref_file = af_ref_file)

fig_ds_alt <- generate_sumstats_bimr(
  Na = 50000, Nb = 50000, M = 25000, ha = 0.3, hb = 0.3, hu = 0.3, 
  nu = 0, eta = 0, gamma = sqrt(0.1), delta = 0,
  p = 0.1, q = 0.2, r = 0.4, s=0.4, af_ref_file = af_ref_file)

sumstats1 <- fig_ds_null$sumstats_1
sumstats2 <- fig_ds_null$sumstats_2

plot_data_1 <- tibble(
  "beta_A" = fig_ds_null$beta$A, "beta_B" = fig_ds_null$beta$B,
  "beta_hat_A" = sumstats1$beta_hat$A, "beta_hat_B" = sumstats1$beta_hat$B,
  "se_hat_A" = sumstats1$se_hat$A, "se_hat_B" = sumstats1$se_hat$B,
  "ZA" = fig_ds_null$Z$A, "ZB" = fig_ds_null$Z$B, "ZU" = fig_ds_null$Z$U) %>%
  mutate(Z_tot = ZU + ZA + ZB, sig_A = abs(beta_hat_A/se_hat_A) > 5.45, sig_B = abs(beta_hat_B/se_hat_B) > 5.45) %>%
  filter(Z_tot > 0) %>%
  mutate(label = if_else(ZU==1, "U", if_else(ZA==1 & ZB==1, "A, B", if_else(ZA==1, "A", "B")))) %>%
  mutate(welch_stat = bimmer:::welch_test(beta_hat_B, se_hat_B, beta_hat_A, se_hat_A)$t)
  # mutate(label = if_else(ZU==1, "U", if_else(ZA==1 & ZB==1, "A", "E")))


g1 <- plot_data_1 %>% ggplot(aes(beta_A, beta_B, color = label)) + geom_point() +
  labs(x = NULL, y = NULL, title = "Null, true effects") + guides(color=F)
g2 <- plot_data_1 %>% ggplot(aes(beta_hat_A, beta_hat_B, color = welch_stat)) + geom_point() +
  scale_color_gradient2() + guides(color=F) + labs(x = NULL, y = NULL, title = "Null, estimates in first dataset")

plot_data_2 <- tibble(
  "beta_hat_A" = sumstats2$beta_hat$A, "beta_hat_B" = sumstats2$beta_hat$B,
  "se_hat_A" = sumstats2$se_hat$A, "se_hat_B" = sumstats2$se_hat$B,
  "ZA" = fig_ds_null$Z$A, "ZB" = fig_ds_null$Z$B, "ZU" = fig_ds_null$Z$U) %>%
  mutate(Z_tot = ZU + ZA + ZB) %>% filter(Z_tot > 0) %>% 
  mutate(welch_stat = plot_data_1$welch_stat, weight = as.numeric(plot_data_1$sig_A), cat = "Egger")

plot_data_2_w <- plot_data_2 %>% mutate(weight = if_else((plot_data_1$welch_stat > 2)&(weight > 0), plot_data_1$welch_stat, 0), cat = "WWER")
plot_data_2 <- rbind(plot_data_2, plot_data_2_w)

g3 <- plot_data_2 %>% ggplot(aes(beta_hat_A, beta_hat_B)) +
  geom_point(aes(fill =  welch_stat), shape=21, stroke=0, size=2) + scale_fill_gradient2() +
  geom_smooth(formula = y ~ x, method = lm, aes(weight=weight, color=cat)) +
  labs(x = NULL, y = NULL, title ="Null, estimates in second dataset") + guides(fill=F, color=F)


sumstats1 <- fig_ds_alt$sumstats_1
sumstats2 <- fig_ds_alt$sumstats_2


plot_data_1 <- tibble(
  "beta_A" = fig_ds_alt$beta$A, "beta_B" = fig_ds_alt$beta$B,
  "beta_hat_A" = sumstats1$beta_hat$A, "beta_hat_B" = sumstats1$beta_hat$B,
  "se_hat_A" = sumstats1$se_hat$A, "se_hat_B" = sumstats1$se_hat$B,
  "ZA" = fig_ds_alt$Z$A, "ZB" = fig_ds_alt$Z$B, "ZU" = fig_ds_alt$Z$U) %>%
  mutate(Z_tot = ZU + ZA + ZB, sig_A = abs(beta_hat_A/se_hat_A) > 5.45, sig_B = abs(beta_hat_B/se_hat_B) > 5.45) %>%
  filter(Z_tot > 0) %>%
  mutate(label = if_else(ZU==1, "U", if_else(ZA==1 & ZB==1, "A, B", if_else(ZA==1, "A", "B")))) %>%
  mutate(welch_stat = bimmer:::welch_test(beta_hat_B, se_hat_B, beta_hat_A, se_hat_A)$t)

g4 <- plot_data_1 %>% ggplot(aes(beta_A, beta_B, color = label)) + geom_point() +
  labs(x = NULL, y = NULL, color = "SNP effect", title = "Alt, true effects")
g5 <- plot_data_1 %>% ggplot(aes(beta_hat_A, beta_hat_B, color = welch_stat)) + geom_point() +
  scale_color_gradient2(name = "Welch statistic") + labs(x = NULL, y = NULL, title = "Alt, estimates in first dataset")



plot_data_2 <- tibble(
  "beta_hat_A" = sumstats2$beta_hat$A, "beta_hat_B" = sumstats2$beta_hat$B,
  "se_hat_A" = sumstats2$se_hat$A, "se_hat_B" = sumstats2$se_hat$B,
  "ZA" = fig_ds_alt$Z$A, "ZB" = fig_ds_alt$Z$B, "ZU" = fig_ds_alt$Z$U) %>%
  mutate(Z_tot = ZU + ZA + ZB) %>% filter(Z_tot > 0) %>% 
  mutate(welch_stat = plot_data_1$welch_stat, weight = as.numeric(plot_data_1$sig_A), cat = "Egger")

plot_data_2_w <- plot_data_2 %>% mutate(weight = if_else((plot_data_1$welch_stat > 2)&(weight > 0), plot_data_1$welch_stat, 0), cat = "WWER")
plot_data_2 <- rbind(plot_data_2, plot_data_2_w)

g6 <- plot_data_2 %>% ggplot(aes(beta_hat_A, beta_hat_B)) +
  geom_point(aes(fill =  welch_stat), shape=21, stroke=0, size=2) + scale_fill_gradient2() +
  geom_smooth(formula = y ~ x, method = lm, aes(weight=weight, color=cat)) +
  labs(x = NULL, y = NULL, color = "Method", title = "Alt, estimates in second dataset") + guides(fill=F)

ggarrange(g1, g4, g2, g5, g3, g6, ncol= 2, left = "Beta B", bottom = "Beta A")


brielin/WWER documentation built on May 22, 2022, 7:28 p.m.