R/estimation.R

#' @include util.R structural_model.R
NULL

setClass(
  "EstimationBase",
  slots = c(model = "StructuralCausalModel")
)

setClass(
  "ReplicationEstimand",
  contains = "EstimationBase",
  slots = c("name" = "character")
)

setClass(
  "ReplicationCorrelationEstimand",
  contains = "ReplicationEstimand",
  slots = c("outcome1" = "character", "outcome2" = "character", cond = "ANY", condition_string = "character")
)

setClass(
  "Estimand",
  contains = "EstimationBase",
  slots = list(name = "character", outcome_group = "character", condition_string = "character")
)

setClass(
  "DiffEstimand",
  contains = "Estimand",
  slots = list(left = "Estimand", right = "Estimand")
)

setClass(
  "DiscreteEstimand",
  contains = "Estimand"
)

setClass(
  "AtomEstimand",
  contains = "DiscreteEstimand",
  slots = list(intervention = "list", outcome = "character", cond = "ANY")
)

setClass(
  "DiscreteDiffEstimand",
  contains = c("DiscreteEstimand", "DiffEstimand")
)

setClass(
  "DiscretizedEstimand",
  contains = "DiscreteEstimand",
  slots = c("direction" = "factor", "cutpoint" = "numeric")
)

setClass(
  "DiscretizedDiffEstimand",
  contains = c("DiscretizedEstimand", "DiscreteDiffEstimand")
)

setClass(
  "DiscretizedAtomEstimand",
  contains = c("DiscretizedEstimand", "AtomEstimand"),
)

setClass(
  "DiscretizedAtomEstimandCollection",
  contains = c("DiscretizedEstimand", "AtomEstimand")
)

setClass(
  "DiscretizedDiffEstimandCollection",
  contains = c("DiscretizedEstimand", "DiscreteDiffEstimand")
)

setClass(
  "DiscretizedMeanEstimand",
  contains = "Estimand",
  slots = c("group" = "DiscretizedAtomEstimandCollection")
)

setClass(
  "DiscretizedUtilityEstimand",
  contains = "Estimand",
  slots = c("group" = "DiscretizedAtomEstimandCollection")
)

setClass(
  "DiscretizedMeanDiffEstimand",
  contains = "DiffEstimand"
)

setClass(
  "DiscretizedUtilityDiffEstimand",
  contains = "DiffEstimand"
)

#' Collection of S4 estimands
#'
#' @slot estimands \code{data.frame} for internal storage of estimands.
#' @slot est_stan_info list used in Stan model.
#'
#' @export
setClass(
  "EstimandCollection",
  contains = "EstimationBase",
  slots = list(estimands = "data.frame",
               est_stan_info = "list")
)

EstimandResults <- setClass(
  "EstimandResults",
  contains = "tbl_df"
)

setMethod("initialize", "DiscretizedMeanDiffEstimand", function(.Object, ...) {
  .Object <- callNextMethod(.Object, ...)

  left_right <- list2(...)

  .Object@name <- left_right %>%
    map_chr(~ .@name) %>%
    str_c(collapse = " - ")

  .Object@outcome_group <- left_right[[1]]@outcome_group

  return(.Object)
})

setMethod("initialize", "DiscretizedUtilityDiffEstimand", function(.Object, ...) {
  .Object <- callNextMethod(.Object, ...)

  left_right <- list2(...)

  .Object@name <- left_right %>%
    map_chr(~ .@name) %>%
    str_c(collapse = " - ")

  .Object@outcome_group <- left_right[[1]]@outcome_group

  return(.Object)
})

setGeneric("calculate_from_known_dgp",
           signature = "est",
           function(est, joint_dist, ...) standardGeneric("calculate_from_known_dgp"))

setMethod("calculate_from_known_dgp", "AtomEstimand", function(est, joint_dist, prob_var = prob) {
  joint_dist@types_data %<>%
    unnest(outcomes) %>%
    mutate(full_joint_prob = !!est@cond * ex_prob * {{ prob_var }},
           full_joint_prob = full_joint_prob / sum(full_joint_prob))

  joint_dist %<>%
    set_obs_outcomes(!!!est@intervention)

  outcome_exper <- str_c(est@outcome, " == 1") %>%
      parse_expr() %>%
      as_quosure(env = global_env())

  joint_dist@types_data %>%
    filter(!!outcome_exper) %$%
    sum(full_joint_prob)
})

setMethod("calculate_from_known_dgp", "DiscreteDiffEstimand", function(est, joint_dist, prob_var = prob) {
  return(calculate_from_known_dgp(est@left, joint_dist, {{ prob_var }}) - calculate_from_known_dgp(est@right, joint_dist, {{ prob_var }}))
})


setMethod("calculate_from_known_dgp", "EstimandCollection", function(est, joint_dist, as_df = TRUE) {
  est_calculator <- function(joint_dist, prob_var, calculated_var, set_model = FALSE) {
    if (set_model) {
      joint_dist %<>%
        right_join(model, by = "latent_type_index")
    }

    est@estimands %>%
      filter(fct_match(est_type, c("atom", "diff"))) %>%
      mutate(
        {{ calculated_var }} := map_dbl(est_obj, calculate_from_known_dgp, joint_dist, prob_var = {{ prob_var }})
      ) %>%
      select(-est_obj)
  }

  calculated_est <- est_calculator(joint_dist, prob_var = prob, calculated_var = prob)

  if (!as_df) {
    calculated_est <- calculated_est$r_prob %>% purrr::set_names(calculated_est$name)
  }

  return(calculated_est)
})

setGeneric("get_discrete_estimand_info", function(est) {
  standardGeneric("get_discrete_estimand_info")
})

setMethod("get_discrete_estimand_info", "DiscreteEstimand", function(est) {
  tibble(estimand_name = est@name, outcome_group = est@outcome_group)
})

setMethod("get_discrete_estimand_info", "DiscretizedDiffEstimand", function(est) {
  tibble(estimand_name = est@name, outcome_group = est@outcome_group, cutpoint = est@cutpoint)
})

setMethod("get_discrete_estimand_info", "DiscretizedAtomEstimand", function(est) {
  tibble(estimand_name = est@name, outcome_group = est@outcome_group, cutpoint = est@cutpoint)
})

setGeneric("extract_from_fit",
           signature = c("est"),
           function(est, fit, levels, unique_ids, between_entity_diff_info, no_sim_diag, ...) standardGeneric("extract_from_fit"))

setMethod("extract_from_fit", "EstimandCollection", function(est, fit, levels, unique_ids, between_entity_diff_info, no_sim_diag, quants, ...) {
  discrete_est_info <- est@estimands

  discrete_estimation <- fit %>%
    as.array(par = "iter_estimand") %>%
    plyr::adply(3, diagnose, no_sim_diag) %>%
    tidyr::extract(parameters, "estimand_id", "(\\d+)", convert = TRUE) %>%
    mutate(iter_data = map(iter_data, ~ tibble(iter_estimand = c(.), iter_id = seq(NROW(.) * NCOL(.))))) %>%
    full_join(discrete_est_info, ., by = c("estimand_id"))

  stopifnot(!any(map_lgl(discrete_estimation$iter_data, is_empty)))

  if (!missing(levels) && !is_empty(levels) && !missing(unique_ids)) {
    discrete_level_estimation <- tryCatch(
      as.array(fit, par = "iter_level_entity_estimand"),
      error = function(err) {
        stop("Failed to find iter_level_entity_estimand parameter.")

        return(NULL)
      })

    long_entity_ids <- map_dfr(
      levels,
      ~ unique_ids %>%
        select(all_of(.x)) %>%
        distinct() %>%
        rename("entity_name" = .x) %>%
        mutate_all(lst(entity_index = as.integer)) %>%
        mutate(
          entity_name = as.character(entity_name),
          level = .x),
      .id = "level_index") %>%
      arrange(level_index, entity_index) %>%
      mutate(long_entity_index = seq(n()))

    if (!is_null(discrete_level_estimation)) {
      discrete_estimation <- discrete_level_estimation %>%
        plyr::adply(3, diagnose, no_sim_diag) %>%
        tidyr::extract(parameters, c("estimand_id", "long_entity_index"), "(\\d+),(\\d+)", convert = TRUE) %>%
        mutate(iter_data = map(iter_data, ~ tibble(iter_estimand = c(.), iter_id = seq(NROW(.) * NCOL(.))))) %>%
        summ_iter_data(quants = quants) %>%
        left_join(long_entity_ids, by = c("long_entity_index")) %>%
        group_nest(estimand_id, .key = "level_estimands") %>%
        left_join(discrete_estimation, ., by = c("estimand_id"))
    }

    if (!is_null(between_entity_diff_info)) {
      between_entity_diff_estimation <- tryCatch(
        as.array(fit, par = "iter_between_level_entity_diff_estimand"),
        error = function(err) {
          stop("Failed to find iter_between_level_entity_diff_estimand parameter.")

          return(NULL)
        })

      discrete_estimation <- between_entity_diff_estimation %>%
        plyr::adply(3, diagnose, no_sim_diag) %>%
        tidyr::extract(parameters, c("estimand_id", "diff_index"), "(\\d+),(\\d+)", convert = TRUE) %>%
        mutate(iter_data = map(iter_data, ~ tibble(iter_estimand = c(.), iter_id = seq(NROW(.) * NCOL(.))))) %>%
        summ_iter_data(quants = quants) %>%
        left_join(between_entity_diff_info, by = c("diff_index", "estimand_id")) %>%
        group_nest(estimand_id, .key = "between_entity_estimands") %>%
        left_join(discrete_estimation, ., by = "estimand_id")
    }
  }

  discrete_estimation %>%
    summ_iter_data(quants = quants) %>%
    new_tibble(nrow = nrow(.), class = "EstimandResults")
})

setGeneric("latex_tablular", function(est) standardGeneric("latex_tablular"))

setMethod("latex_tablular", "EstimandResults", function(est) {
  quants <- str_subset(names(est), "^per_0\\.\\d+$") %>%
    str_extract("0\\.\\d+") %>%
    as.numeric()
  num_quants <- length(quants)

  results_latex <- est %>%
    # mutate_at(vars(starts_with("per_0."), rhat, ess_bulk, ess_tail), ~ sprintf("%.2f", .)) %>%
    # unite("latex", starts_with("per_0."), rhat, ess_bulk, ess_tail, sep = " & ") %>%
    mutate_at(vars(starts_with("per_0.")), ~ sprintf("%.2f", .)) %>%
    unite("latex", starts_with("per_0."), sep = " & ") %>%
    mutate(
      row_color = if_else((row_number() %% 2) == 0, "\\rowcolor{gray!20}", ""),
      latex = str_c(row_color, "$", estimand_name, "$ & ", latex, "\\\\")
    ) %>%
    pull(latex) %>%
    str_c(collapse = "\n")

    #                  & \\multicolumn{${num_quants}}{c}{Percentiles} & \\multicolumn{3}{c}{Simulation Diagnostics} \\\\
    #                    \\cmidrule(l){2-${num_quants + 1}} \\cmidrule(l){${num_quants + 2}-${num_quants + 4}}
    # Estimand         & ${str_c(quants * 100, collapse = '\\\\% &')}\\% & $\\widehat{R}$ & Bulk ESS & Tail ESS \\\\

  str_c(
    str_interp(

    "\\begin{tabular}{l*{${num_quants}}{c}}
     \\toprule
                     & \\multicolumn{${num_quants}}{c}{Percentiles} \\\\
                       \\cmidrule(l){2-${num_quants + 1}}
    Estimand         & ${str_c(quants * 100, collapse = '\\\\% &')}\\% \\\\
    \\midrule"
  ),

  results_latex,

  "\\bottomrule
   \\end{tabular}"

 )

  # cat(str_interp("& \\multicolumn{${length(quants)}}{c}{Percentiles} \\\\\n"))
  # cat(str_interp("\\cmidrule(l){2-${ength(quants)}}\n"))
  # cat(str_c("Treatment Effect & ", str_c(quants * 100, collapse = "\\% &"), " \\% \\\\\n"))
  # cat("\\midrule\n")
})

setMethod("plot", c("EstimandResults"), function(x,
                                                 y = "estimand_id",
                                                 estimands = NULL,
                                                 levels = NULL,
                                                 prior_results = NULL,
                                                 zero_line = TRUE,
                                                 wrap_after_friendly_name = FALSE) {
  plot_type <- "linerange"

  estimand_friendly_renamer <- identity

  if (!is_null(prior_results)) {
    x %<>% bind_rows(fit = ., prior = prior_results, .id = "run_type")
  }

  if (!is_null(estimands)) {
    x %<>% filter(estimand_name %in% estimands)

    if (is_named(estimands)) {
      estimand_friendly_renamer <- function(est_names) {
        rename_sep <- if (wrap_after_friendly_name) "\n" else ", "

        rename_list <- estimands %>%
          purrr::set_names(str_c(names(.), ., sep = rename_sep))

        exec(fct_recode, est_names, !!!rename_list)
      }
    }
  }

  ordered_estimand_labels <- if (is_null(estimands)) unique(x$estimand_name) else estimands

  if (is_null(levels)) {
    ordered_estimand_labels %<>% rev()
  }

  x %<>%
    mutate(estimand_id = factor(estimand_id,
                                levels = estimand_id,
                                labels = estimand_name) %>%
             exec(fct_relevel, ., !!!ordered_estimand_labels) %>%
             estimand_friendly_renamer() %>%
             fct_relabel(. %>% str_replace_all(c("\\[" = "\\\\[", "\\]" = "\\\\]"))))

  if (!is_null(levels)) {
    x %<>%
      select(-one_of("iter_data", "ess_bulk", "ess_tail", "rhat"), -starts_with("per_0."), -mean) %>% # Not using population level iter_data
      unnest(level_estimands) %>%
      filter(fct_match(level, levels))
  }

  plot_obj <- if (plot_type == "density") {
    if (!is_null(levels)) stop("Density plot for levels not yet supported.")

    x %>%
      unnest(iter_data) %>%
      ggplot(aes(x = iter_estimand, y = estimand_id)) +
      ggridges::geom_density_ridges(quantile_lines = TRUE, quantiles = c(0.1, 0.25, 0.5, 0.75, 0.9)) +
      labs(x = "") +
      scale_y_discrete("", labels = latex2exp::TeX) +
      theme_minimal() +
      theme(axis.text.y = element_text(size = 14), strip.text = element_blank()) +
      NULL
  } else {
    plot_position <- ggstance::position_dodgev(height = if (!is_null(prior_results)) 0.75 else 0)

    x %>%
      ggplot(aes_string(y = y,
                 group = if (!is_null(prior_results)) "run_type" else if (!is_null(levels)) "entity_name" else NA,
                 color = if (!is_null(prior_results)) "run_type" else NULL)) +
      ggstance::geom_crossbarh(aes(x = per_0.5, xmin = per_0.25, xmax = per_0.75), width = 0.25, position = plot_position) +
      ggstance::geom_crossbarh(aes(x = per_0.5, xmin = per_0.1, xmax = per_0.9), width = 0.4, position = plot_position) +
      ggstance::geom_linerangeh(aes(xmin = per_0.05, xmax = per_0.95), fatten = 3, position = plot_position) +
      scale_y_discrete("", labels = latex2exp::TeX) +
      labs(
        x = "",
        caption = "Line range: 90% credible interval. Outer box: 80% credible interval. Inner box: 50% credible interval.
                   Thick vertical line: median.") +
      theme_minimal() +
      theme(axis.text.y = element_text(size = 12), strip.text = element_text(size = 12)) +
      NULL
  }

  if (!is_null(prior_results)) {
    plot_obj <- plot_obj +
      scale_color_manual("", values = c("fit" = "black", "prior"= "grey"), labels = c("fit" = "Fit", "prior"= "Prior Prediction")) +
      theme(legend.position = "bottom")
  }

  if (zero_line) {
    plot_obj <- plot_obj + geom_vline(xintercept = 0, linetype = "dotted")
  }

  return(plot_obj)
})

setGeneric("plot_cdf", function(data, ..., just_data = FALSE, per = seq(0.0, 0.8, 0.2)) {
  standardGeneric("plot_cdf")
})

setMethod("plot_cdf", "EstimandResults", function(data, ..., just_data = FALSE, per = seq(0.0, 0.8, 0.2)) {
  estimands <- list2(...)
  facet <- FALSE

  prep_est_data <- function(curr_estimand_name, data) {
    data %>%
      filter(estimand_name %in% curr_estimand_name) %>%
      arrange(cutpoint) %>%
      bind_rows(
        group_by(., estimand_name) %>%
          filter(estimand_id == max(estimand_id)) %>%
          mutate(cutpoint = Inf) %>%
          ungroup()
      ) %>%
      mutate(percentile = rep(per, each = 2))
  }

  if (is.list(estimands)) {
    facet <- TRUE
  } else {
    estimands %<>% list()
  }

  data %<>%
    map(estimands, prep_est_data, data = .) %>%
    bind_rows(.id = "est_group") %>%
    tidyr::extract(estimand_name, into = c("outcome", "intervention_name", "intervention_value", "condition"),
                   "Y\\^\\{(\\w+)\\}_\\{(\\w)=(\\d+)\\}(?:.+\\|\\s([^\\]]+))?", remove = FALSE) %>%
    mutate(
      condition = if_else(!is.na(condition) > 0, str_c("|", condition), ""),
      intervention = str_glue("E\\[Y^{{{outcome}}}_{{{intervention_name}={intervention_value}}} < c{condition}\\]"),
      est_group = str_glue("Pr\\[Y^{{{outcome}}}_{{{intervention_name}}} < c{condition}\\]")
  )

  if (just_data) {
     return(data)
  }

  plot_obj <- ggplot(data, aes(percentile, per_0.5)) +
    geom_step(aes(color = intervention_value)) +
    geom_area(aes(ymax = per_0.5, fill = intervention_value), alpha = 0.125, position = position_identity(),
              data = . %>%
                group_by(est_group, intervention_value) %>%
                bind_rows(mutate(., cutpoint = dplyr::lead(cutpoint), percentile = dplyr::lead(percentile))) %>%
                arrange(estimand_id) %>%
                ungroup()
    ) +
    scale_y_continuous("Cumulative Probability") +
    scale_x_continuous(latex2exp::TeX("Percentiles of observed $Y$"), breaks = data$percentile, labels = . %>% multiply_by(100)) +
    scale_color_brewer("Intervention", palette = "Dark2", labels = latex2exp::TeX) +
    scale_color_brewer("Intervention", palette = "Dark2", labels = latex2exp::TeX, aesthetics = "fill") +
    theme(legend.position = "right", panel.grid.minor.x = element_blank()) +
    NULL

  if (facet) {
    plot_obj <- plot_obj +
      facet_wrap(vars(est_group), ncol = 3, labeller = as_labeller(latex2exp::TeX, default = label_parsed))
  }

  return(plot_obj)
})

setGeneric("plot_cdf_diff", function(data, ..., just_data = FALSE, per = seq(0.0, 0.8, 0.2)) {
  standardGeneric("plot_cdf_diff")
})

# setMethod("plot_cdf_diff", "EstimandResults", function(data, ..., just_data = FALSE, per = seq(0.0, 0.8, 0.2)) {
setMethod("plot_cdf_diff", "EstimandResults", function(data, ..., just_data = FALSE) {
  estimands <- list2(...)

  data <- map(estimands, function(curr_estimand_name, data) {
    data %>%
      filter(estimand_name == curr_estimand_name) %>%
      arrange(year, cutpoint) %>%
      group_by(year) %>%
      bind_rows(
        filter(., estimand_id == max(estimand_id)) %>%
          mutate(estimand_id = max(estimand_id) + 1, cutpoint = Inf)
      ) %>%
      ungroup()
      # mutate(percentile = per)
  },
  data = data) %>%
    purrr::set_names(estimands) %>%
    bind_rows(.id = "est_group") %>%
    mutate(est_group = str_replace_all(est_group, c("\\(?E" = "Pr", "([\\[\\]])" = "\\\\\\1", "\\)$" = "")))

  if (just_data) return(data)

  plot_obj <- data %>% {
    ggplot(., aes(percentile, per_0.5)) +
      geom_hline(yintercept = 0, linetype = "dotted") +
      geom_step(direction = "hv") +
      geom_ribbon(aes(percentile, ymin = per_0.25, ymax = per_0.75), alpha = 0.25,
                  data = . %>%
                    group_by(est_group) %>%
                    bind_rows(mutate(., cutpoint = dplyr::lead(cutpoint), percentile = dplyr::lead(percentile))) %>%
                    arrange(estimand_id) %>%
                    ungroup()
      ) +
      geom_ribbon(aes(percentile, ymin = per_0.1, ymax = per_0.9), alpha = 0.25,
                  data = . %>%
                    group_by(est_group) %>%
                    bind_rows(mutate(., cutpoint = dplyr::lead(cutpoint), percentile = dplyr::lead(percentile))) %>%
                    arrange(estimand_id) %>%
                    ungroup()
      ) +
      geom_ribbon(aes(percentile, ymin = per_0.05, ymax = per_0.95), alpha = 0.25,
                  data = . %>%
                    group_by(est_group) %>%
                    bind_rows(mutate(., cutpoint = dplyr::lead(cutpoint), percentile = dplyr::lead(percentile))) %>%
                    arrange(estimand_id) %>%
                    ungroup()
      ) +
      scale_y_continuous("Cumulative Probability Difference") +
      scale_x_continuous(latex2exp::TeX("Percentiles of observed $Y$"), breaks = .$percentile, labels = . %>% multiply_by(100)) +
      labs(caption = "The black line shows the posterior median while the grey ribbons represent the 50, 80, and 90% credible intervals.") +
      theme(legend.position = "top", panel.grid.minor.x = element_blank()) +
      NULL
  }

  if (length(estimands) > 1) {
    plot_obj <- plot_obj +
      facet_wrap(vars(est_group), ncol = 2, labeller = as_labeller(latex2exp::TeX, default = label_parsed), scales = "free")
  }

  return(plot_obj)
})

setGeneric("obs_outcomes_setter<-",
           function(est, value) standardGeneric("obs_outcomes_setter<-"),
           signature = "est")

setMethod("obs_outcomes_setter<-", "EstimationBase", function(est, value) {
  est@obs_outcomes_setter <- value

  return(est)
})

setMethod("obs_outcomes_setter<-", "DiscreteDiffEstimand", function(est, value) {
  est <- callNextMethod()

  obs_outcomes_setter(est@left) <- value
  obs_outcomes_setter(est@right) <- value

  return(est)
})

setGeneric("get_range", function(est) {
  standardGeneric("get_range")
})

setMethod("get_range", "Estimand", function(est) c(0, 1))

setMethod("get_range", "DiscretizedEstimand", function(est) est@model@discretized_types[[1]]@cutpoints %>% { c(first(.), last(.)) } )

setGeneric("set_model",
           function(est, model) standardGeneric("set_model"),
           signature = "est")

setMethod("set_model", "EstimationBase", function(est, model) {
  est@model <- model

  return(est)
})

setMethod("set_model", "DiffEstimand", function(est, model) {
  est <- callNextMethod()

  est@left %<>% set_model(model)
  est@right %<>% set_model(model)

  return(est)
})

setMethod("set_model", "DiscretizedAtomEstimandCollection", function(est, model) {
  est <- callNextMethod()

  discretized_type <- model@discretized_responses[[est@outcome_group]]
  est@cutpoint <- get_discretized_cutpoints(discretized_type)
  est@direction <- discretized_type@direction
  est@outcome = names(est@cutpoint)

  intervention_string <- est@intervention %>%
    imap_chr(~ str_interp("${.y}=${.x}")) %>%
    str_c(collapse = ",")

  est@name = str_interp("Pr[Y^{${est@outcome_group}}_{${intervention_string}} ${est@direction} c${est@condition_string}]")

  return(est)
})

setGeneric("get_component_estimands",
           function(est, next_estimand_id, next_estimand_group_id) standardGeneric("get_component_estimands"))

setMethod("get_component_estimands", "Estimand", function(est, next_estimand_id, next_estimand_group_id) {
  est %>%
    get_discrete_estimand_info() %>%
    mutate(
      est_obj = list(est),
      estimand_id = next_estimand_id,
      estimand_group_id = NA_integer_
    )
})

setMethod("get_component_estimands", "DiscreteDiffEstimand", function(est, next_estimand_id) {
  left_est_data <- est@left %>% get_component_estimands(next_estimand_id)
  next_estimand_id <- left_est_data %>% pull(estimand_id) %>% max() %>% add(1)

  right_est_data <- est@right %>% get_component_estimands(next_estimand_id)
  next_estimand_id <- right_est_data %>% pull(estimand_id) %>% max() %>% add(1)

  est %>%
    get_discrete_estimand_info() %>%
    mutate(
      est_obj = list(est),
      estimand_id = next_estimand_id,
      estimand_group_id = NA_integer_,
      estimand_id_left = left_est_data$estimand_id,
      estimand_id_right = right_est_data$estimand_id,
    ) %>%
    bind_rows(left_est_data, right_est_data)
})

setMethod("get_component_estimands", "DiscretizedAtomEstimandCollection", function(est, next_estimand_id, next_estimand_group_id) {
  discretized_atoms <- est@cutpoint %>%
    imap(function(cutpoint, outcome) {
      new(
        "DiscretizedAtomEstimand",
        name = est@name,
        outcome_group = est@outcome_group,
        cutpoint = cutpoint,
        direction = est@direction,
        intervention = est@intervention,
        outcome = outcome,
        # str_c(outcome, " == 1") %>%
        #   parse_expr() %>%
        #   as_quosure(env = global_env()),
        cond = est@cond
      )
    }) %>%
    map(set_model, est@model) %>%
    map_df(~ mutate(get_discrete_estimand_info(.), est_obj = list(.))) %>%
    mutate(
      estimand_id = seq(next_estimand_id, next_estimand_id + n() - 1),
      estimand_group_id = next_estimand_group_id
    )

  next_estimand_id <- discretized_atoms %>%
    pull(estimand_id) %>%
    max() %>%
    add(1)

  atom_mean_est <- tibble(
    estimand_name = str_remove(est@name, "\\s*[<>]\\s*c\\s*") %>% str_replace("Pr\\[", "E["),
    outcome_group = est@outcome_group,
    estimand_id = next_estimand_id,
    mean_estimand_group_id = next_estimand_group_id,
    est_obj = list(new(
      "DiscretizedMeanEstimand",
      name = estimand_name,
      outcome_group = outcome_group,
      group = est
    ))
  )

  next_estimand_id <- atom_mean_est %>%
    pull(estimand_id) %>%
    max() %>%
    add(1)

  utility_est <- tibble(
    estimand_name = str_remove(est@name, "\\s*[<>]\\s*c\\s*") %>% str_replace("Pr\\[", "EU["),
    outcome_group = est@outcome_group,
    estimand_id = next_estimand_id,
    mean_estimand_group_id = next_estimand_group_id,
    est_obj = list(new(
      "DiscretizedUtilityEstimand",
      name = estimand_name,
      outcome_group = outcome_group,
      group = est
    ))
  )

  bind_rows(discretized_atoms, atom_mean_est, utility_est)
})

setMethod("get_component_estimands", "DiscretizedDiffEstimandCollection", function(est, next_estimand_id, next_estimand_group_id) {
  left_est_data <- est@left %>% get_component_estimands(next_estimand_id, next_estimand_group_id)
  next_estimand_id <- left_est_data %>% pull(estimand_id) %>% max() %>% add(1)
  next_estimand_group_id <- left_est_data %>% pull(estimand_group_id) %>% max(na.rm = TRUE) %>% add(1)

  right_est_data <- est@right %>% get_component_estimands(next_estimand_id, next_estimand_group_id)
  next_estimand_id <- right_est_data %>% pull(estimand_id) %>% max() %>% add(1)
  next_estimand_group_id <- right_est_data %>% pull(estimand_group_id) %>% max(na.rm = TRUE) %>% add(1)

  mean_diff_data <- list(left_est_data, right_est_data) %>%
    map(filter, map_lgl(est_obj, is, "DiscretizedMeanEstimand")) %>%
    map_df(select, estimand_id, est_obj) %>% {
      left_right_list <- purrr::set_names(.$est_obj, c("left", "right"))
      mean_diff_est_obj <- exec(new, "DiscretizedMeanDiffEstimand", !!!left_right_list)

      transmute(., estimand_id, name = c("left", "right")) %>%
        pivot_wider(values_from = estimand_id, names_prefix = "estimand_id_") %>%
        mutate(est_obj = list(mean_diff_est_obj),
               estimand_id = next_estimand_id,
               outcome_group = mean_diff_est_obj@outcome_group,
               estimand_name = mean_diff_est_obj@name)
    }

  next_estimand_id <- mean_diff_data %>% pull(estimand_id) %>% max() %>% add(1)

  utility_diff_data <- list(left_est_data, right_est_data) %>%
    map(filter, map_lgl(est_obj, is, "DiscretizedUtilityEstimand")) %>%
    map_df(select, estimand_id, est_obj) %>% {
      left_right_list <- purrr::set_names(.$est_obj, c("left", "right"))
      utility_diff_est_obj <- exec(new, "DiscretizedUtilityDiffEstimand", !!!left_right_list)

      transmute(., estimand_id, name = c("left", "right")) %>%
        pivot_wider(values_from = estimand_id, names_prefix = "estimand_id_") %>%
        mutate(est_obj = list(utility_diff_est_obj),
               estimand_id = next_estimand_id,
               outcome_group = utility_diff_est_obj@outcome_group,
               estimand_name = utility_diff_est_obj@name)
    }

  next_estimand_id <- utility_diff_data %>% pull(estimand_id) %>% max() %>% add(1)

  inner_join(
    select(left_est_data, est_obj, cutpoint, estimand_id) %>% filter(!is.na(cutpoint)),
    select(right_est_data, est_obj, cutpoint, estimand_id) %>% filter(!is.na(cutpoint)),
    by = c("cutpoint"), suffix = c("_left", "_right")
  ) %>%
    mutate(
      estimand_id = seq(next_estimand_id, next_estimand_id + n() - 1),
      estimand_group_id = next_estimand_group_id,
      est_obj = pmap(lst(est_obj_left, est_obj_right, cutpoint), function(est_obj_left, est_obj_right, cutpoint) {
        new(
          "DiscretizedDiffEstimand",
          name = str_c(est_obj_left@name, " - ", est_obj_right@name),
          outcome_group = est_obj_left@outcome_group,
          cutpoint = cutpoint,
          left = est_obj_left,
          right = est_obj_right
        )
      }),

      outcome_group = map_chr(est_obj, ~ .@outcome_group),
      estimand_name = map_chr(est_obj, ~ .@name),
    ) %>%
    select(-matches("(est_obj)_(left|right)$")) %>%
    bind_rows(left_est_data, right_est_data, mean_diff_data, utility_diff_data)
})

setGeneric("num_estimands",
           function(est, est_class = "ANY") standardGeneric("num_estimands"))

setMethod("num_estimands", "EstimandCollection", function(est, est_class = NA) {
  if (all(is.na(est_class))) {
    nrow(est@estimands)
  } else {
    filter(est@estimands, fct_match(est_type, est_class)) %>%
      nrow()
  }
})

setGeneric("get_stan_data_structures",
           function(est, ...) standardGeneric("get_stan_data_structures"),
           signature = "est")

setMethod("get_stan_data_structures", "AtomEstimand", function(est) {
  mask_group_by <- est@model@types_data %>%
    names() %>%
    setdiff("outcomes")

  abducted_mask <- est@model@types_data %>%
    unnest(outcomes) %>%
    mutate(abducted_mask = !!est@cond) %>%
    group_by_at(vars(all_of(mask_group_by))) %>%
    mutate(abducted_type_mask = any(abducted_mask)) %>%
    pull(abducted_type_mask)

  est_prob_index <- est@model %>%
    get_prob_indices(est@outcome, !!!est@intervention)

  lst(
    abducted_prob_size = if (all(abducted_mask)) 0 else sum(abducted_mask),
    abducted_prob_index = if (abducted_prob_size > 0) which(abducted_mask), # Row major index

    est_prob_index = if (abducted_prob_size > 0) intersect(est_prob_index, abducted_prob_index) else est_prob_index,
    est_prob_size = as.array(length(est_prob_index)),
  )
})

setMethod("get_stan_data_structures", "Estimand", function(est) {
  return(NULL)
})

setMethod("get_stan_data_structures", "EstimandCollection", function(est, cores = 1) {
  map_fun <- if (cores > 1) partial(pbmcapply::pbmclapply, ignore.interactive = TRUE, mc.silent = TRUE, mc.cores = cores) else map

  stan_data_structures <- est@estimands$est_obj %>%
    map_fun(get_stan_data_structures) %>%
    reduce(function(accum, to_add) list_merge(accum, !!!to_add))

  return(stan_data_structures)
})

setGeneric("get_stan_info", function(est) standardGeneric("get_stan_info"))

setMethod("get_stan_info", "EstimandCollection", function(est) est@est_stan_info)

setGeneric("add_between_level_entity_diff_estimands", function(est, levels, analysis_data) {
  standardGeneric("add_between_level_entity_diff_estimands")
})

setMethod("add_between_level_entity_diff_estimands", "EstimandCollection", function(est, levels, analysis_data) {
  # Do nothing
})

#' Create S4 instance of estimand for a discrete variable
#'
#' @param outcome Variable name
#' @param ... Intervention
#' @param cond Conditional
#' @param cond_desc Description of conditional
#'
#' @return \code{AtomEstimand} S4 object
#' @export
build_atom_estimand <- function(outcome, ..., cond, cond_desc) {
  intervention <- list2(...)
  intervention_string <- if (!is_empty(intervention)) {
    intervention %>%
      imap_chr(~ str_interp("${.y}=${.x}")) %>%
      str_c(collapse = ",") %>%
      str_c("_{", ., "}")
  } else ""

  condition_string <- if (!missing(cond_desc)) {
      str_c(" | ", cond_desc)
    } else if (!missing(cond)) {
      as_label(enquo(cond)) %>%
        str_split("\\s*&\\s*", simplify = TRUE) %>%
        c() %>%
        map_chr(str_to_upper) %>%
        map(str_replace_all, fixed("=="), "=") %>%
        str_c(., collapse = ", ") %>%
        str_c(" | ", .)
    } else ""

  new(
    "AtomEstimand",
    name = str_interp("E[${str_to_upper(outcome)}${intervention_string}${condition_string}]"),
    outcome_group = outcome,
    intervention = intervention,
    outcome = outcome,
    # outcome = str_c(outcome, " == 1") %>%
    #   parse_expr() %>%
    #   as_quosure(env = global_env()),
    cond = if (!missing(cond)) enquo(cond) else TRUE,
    condition_string = condition_string)
}

build_replication_correlation_estimand <- function(outcome1, outcome2, cond = NA_character_) {
  condition_string <- if (!is.na(cond)) {
    str_c(" | ", str_to_upper(cond))
    # as_label(enquo(cond)) %>%
    #   str_split("\\s*&\\s*", simplify = TRUE) %>%
    #   c() %>%
    #   map_chr(str_to_upper) %>%
    #   # map(str_replace_all, fixed("=="), "=") %>%
    #   # str_c(., collapse = ", ") %>%
    #   str_c(" | ", .)
  } else ""

  new(
    "ReplicationCorrelationEstimand",

    name = str_interp("SampleCor[${str_to_upper(outcome1)}, ${str_to_upper(outcome2)}${condition_string}]"),
    outcome1 = outcome1,
    outcome2 = outcome2,
    # cond = if (!missing(cond)) enquo(cond) else TRUE,
    cond = cond,
    condition_string = condition_string
  )
}

#' Create an S4 instance of difference estimand for discrete variables
#'
#' @param left Left estimand to difference from
#' @param right  Right estimand to difference out
#'
#' @return \code{DiscreteDiffEstimand} S4 object
#' @export
build_diff_estimand <- function(left, right) {
  diff_outcome_group <- union(left@outcome_group, right@outcome_group)

  stopifnot(length(diff_outcome_group) == 1)

  new("DiscreteDiffEstimand", name = str_interp("${left@name[1]} - ${right@name[1]}"), outcome_group = diff_outcome_group, left = left, right = right)
}

#' Create an S4 instance for an estimand of a discretized variable
#'
#' @param outcome_group Name of discretized variable group .
#' @param ... Intervention.
#' @param cond Conditional.
#' @param cond_desc String description of conditional.
#'
#' @return \code{DiscretizedAtomEstimandCollection} S4 object
#' @export
build_discretized_atom_estimand <- function(outcome_group, ..., cond, cond_desc) {
  new(
    "DiscretizedAtomEstimandCollection",
    outcome_group = outcome_group,
    intervention = list(...),
    cond = if (!missing(cond)) enquo(cond) else TRUE,
    condition_string = if (!missing(cond_desc)) {
      str_c(" | ", cond_desc)
    } else if (!missing(cond)) {
      as_label(enquo(cond)) %>%
        str_split("\\s*&\\s*", simplify = TRUE) %>%
        c() %>%
        map_chr(str_to_upper) %>%
        map(str_replace_all, fixed("=="), "=") %>%
        str_c(., collapse = ", ") %>%
        str_c(" | ", .)
    } else ""
  )
}

#' Create an S4 instance for a different estimand for a discretized variable
#'
#' @param left Left estimand, to difference from
#' @param right  Right estimand, to difference out
#'
#' @return \code{DiscretizedDiffEstimandCollection} S4 object
#' @export
build_discretized_diff_estimand <- function(left, right) {
  new(
    "DiscretizedDiffEstimandCollection",
    name = str_interp("${left@name[1]} - ${right@name[1]}"),
    left = left,
    right = right,
    outcome_group = NA_character_)
}
karimn/boundr documentation built on March 1, 2021, 6:57 p.m.