knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(magrittr)
library(conservationSpillover)
library(mgcv)
library(sf)
basisdt    <- readRDS("data/simulation_basisdt.rds")
simulator0 <- readRDS("data/plasmode_simulator_0.rds")
ids        <- readRDS("data/bas_3000_sample.rds")
# TODO : why are there NAs in ids?
ids <- ids[!is.na(ids)]
create_sim_basis <- create_sim_basis_maker(basisdt)
naivimator <- function(data){
  rep(1, nrow(data))
}

pw_glminator_1 <- make_pw_estimator(
  fitter = glm_fitter,
  rhs_formula = ~ (A + A_tilde)*(
     log(ha)*slope +
     I(slope == 0) +
     I(p_wet == 0) + p_wet +
     I(neighbor_p_wet == 0) + neighbor_p_wet +
     log1p(coast_2500) + travel),
  B = 10
)

pw_glminator_2 <- make_pw_estimator(
  fitter = glm_fitter,
  rhs_formula = ~ (A + A_tilde)*(
     ha*slope +
     p_wet + neighbor_p_wet +
     coast_2500 + travel),
  B = 10
)

pw_xgbinator_1 <- make_pw_estimator(
  xg_fitter,
  B = 5,
  rhs_formula = ~ A + A_tilde + log_ha + slope + slope_0 + p_wet + p_wet_0 + neighbor_p_wet + 
    neighbor_p_wet_0 + log1p_coast_2500 + travel,
  predictor =  function(object, newdata){  
    vars <- attr(terms(C ~ A + A_tilde + log_ha + slope + slope_0 + p_wet + p_wet_0 +
                         neighbor_p_wet + neighbor_p_wet_0 + log1p_coast_2500 + travel), 
                 "term.labels")
    X <- data.matrix(newdata[ , vars])
    predict(object, newdata = X)
  },
  fitter_args = list(nrounds = 50,
                     # objective = "reg:logistic",
                     objective = "binary:logistic",
                     eval_metric = "logloss",
                     max_depth = 15,
                     verbose = 0)
)

pw_xgbinator_2 <- make_pw_estimator(
  xg_fitter,
  B = 5,
  rhs_formula = ~ A + A_tilde + ha + slope + p_wet + neighbor_p_wet  + coast_2500 + travel,
  predictor =  function(object, newdata){  
    vars <- 
      attr(terms(C ~ A + A_tilde + ha + slope + p_wet  + neighbor_p_wet + coast_2500 + travel),
                 "term.labels")
    X <- data.matrix(newdata[ , vars])
    predict(object, newdata = X)
  },
  fitter_args = list(nrounds = 50,
                     # objective = "reg:logistic",
                     objective = "binary:logistic",
                     eval_metric = "logloss",
                     max_depth = 15,
                     verbose = 0)
)

ipw_glminator_c <- make_ipw_estimator(
  fitter = glm_fitter,
  rhs_formula = ~ log(ha)*slope + 
    I(slope == 0) + 
    I(p_wet == 0) + p_wet + I(p_wet^2) + I(p_wet^3)  +
    I(neighbor_p_wet == 0) + neighbor_p_wet + I(neighbor_p_wet^2) + I(neighbor_p_wet^3)  +
    log1p(coast_2500)*travel +
    I(log(ha)^2) + I(log(ha)^3) * 
    I(slope^2) + I(slope^3)
)

ipw_gaminator_c <- make_ipw_estimator(
  gam_fitter,
  rhs_formula = ~ log(ha)*slope +
     I(p_wet == 0) + p_wet +
     I(neighbor_p_wet == 0) + neighbor_p_wet +
     log1p(coast_2500) + travel +
     s(log(ha)) + s(slope) + s(p_wet) + s(travel)
)

ipw_gaminator_m <- make_ipw_estimator(
  gam_fitter,
  rhs_formula = ~ log(ha) + slope +
     p_wet + neighbor_p_wet +
     log1p(coast_2500) + travel +
    s(p_wet) +
    s(travel)
)

weight_estimators <- 
  list(
    naivimator       = naivimator,
    pw_glminator_1   = pw_glminator_1,
    pw_glminator_2   = pw_glminator_2,
    pw_xgbinator_1   = pw_xgbinator_1,
    pw_xgbinator_2   = pw_xgbinator_2,
    ipw_glminator_c  = ipw_glminator_c,
    ipw_gaminator_c  = ipw_gaminator_c,
    ipw_gaminator_m  = ipw_gaminator_m
  )
null_A <- function(b, x){
  ifelse(
    grepl("^A", names(b)), 0, b
  ) %>%
    setNames(names(b))
}

de_A <- function(b, x){
  dplyr::case_when(
    names(b) == "A" ~ 1,
    names(b) == "A_tilde" ~ 0,
    names(b) == "A:A_tilde" ~ 0,
    TRUE ~ b
  ) %>%
    setNames(names(b))
}

ie_A <- function(b, x){
  dplyr::case_when(
    names(b) == "A" ~ 0,
    names(b) == "A_tilde" ~ 0.5,
    names(b) == "A:A_tilde" ~ 0,
    TRUE ~ b
  ) %>%
    setNames(names(b))
}

de_ie_A <- function(b, x){
  dplyr::case_when(
    names(b) == "A" ~ 1,
    names(b) == "A_tilde" ~ 0.5,
    names(b) == "A:A_tilde" ~ 0,
    TRUE ~ b
  ) %>%
    setNames(names(b))
}

de_ie_int_A <- function(b, x){
  dplyr::case_when(
    names(b) == "A" ~ 1,
    names(b) == "A_tilde" ~ 0.5,
    names(b) == "A:A_tilde" ~ -0.25,
    TRUE ~ b
  ) %>%
    setNames(names(b))
}

slope_10x <- function(b, x){
  ifelse(
    grepl("^slope", names(b)), 50 * b, b
  ) %>%
    setNames(names(b))
}

slope_2x <- function(b, x){
  ifelse(
    grepl("^slope", names(b)), 2 * b, b
  ) %>%
    setNames(names(b))
}

fiveX_pwet <- function(b, x){
  ifelse(
    grepl("^I(slope == 0)", names(b)), b, 10*b
  ) %>%
    setNames(names(b))
}

double_notA <- function(b, x){
  ifelse(
    grepl("^A", names(b)), b, 2*b
  ) %>%
    setNames(names(b))
}

triple_notA <- function(b, x){
  ifelse(
    grepl("^A", names(b)), b, 3*b
  ) %>%
    setNames(names(b))
}

tenX_notA <- function(b, x){
  ifelse(
    grepl("^A", names(b)), b, 10*b
  ) %>%
    setNames(names(b))
}

tenX_slope <- function(b, x){
  ifelse(
    grepl("^slope$", names(b)), b, 10*b
  ) %>%
    setNames(names(b))
}

double_notA_notSpline <- function(b, x){
  ifelse(
    grepl("(^A|\\.[0-9]+)", names(b)), b, 2*b
  ) %>%
    setNames(names(b))
}

fix_marginal_A_to_popmean <- function(b, x){

  intcpt_finder <- function(b0, b, x, target) {
    mean(plogis(x %*% c(b0, b[-1]))) - target
  }

  b0 <- uniroot(
    intcpt_finder,
    b = b,
    x = x,
    target = 0.02,
    lower = -10,
    upper = 10
  )$root

  c(`(Intercept)` = b0, b[-1])
}


sim_settings <-
  purrr::cross2(
    .x = list(id = function(b, x) { b }, 
              pop =function(b, x) {
                slope_10x(b, x)
                # %>%
                  # fix_marginal_A_to_popmean(x)
                # tenX_slope(b, x) 
                # double_notA_notSpline(b, x)
                # %>%
                # fix_marginal_A_to_popmean(b, x)
              }),
    .y = purrr::cross2(
      .x = list(null = null_A, de = de_A, ie = ie_A, 
                de_ie = de_ie_A, de_ie_int_A = de_ie_int_A),
      .y = list(
        # double_notA = tenX_slope
        # id = function(b, x) { b }, 
                slope_10x = fiveX_pwet
                )
      ) %>% purrr::map(~ purrr::compose(!!! .x))
  ) %>%
  purrr::map(
    ~ set_names(.x, c("exposure_newparms", "outcome_newparms"))
  )
msm_estimator <- function(data, weights){
  glm(Y ~ A*A_tilde,  data = data, weights = weights, family = gaussian) %>%
    broom::tidy()
}

is_error <- function(x){
  !is.null(x[["error"]])
}

estimate_from_sim <- function(simdt, weight_estimators, causal_estimator){

  dt <- simdt$data
  weights <- 
    purrr::imap(
      .x = weight_estimators,
      .f = ~ {
        print(sprintf("estimating weights by %s", .y))
        purrr::safely(.x)(dt)
      } 
    )

  estimates <-
    purrr::map(
      .x = weights,
      .f = ~ {

        if(is_error(.x)){
            out <- list(
              estimates = data.frame(),
              smd = data.frame(),
              weight_summary = .x
            )
          out 
        }

        w <- .x[["result"]]

        smd <- purrr::safely(smd::smd)(
          x = dt[, c("log_ha", "slope", "travel", "p_wet", "log1p_coast_2500")], 
          g = dt$A,
          w = w)

        estimate <- purrr::safely(causal_estimator)(dt, weights = w)

        list(
          estimates = estimate,
          smd = smd,
          weight_summary = summary(w)
        )
      }
    )
}
# test <- create_sim_basis(seed_id = ids[1], 10)
test <- dir("data/sims", full.names = TRUE) %>% 
  .[6] %>%
  readRDS()

x <- do.call(simulator0, args = c(list(newdata = test$data), sim_settings[[8]]))
mean(x$A)

# hist(x$Y)

x %>%
  group_by(A) %>%
  summarise(mean(Y))
lm(Y~A*A_tilde,  data = x)
x <- 
  x %>%
  mutate(
      log_ha = log(ha),
      slope_0 = (slope == 0),
      p_wet_0 = (p_wet == 0),
      neighbor_p_wet_0 = (neighbor_p_wet == 0),
      log1p_coast_2500 = log1p(coast_2500)
  )
ww <- ipw_glminator_c(x)  
lm(Y~A*A_tilde,  data = x, weights=ww)

yyy <-
estimate_from_sim(
  simdt = list(data = x),
  # simdt = list(data = x[sample(1:nrow(x), 3000, replace = FALSE), ]),
  weight_estimators = weight_estimators,
  causal_estimator = msm_estimator
)

oracle <- c(1, 0.5, -0.00)
purrr::map_dfr(
  yyy,
  ~ dplyr::filter(.x[["estimates"]]$result, term != "(Intercept)") %>%
    dplyr::mutate(
      bias = estimate - oracle
    ),
  .id = "estimator"
) %>%
  mutate(ab_bias = abs(bias)) %>%
  View()
set.seed(34)
nn <- 50
buffers <- sample(5:15, size = nn, replace = TRUE)

purrr::walk2(
  .x = ids[1:nn],
  .y = buffers,
  .f = ~ {
    make_sim_basis_data(
      create_sim_basis, id = .x, buffer = .y, 
      checker = function(hash) {

        test <- file.exists(here::here("data/sims", sprintf("%s.rds", hash)))
        if (test) { print(sprintf("%s exists", hash)) }
        test
      },
      redirector = function(out, sha) {
        saveRDS(out, file = here::here("data/sims", sprintf("%s.rds", sha)))
      }
    )
  }
)
  dir("data/sims", full.names = TRUE) %>% 
  purrr::map_dfr(
    .f = ~ {
      dt <- readRDS(.x)
      tibble(!!! dt[c("n", "buffer")])
    },
    .id = "simid"
  )
sim_res <- 
  dir("data/sims", full.names = TRUE) %>% 
  .[1:25] %>%
  purrr::map(
    .f = ~ {
      dt <- readRDS(.x)

      dt$data <- dt$data %>%
        dplyr::mutate(
          log_ha = log(ha),
          p_wet_0 = (p_wet == 0),
          neighbor_p_wet_0 = (neighbor_p_wet == 0),
          log1p_coast_2500 = log1p(coast_2500)
        ) 

      if (nrow(dt$data) > 3000){
        dt$data <- dt$data[sample(1:(nrow(dt$data)), 3000, replace = FALSE), ]
      }

      purrr::imap(
        .x = sim_settings,
        .f = ~ { 
          print(sprintf("do sim setting %s", .y))
          make_sim(basedt = dt$data, simulator = simulator0, parms = .x) %>%
            estimate_from_sim(
              weight_estimators = weight_estimators,
              causal_estimator = msm_estimator
            )
        }
      )
    }
  )
res <-
purrr::map_dfr(
  .x = sim_res,
  .f = ~ {
    purrr::map_dfr(
      .x = .x,
      .f = ~ {
        purrr::map_dfr(
          .x, 
          .f = ~ purrr::pluck(.x, "estimates", "result"),
          .id = "estimator"
        )
       },
      .id = "setting"
    )
  },
  .id = "sim"
)

res <-
  res %>%
  filter(term != "(Intercept)") %>%
  group_by(setting) %>%
  mutate(
    bias = case_when(
      setting %in% c("1", "2")  ~ estimate - c(0, 0, 0), 
      setting %in% c("3", "4")  ~ estimate - c(0.003, 0, 0), 
      setting %in% c("5", "6")  ~ estimate - c(0, 0.002, 0), 
      setting %in% c("7", "8")  ~ estimate - c(0.003, 0.002, 0), 
      setting %in% c("9", "10") ~ estimate - c(0.003, 0.002, -0.001),
    ),
    ab_bias = abs(bias)
  )


library(ggplot2)

res %>% 
  group_by(
    setting, term, estimator
  ) %>%
  summarise(
    mab = mean(abs(bias)),
    rmse = sqrt(mean(bias^2)),
  ) %>%
  tidyr::pivot_wider(
    names_from = "setting",
    values_from = c("mab", "rmse")
  )

res %>%
  filter(term == "A") %>%
  ggplot(aes(y = bias, x = setting)) +
    geom_hline(yintercept = 0) +
    geom_point() +
    facet_grid(
      . ~ estimator,
      scales = "free"
    )
  # dir("data/sims", full.names = TRUE) %>% 
  # # .[1:1] %>%
  # purrr::map(
  #   .f = ~ {
  #     dt <- readRDS(.x)
  #     dt$mean_A})


bsaul/conservation_spillover documentation built on Feb. 26, 2021, 1:28 p.m.