R/est_marg_nuisance_params.R

Defines functions est_marg_nuisance_params

Documented in est_marg_nuisance_params

#' @title Estimate nuisance parameters for each marginal mixture component
#'
#' @description For each marginal mixture component rule found, create a
#' g estimator for the probability of being exposed to the rule thresholds,
#' and a Q estimator for the outcome E(Y| A = a_mix, W).
#' Get estimates of g and Q using the validation data and
#' calculate the clever covariate used in the TMLE fluctuation step.
#'
#' @param at Training data
#' @param av Validation data
#' @param w Vector of characters denoting covariates
#' @param y Outcome variable
#' @param aw_stack Super Learner library for fitting Q (outcome mechanism)
#' and g (treatment mechanism)
#' @param family Binomial or continuous
#' @param a Vector of characters that denote the mixture components
#' @param marg_decisions List of rules found within the fold for each
#' mixture component
#' @param h_aw_trunc_lvl Truncation level of the clever covariate
#' (induces more bias to reduce variance)
#' @param no_marg_rules TRUE/FALSE if no marginal rules were
#' found across the folds
#' @param parallel_cv TRUE/FALSE if cv parallelization is used
#' @param seed Seed number
#' @importFrom magrittr %>%
#' @importFrom rlang :=

#' @importFrom dplyr group_by filter top_n
#' @return marginal_data: A list of data frames for each mixture component
#' with the baseline covariates, exposure, outcome, and nuisance parameters
#' needed to calculate the ATE.
#'
#' @export

est_marg_nuisance_params <- function(at,
                                     av,
                                     w,
                                     y,
                                     aw_stack,
                                     family,
                                     a,
                                     no_marg_rules,
                                     marg_decisions,
                                     parallel_cv,
                                     seed,
                                     h_aw_trunc_lvl) {
  if (parallel_cv == TRUE) {
    future::plan(future::sequential, gc = TRUE)
  }

  set.seed(seed)

  marginal_data <- list()

  if (no_marg_rules == FALSE) {
    marg_decisions_groups <- marginal_group_split(marg_decisions)

    for (i in seq(marg_decisions_groups)) {
      at_c <- at
      av_c <- av

      variable_decisions <- marg_decisions_groups[[i]]

      quant_one_row <- variable_decisions[variable_decisions$quantile == 1, ]
      quant_one_rule <- quant_one_row$rules

      at_c_ref_data <- at_c %>%
        dplyr::mutate("a" := ifelse(eval(parse(text = quant_one_rule)), 1, 0))

      av_c_ref_data <- av_c %>%
        dplyr::mutate("a" := ifelse(eval(parse(text = quant_one_rule)), 1, 0))

      at_c_ref_data <- at_c_ref_data[at_c_ref_data[, "a"] == 1, ]
      av_c_ref_data <- av_c_ref_data[av_c_ref_data[, "a"] == 1, ]

      at_c_ref_data[, "a"] <- 0
      av_c_ref_data[, "a"] <- 0

      quant_comparisons <- variable_decisions[variable_decisions$quantile > 1, ]

      in_group_marg_data <- list()

      for (j in seq(nrow(quant_comparisons))) {
        target_m_row <- quant_comparisons[j, ]

        at_c_comp_data <- at_c %>%
          dplyr::mutate("a" := ifelse(eval(parse(
            text =
              target_m_row$rules
          )), 1, 0))

        av_c_comp_data <- av_c %>%
          dplyr::mutate("a" := ifelse(eval(parse(
            text =
              target_m_row$rules
          )), 1, 0))

        at_c_comp_data <- at_c_comp_data[at_c_comp_data[, "a"] == 1, ]
        av_c_comp_data <- at_c_comp_data[at_c_comp_data[, "a"] == 1, ]

        at_data <- rbind(at_c_ref_data, at_c_comp_data)
        av_data <- rbind(av_c_ref_data, av_c_comp_data)

        task_at <- sl3::make_sl3_Task(
          data = at_data,
          covariates = w,
          outcome = "a",
          outcome_type = "binomial",
          folds = 2
        )

        task_av <- sl3::make_sl3_Task(
          data = av_data,
          covariates = w,
          outcome = "a",
          outcome_type = "binomial",
          folds = 2
        )

        sl <- sl3::Lrnr_sl$new(
          learners = aw_stack,
          metalearner = sl3::Lrnr_nnls$new()
        )

        sl_fit <- suppressWarnings(sl$train(task_at))

        ghat_1w <- bound_precision(sl_fit$predict(task_av))

        h_aw <- calc_clever_covariate(
          ghat_1_w = ghat_1w,
          data = av_data,
          exposure = "a",
          h_aw_trunc_lvl = h_aw_trunc_lvl,
          type = "reg"
        )

        av_data$ghat_1w <- ghat_1w
        av_data$h_aw <- h_aw

        task_at <- sl3::make_sl3_Task(
          data = at_data,
          covariates = c(w, "a"),
          outcome = y,
          outcome_type = family,
          folds = 2
        )

        x_m1 <- x_m0 <- av_data
        x_m1$a <- 1 # under exposure
        x_m0$a <- 0 # under control

        task_av <- sl3::make_sl3_Task(
          data = av_data,
          covariates = c(w, "a"),
          outcome = y,
          outcome_type = family
        )

        task_av_1 <- sl3::make_sl3_Task(
          data = x_m1,
          covariates = c(w, "a"),
          outcome = y,
          outcome_type = family
        )

        task_av_0 <- sl3::make_sl3_Task(
          data = x_m0,
          covariates = c(w, "a"),
          outcome = y,
          outcome_type = family
        )

        sl <- sl3::Lrnr_sl$new(
          learners = aw_stack,
          metalearner = sl3::Lrnr_nnls$new()
        )

        sl_fit <- suppressWarnings(sl$train(task_at))

        q_bar_aw <- sl_fit$predict(task_av)
        q_bar_1w <- sl_fit$predict(task_av_1)
        q_bar_0w <- sl_fit$predict(task_av_0)

        av_data$qbar_aw <- q_bar_aw
        av_data$qbar_1w <- q_bar_1w
        av_data$qbar_0w <- q_bar_0w

        in_group_marg_data[[j]] <- av_data
      }
      marginal_data[[i]] <- in_group_marg_data
    }

    marginal_data <- unlist(marginal_data, recursive = FALSE, use.names = FALSE)
  } else {
    av$ghat_1w <- NA
    av$h_aw <- NA
    av$qbar_aw <- NA
    av$qbar_1w <- NA
    av$qbar_0w <- NA
    marginal_data[[1]] <- av
  }

  return(marginal_data)
}
blind-contours/CVtreeMLE documentation built on June 22, 2024, 8:53 p.m.