R/estimators.R

Defines functions est_tml est_onestep two_phase_eif est_plugin cv_eif

Documented in cv_eif est_onestep est_plugin est_tml two_phase_eif

utils::globalVariables(c("..w_names", "A", "Z", "Y", "R", "v_star"))

#' EIF for natural and interventional (in)direct effects
#'
#' @param fold Object specifying cross-validation folds as generated by a call
#'   to \code{\link[origami]{make_folds}}.
#' @param data_in A \code{data.table} containing the observed data with columns
#'   are in the order specified by the NPSEM (Y, M, R, Z, A, W), with column
#'   names set appropriately based on the data. Such a structure is merely a
#'   convenience utility to passing data around to the various core estimation
#'   routines and is automatically generated by \code{\link{medoutcon}}.
#' @param contrast A \code{numeric} double indicating the two values of the
#'   intervention \code{A} to be compared. The default value of \code{NULL} has
#'   no effect, as the value of the argument \code{effect} is instead used to
#'   define the contrasts. To override \code{effect}, provide a \code{numeric}
#'   double vector, giving the values of a' and a*, e.g., \code{c(0, 1)}.
#' @param g_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a model for the propensity score.
#' @param h_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a model for a parameterization of the
#'   propensity score that conditions on the mediators.
#' @param b_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a model for the outcome regression.
#' @param q_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#'   the intermediate confounder, conditioning on the treatment and potential
#'   baseline covariates.
#' @param r_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#'   the intermediate confounder, conditioning on the mediators, the treatment,
#'   and potential baseline confounders.
#' @param u_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#'   for in the efficient influence function.
#' @param v_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#'   for in the efficient influence function.
#' @param d_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit an initial efficient influence
#'   function regression when computing the efficient influence function in a
#'   two-phase sampling design.
#' @param effect_type A \code{character} indicating whether components of the
#'   interventional or natural (in)direct effects are to be estimated. In the
#'   case of the natural (in)direct effects, estimation of several nuisance
#'   parameters is unnecessary.
#' @param w_names A \code{character} vector of the names of the columns that
#'   correspond to baseline covariates (W). The input for this argument is
#'   automatically generated by \code{\link{medoutcon}}.
#' @param m_names A \code{character} vector of the names of the columns that
#'   correspond to mediators (M). The input for this argument is automatically
#'   generated by \code{\link{medoutcon}}.
#' @param g_bounds A \code{numeric} vector containing two values, the
#'   first being the minimum allowable estimated propensity score value and the
#'   second being the maximum allowable for estimated propensity score value.
#'
#' @importFrom assertthat assert_that
#' @importFrom data.table data.table copy
#' @importFrom origami training validation fold_index
#' @importFrom sl3 Lrnr_mean
#'
#' @keywords internal
cv_eif <- function(fold,
                   data_in,
                   contrast,
                   g_learners,
                   h_learners,
                   b_learners,
                   q_learners,
                   r_learners,
                   u_learners,
                   v_learners,
                   d_learners,
                   effect_type = c("interventional", "natural"),
                   w_names,
                   m_names,
                   g_bounds = c(0.005, 0.995)) {
  # make training and validation data
  train_data <- origami::training(data_in)
  valid_data <- origami::validation(data_in)

  # 1) fit regression for propensity score regression
  g_out <- fit_treat_mech(
    train_data = train_data,
    valid_data = valid_data,
    contrast = contrast,
    learners = g_learners,
    w_names = w_names,
    m_names = m_names,
    type = "g",
    bounds = g_bounds
  )

  # 2) fit clever regression for treatment, conditional on mediators
  h_out <- fit_treat_mech(
    train_data = train_data,
    valid_data = valid_data,
    contrast = contrast,
    learners = h_learners,
    w_names = w_names,
    m_names = m_names,
    type = "h",
    bounds = g_bounds
  )

  # 3) fit outcome regression
  b_out <- fit_out_mech(
    train_data = train_data,
    valid_data = valid_data,
    contrast = contrast,
    learners = b_learners,
    m_names = m_names,
    w_names = w_names
  )

  # 4) fit mediator-outcome confounder regression, excluding mediator(s)
  if (effect_type == "natural") {
    # NOTE: in this case Z := 1 in the wrapper function, so overriding the
    #       provided learner with an intercept model guarantees predictions
    #       that are returned will be uniformly 1
    q_learners <- sl3::Lrnr_mean$new()
  }
  q_out <- fit_moc_mech(
    train_data = train_data,
    valid_data = valid_data,
    contrast = contrast,
    learners = q_learners,
    m_names = m_names,
    w_names = w_names,
    type = "q"
  )

  # 5) fit mediator-outcome confounder regression, conditioning on mediator(s)
  if (effect_type == "natural") {
    # NOTE: in this case Z := 1 in the wrapper function, so overriding the
    #       provided learner with an intercept model guarantees predictions
    #       that are returned will be uniformly 1
    r_learners <- sl3::Lrnr_mean$new()
  }
  r_out <- fit_moc_mech(
    train_data = train_data,
    valid_data = valid_data,
    contrast = contrast,
    learners = r_learners,
    m_names = m_names,
    w_names = w_names,
    type = "r"
  )

  # extract components; NOTE: only do this for observations in validation set
  b_prime <- b_out$b_est_valid$b_pred_A_prime
  h_star <- h_out$treat_est_valid$treat_pred_A_star
  g_star <- g_out$treat_est_valid$treat_pred_A_star[valid_data$R == 1]
  h_prime <- h_out$treat_est_valid$treat_pred_A_prime
  g_prime <- g_out$treat_est_valid$treat_pred_A_prime[valid_data$R == 1]
  q_prime_Z_one <-
    q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]
  r_prime_Z_one <- r_out$moc_est_valid_Z_one$moc_pred_A_prime
  q_prime_Z_natural <-
    q_out$moc_est_valid_Z_natural$moc_pred_A_prime[valid_data$R == 1]
  r_prime_Z_natural <- r_out$moc_est_valid_Z_natural$moc_pred_A_prime

  # need pseudo-outcome regressions with intervention set to a contrast
  # NOTE: training fits of these nuisance functions must be performed using the
  #       data corresponding to the natural intervention value but predictions
  #       are only needed for u(z,a',w) and v(a*,w) as per form of the EIF
  valid_data_a_prime <- data.table::copy(valid_data)[, A := contrast[1]]
  valid_data_a_star <- data.table::copy(valid_data)[, A := contrast[2]]
  u_out <- fit_nuisance_u(
    train_data = train_data,
    valid_data = valid_data_a_prime,
    learners = u_learners,
    b_out = b_out,
    q_out = q_out,
    r_out = r_out,
    g_out = g_out,
    h_out = h_out,
    w_names = w_names
  )
  u_prime <- u_out$u_pred

  v_out <- fit_nuisance_v(
    train_data = train_data,
    valid_data = valid_data_a_star,
    contrast = contrast,
    learners = v_learners,
    b_out = b_out,
    q_out = q_out,
    m_names = m_names,
    w_names = w_names
  )
  v_star <- v_out$v_pred

  # NOTE: assuming Z in {0,1}; other cases not supported yet
  u_int_eif <- lapply(c(1, 0), function(z_val) {
    # intervene on training and validation data sets
    valid_data_z_interv <- data.table::copy(valid_data[R == 1, ])
    valid_data_z_interv[, `:=`(
      Z = z_val,
      A = contrast[1],
      U_pseudo = u_prime
    )]

    # predict u(z, a', w) using intervened data with treatment set A = a'
    # NOTE: here, obs_weights should not include two_phase_weights (?)
    suppressWarnings(
      u_task_valid_z_interv <- sl3::sl3_Task$new(
        data = valid_data_z_interv,
        weights = "obs_weights",
        covariates = c("Z", "A", w_names),
        outcome = "U_pseudo",
        outcome_type = "continuous"
      )
    )

    # return partial pseudo-outcome for v nuisance regression
    out_valid <- u_out[["u_fit"]]$predict(u_task_valid_z_interv)
    return(out_valid)
  })
  u_int_eif <- do.call(`-`, u_int_eif)

  # create inverse probability weights
  ipw_a_prime <- as.numeric(valid_data[R == 1, A] == contrast[1]) / g_prime
  ipw_a_star <- as.numeric(valid_data[R == 1, A] == contrast[2]) / g_star

  # residual term for outcome component of EIF
  c_star <- (g_prime / g_star) * (q_prime_Z_natural / r_prime_Z_natural) *
    (h_star / h_prime)

  # compute uncentered efficient influence function components
  eif_y <- ipw_a_prime * c_star / mean(ipw_a_prime * c_star) *
    (valid_data[R == 1, Y] - b_prime)
  eif_u <- ipw_a_prime / mean(ipw_a_prime) * u_int_eif *
    (valid_data[R == 1, Z] - q_prime_Z_one)
  eif_v <- ipw_a_star / mean(ipw_a_star) * (v_out$v_pseudo - v_star)

  # SANITY CHECK: EIF_U should be ~ZERO~ for natural (in)direct effects
  if (effect_type == "natural") {
    assertthat::assert_that(all(eif_u == 0))
  }

  # un-centered efficient influence function
  eif <- eif_y + eif_u + eif_v + v_star

  # adjust un-centered EIF for two-phase sampling design
  if (!all(data_in$R == 1) || !(all(data_in$two_phase_weights == 1))) {
    # compute a centered EIF
    plugin_est <- est_plugin(v_pred = v_star)
    centered_eif <- eif - plugin_est

    # estimate the conditional EIF using the validation data
    d_out <- fit_nuisance_d(
      train_data = train_data,
      valid_data = valid_data,
      contrast = contrast,
      learners = d_learners,
      b_out = b_out,
      g_out = g_out,
      h_out = h_out,
      q_out = q_out,
      r_out = r_out,
      u_out = u_out,
      v_out = v_out,
      m_names = m_names,
      w_names = w_names
    )
    centered_eif_pred <- d_out$d_pred

    # compute the two-phase sampling un-centered EIF
    full_eif <- two_phase_eif(
      R = valid_data$R,
      two_phase_weights = valid_data$two_phase_weights,
      eif = centered_eif,
      eif_predictions = centered_eif_pred,
      plugin_est = plugin_est
    )
  } else {
    full_eif <- eif
    centered_eif_pred <- NA
  }

  # output list
  out <- list(
    tmle_components = data.table::data.table(
      # components necessary for fluctuation step of TMLE
      g_prime = g_prime, g_star = g_star, h_prime = h_prime, h_star = h_star,
      q_prime_Z_natural = q_prime_Z_natural, q_prime_Z_one = q_prime_Z_one,
      r_prime_Z_natural = r_prime_Z_natural, r_prime_Z_one = r_prime_Z_one,
      v_star = v_star, u_int_diff = u_int_eif,
      b_prime = b_prime, b_prime_Z_zero = v_out$b_A_prime_Z_zero,
      b_prime_Z_one = v_out$b_A_prime_Z_one, D_star = eif,
      # fold IDs
      fold = origami::fold_index()
    ),
    D_star = full_eif,
    D_pred = centered_eif_pred
  )
  return(out)
}


###############################################################################

#' Plug-in estimator
#'
#' A convenience function for the plug-in estimator.
#'
#' @param v_pred A \code{numeric} vector of the predicted values of the v(a, w)
#'   nuisance parameter.
#'
#' @return A \code{numeric} representing the plug-in estimate of the estimand.
#'
#' @keywords internal
est_plugin <- function(v_pred) {
  mean(v_pred)
}

###############################################################################

#' Two-phase sampling adjusted, un-centered efficient influence function
#'
#' Adjust the efficient influence function to account for the use of two-phase
#' sampling designs in measuring the mediators.
#'
#' @param R A \code{logical} vector indicating whether an sampled observation's
#'   mediators were measured using a two-phase sampling design.
#' @param two_phase_weights A \code{numeric} vector of known observation-level
#'   weights corresponding to the inverse probability of the mediators being
#'   measured. These weights should only be provided if the two-phase sampling
#'   indicator \code{R} is specified.
#' @param eif A \code{numeric} vector of the efficient influence function.
#' @param eif_predictions A \code{numeric} vector of the predicted efficient
#'   influence function, conditioning on the mediator being measured.
#' @param plugin_est A \code{numeric} corresponding to the plug-in estimate for
#'   the given contrast.
#'
#' @return An un-centered efficeint influence function that accounts for the
#'   two-phase sampling design.
#'
#' @keywords internal
two_phase_eif <- function(R,
                          two_phase_weights,
                          eif,
                          eif_predictions,
                          plugin_est) {
  # compute the weights for the EIF update
  ipw_two_phase <- R * two_phase_weights

  # for each index in R with R == 0, add a zero at the same index in eif
  new_eif <- rep(NA, length(R))
  eif_idx <- 1
  for (idx in seq_along(R)) {
    if (R[idx] == 1) {
      new_eif[idx] <- eif[eif_idx]
      eif_idx <- eif_idx + 1
    } else {
      new_eif[idx] <- 0
    }
  }

  # compute updated observed-data EIF by projection of complete-data EIF
  # NOTE: D_{obs} = R/g_R * D_{full} -
  #                 (R/g_R - 1) * E[D_{full} | R = 1, W, A, Z, Y]
  two_phase_eif <- ipw_two_phase * new_eif +
    (1 - ipw_two_phase) * eif_predictions

  # return the un-centered two-phase eif
  uncentered_two_phase_eif <- two_phase_eif + plugin_est
  return(uncentered_two_phase_eif)
}

###############################################################################

#' One-step estimator for natural and interventional (in)direct effects
#'
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#'  set appropriately based on the input data. Such a structure is merely a
#'  convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medoutcon}}.
#' @param contrast A \code{numeric} double indicating the two values of the
#'  intervention \code{A} to be compared. The default value of \code{NULL} has
#'  no effect, as the value of the argument \code{effect} is instead used to
#'  define the contrasts. To override \code{effect}, provide a \code{numeric}
#'  double vector, giving the values of a' and a*, e.g., \code{c(0, 1)}.
#' @param g_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for the propensity score.
#' @param h_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for a parameterization of the
#'  propensity score that conditions on the mediators.
#' @param b_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for the outcome regression.
#' @param q_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#'  the intermediate confounder, conditioning on the treatment and potential
#'  baseline covariates.
#' @param r_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#'  the intermediate confounder, conditioning on the mediators, the treatment,
#'  and potential baseline confounders.
#' @param u_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#'  for in the efficient influence function.
#' @param v_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#'  for in the efficient influence function.
#' @param d_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit an initial efficient influence
#'   function regression when computing the efficient influence function in a
#'   two-phase sampling design.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medoutcon}}.
#' @param m_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (M). The input for this argument is automatically
#'  generated by \code{\link{medoutcon}}.
#' @param y_bounds A \code{numeric} double indicating the minimum and maximum
#'  observed values of the outcome variable Y prior to its being re-scaled to
#'  the unit interval.
#' @param g_bounds A \code{numeric} vector containing two values, the first
#'  being the minimum allowable estimated propensity score value and the second
#'  being the maximum allowable for estimated propensity score value.
#' @param effect_type A \code{character} indicating whether components of the
#'  interventional or natural (in)direct effects are to be estimated. In the
#'  case of the natural (in)direct effects, estimation of several nuisance
#'  parameters is unnecessary.
#' @param svy_weights A \code{numeric} vector of observation-level weights that
#'  have been computed externally. Such weights are used in the construction of
#'  a re-weighted estimator.
#' @param cv_folds A \code{numeric} integer specifying the number of folds to
#'  be created for cross-validation. Use of cross-validation allows for entropy
#'  conditions on the one-step estimator to be relaxed. For compatibility with
#'  \code{\link[origami]{make_folds}}, this value specified must be greater
#'  than or equal to 2; the default is to create 5 folds.
#' @param cv_strat A \code{logical} atomic vector indicating whether V-fold
#'  cross-validation should stratify the folds based on the outcome variable.
#'  If \code{TRUE}, the folds are stratified by passing the outcome variable to
#'  the \code{strata_ids} argument of \code{\link[origami]{make_folds}}. While
#'  the default is \code{FALSE}, an override is triggered when the incidence of
#'  the binary outcome variable falls below the tolerance in \code{strat_pmin}.
#' @param strat_pmin A \code{numeric} atomic vector indicating a tolerance for
#'  the minimum proportion of cases (for a binary outcome variable) below which
#'  stratified V-fold cross-validation is invoked if \code{cv_strat} is set to
#'  \code{TRUE} (default is \code{FALSE}). The default tolerance is 0.1.
#'
#' @importFrom assertthat assert_that
#' @importFrom stats var weighted.mean
#' @importFrom origami make_folds cross_validate folds_vfold
#'
#' @keywords internal
est_onestep <- function(data,
                        contrast,
                        g_learners,
                        h_learners,
                        b_learners,
                        q_learners,
                        r_learners,
                        u_learners,
                        v_learners,
                        d_learners,
                        w_names,
                        m_names,
                        y_bounds,
                        g_bounds = c(0.005, 0.995),
                        effect_type = c("interventional", "natural"),
                        svy_weights = NULL,
                        cv_folds = 10L,
                        cv_strat = FALSE,
                        strat_pmin = 0.1) {
  # make sure that more than one fold is specified
  assertthat::assert_that(cv_folds > 1L)

  # create cross-validation folds
  if (cv_strat && data[, mean(Y) <= strat_pmin]) {
    # check that outcome is binary for stratified V-fold cross-validation
    assertthat::assert_that(data[, all(unique(Y) %in% c(0, 1))])

    # if outcome is binary and rare, use stratified V-fold cross-validation
    folds <- origami::make_folds(
      data,
      fold_fun = origami::folds_vfold,
      V = cv_folds,
      strata_ids = data$Y
    )
  } else {
    # just use standard V-fold cross-validation
    folds <- origami::make_folds(
      data,
      fold_fun = origami::folds_vfold,
      V = cv_folds
    )
  }

  # estimate the EIF on a per-fold basis
  cv_eif_results <- origami::cross_validate(
    cv_fun = cv_eif,
    folds = folds,
    data_in = data,
    contrast = contrast,
    g_learners = g_learners,
    h_learners = h_learners,
    b_learners = b_learners,
    q_learners = q_learners,
    r_learners = r_learners,
    u_learners = u_learners,
    v_learners = v_learners,
    d_learners = d_learners,
    effect_type = effect_type,
    w_names = w_names,
    m_names = m_names,
    g_bounds = g_bounds,
    use_future = FALSE,
    .combine = FALSE
  )

  # get estimated efficient influence function
  v_star <- do.call(rbind, cv_eif_results[[1]])$v_star
  obs_valid_idx <- do.call(c, lapply(folds, `[[`, "validation_set"))
  cv_eif_est <- unlist(cv_eif_results$D_star)[order(obs_valid_idx)]

  # re-scale efficient influence function
  eif_est_rescaled <- cv_eif_est %>%
    scale_from_unit(y_bounds[2], y_bounds[1])

  # compute one-step estimate and variance from efficient influence function
  if (is.null(svy_weights)) {
    os_est <- mean(eif_est_rescaled)
    eif_est_out <- eif_est_rescaled
  } else {
    # compute a re-weighted one-step, with re-weighted influence function
    os_est <- stats::weighted.mean(eif_est_rescaled, svy_weights)
    eif_est_out <- eif_est_rescaled * svy_weights
  }
  os_var <- stats::var(eif_est_out) / length(eif_est_out)

  # output
  os_est_out <- list(
    theta = os_est,
    theta_plugin = est_plugin(v_star),
    var = os_var,
    eif = (eif_est_out - os_est),
    type = "onestep"
  )
  return(os_est_out)
}

###############################################################################

#' TML estimator for natural and interventional (in)direct effects
#'
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#'  set appropriately based on the input data. Such a structure is merely a
#'  convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medoutcon}}.
#' @param contrast A \code{numeric} double indicating the two values of the
#'  intervention \code{A} to be compared. The default value of \code{NULL} has
#'  no effect, as the value of the argument \code{effect} is instead used to
#'  define the contrasts. To override \code{effect}, provide a \code{numeric}
#'  double vector, giving the values of a' and a*, e.g., \code{c(0, 1)}.
#' @param g_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for the propensity score.
#' @param h_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for a parameterization of the
#'  propensity score that conditions on the mediators.
#' @param b_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for the outcome regression.
#' @param q_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#'  the intermediate confounder, conditioning on the treatment and potential
#'  baseline covariates.
#' @param r_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#'  the intermediate confounder, conditioning on the mediators, the treatment,
#'  and potential baseline confounders.
#' @param u_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#'  for in the efficient influence function.
#' @param v_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'  (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'  learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#'  for in the efficient influence function.
#' @param d_learners A \code{\link[sl3]{Stack}} object, or other learner class
#'   (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#'   learners from \pkg{sl3}; used to fit an initial efficient influence
#'   function regression when computing the efficient influence function in a
#'   two-phase sampling design.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medoutcon}}.
#' @param m_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (M). The input for this argument is automatically
#'  generated by \code{\link{medoutcon}}.
#' @param y_bounds A \code{numeric} double indicating the minimum and maximum
#'  observed values of the outcome variable Y prior to its being re-scaled to
#'  the unit interval.
#' @param g_bounds A \code{numeric} vector containing two values, the first
#'  being the minimum allowable estimated propensity score value and the second
#'  being the maximum allowable for estimated propensity score value.
#' @param effect_type A \code{character} indicating whether components of the
#'  interventional or natural (in)direct effects are to be estimated. In the
#'  case of the natural (in)direct effects, estimation of several nuisance
#'  parameters is unnecessary.
#' @param svy_weights A \code{numeric} vector of observation-level weights that
#'  have been computed externally. Such weights are used in the construction of
#'  a re-weighted estimator.
#' @param cv_folds A \code{numeric} value specifying the number of folds to be
#'  created for cross-validation. Use of cross-validation allows for entropy
#'  conditions on the TML estimator to be relaxed. Note: for compatibility with
#'  \code{\link[origami]{make_folds}}, this value  must be greater than or
#'  equal to 2; the default is to create 10 folds.
#' @param cv_strat A \code{logical} atomic vector indicating whether V-fold
#'  cross-validation should stratify the folds based on the outcome variable.
#'  If \code{TRUE}, the folds are stratified by passing the outcome variable to
#'  the \code{strata_ids} argument of \code{\link[origami]{make_folds}}. While
#'  the default is \code{FALSE}, an override is triggered when the incidence of
#'  the binary outcome variable falls below the tolerance in \code{strat_pmin}.
#' @param strat_pmin A \code{numeric} atomic vector indicating a tolerance for
#'  the minimum proportion of cases (for a binary outcome variable) below which
#'  stratified V-fold cross-validation is invoked if \code{cv_strat} is set to
#'  \code{TRUE} (default is \code{FALSE}). The default tolerance is 0.1.
#' @param max_iter A \code{numeric} integer giving the maximum number of steps
#'  to be taken for the iterative procedure to construct a TML estimator.
#' @param tiltmod_tol A \code{numeric} indicating the maximum step size to be
#'  taken when performing TMLE updates based on logistic tilting models. When
#'  the step size of a given update exceeds this value, the update is avoided.
#'
#' @importFrom dplyr "%>%"
#' @importFrom assertthat assert_that
#' @importFrom origami make_folds cross_validate folds_vfold
#' @importFrom stats var as.formula plogis qlogis coef predict weighted.mean
#'   binomial
#' @importFrom glm2 glm2
#'
#' @keywords internal
est_tml <- function(data,
                    contrast,
                    g_learners,
                    h_learners,
                    b_learners,
                    q_learners,
                    r_learners,
                    u_learners,
                    v_learners,
                    d_learners,
                    w_names,
                    m_names,
                    y_bounds,
                    g_bounds = c(0.005, 0.95),
                    effect_type = c("interventional", "natural"),
                    svy_weights = NULL,
                    cv_folds = 10L,
                    cv_strat = FALSE,
                    strat_pmin = 0.1,
                    max_iter = 10L,
                    tiltmod_tol = 5) {
  # make sure that more than one fold is specified
  assertthat::assert_that(cv_folds > 1L)

  # create cross-validation folds
  if (cv_strat && data[, mean(Y) <= strat_pmin]) {
    # check that outcome is binary for stratified V-fold cross-validation
    assertthat::assert_that(data[, all(unique(Y) %in% c(0, 1))])

    # if outcome is binary and rare, use stratified V-fold cross-validation
    folds <- origami::make_folds(
      data,
      fold_fun = origami::folds_vfold,
      V = cv_folds,
      strata_ids = data$Y
    )
  } else {
    # just use standard V-fold cross-validation
    folds <- origami::make_folds(
      data,
      fold_fun = origami::folds_vfold,
      V = cv_folds
    )
  }

  # perform the cv_eif procedure on a per-fold basis
  cv_eif_results <- origami::cross_validate(
    cv_fun = cv_eif,
    folds = folds,
    data_in = data,
    contrast = contrast,
    g_learners = g_learners,
    h_learners = h_learners,
    b_learners = b_learners,
    q_learners = q_learners,
    r_learners = r_learners,
    u_learners = u_learners,
    v_learners = v_learners,
    d_learners = d_learners,
    effect_type = effect_type,
    w_names = w_names,
    m_names = m_names,
    g_bounds = g_bounds,
    use_future = FALSE,
    .combine = FALSE
  )

  # concatenate nuisance function function estimates
  # make sure that data is in the same order as the concatenated validations
  # sets
  cv_eif_est <- do.call(rbind, cv_eif_results[[1]])
  obs_valid_idx <- do.call(c, lapply(folds, `[[`, "validation_set"))
  data <- data[obs_valid_idx]

  # extract nuisance function estimates and auxiliary quantities
  g_prime <- cv_eif_est$g_prime
  h_prime <- cv_eif_est$h_prime
  g_star <- cv_eif_est$g_star
  h_star <- cv_eif_est$h_star
  q_prime_Z_one <- cv_eif_est$q_prime_Z_one
  q_prime_Z_natural <- cv_eif_est$q_prime_Z_natural
  r_prime_Z_one <- cv_eif_est$r_prime_Z_one
  r_prime_Z_natural <- cv_eif_est$r_prime_Z_natural
  b_prime_Z_one <- cv_eif_est$b_prime_Z_one
  b_prime_Z_zero <- cv_eif_est$b_prime_Z_zero
  b_prime_Z_natural <- cv_eif_est$b_prime

  # generate inverse weights and multiplier for auxiliary covariates
  ipw_prime <- as.numeric(data[R == 1, A] == contrast[1]) / g_prime
  ipw_star <- as.numeric(data[R == 1, A] == contrast[2]) / g_star
  c_star_mult <- (g_prime / g_star) * (h_star / h_prime)

  # prepare for iterative targeting
  eif_stop_crit <- FALSE
  n_iter <- 0
  n_obs <- nrow(data)
  se_eif <- sqrt(var(cv_eif_est$D_star) / n_obs)
  tilt_stop_crit <- se_eif / log(n_obs)
  b_score <- q_score <- Inf
  tilt_two_phase_weights <- sum(data$R) != nrow(data)
  d_pred <- unlist(cv_eif_results$D_pred)[order(obs_valid_idx)]

  # perform iterative targeting
  while (!eif_stop_crit && n_iter <= max_iter) {
    # NOTE: check convergence condition for outcome regression
    if (mean(b_score) > tilt_stop_crit) {
      # compute auxiliary covariates from updated estimates
      c_star_Z_natural <- (q_prime_Z_natural / r_prime_Z_natural) * c_star_mult
      c_star_Z_one <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
      if (effect_type == "natural") {
        # NOTE: this exception handles 0/0 division, since q(1|a',...) = 1
        #       and r(1|a',...) = 1, improperly yielding 0/0 => NaN
        c_star_Z_zero <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
      } else if (effect_type == "interventional") {
        c_star_Z_zero <- ((1 - q_prime_Z_one) / (1 - r_prime_Z_one)) *
          c_star_mult
      }

      # bound and transform nuisance estimates for tilting regressions
      b_prime_Z_natural_logit <- b_prime_Z_natural %>%
        bound_precision() %>%
        stats::qlogis()
      b_prime_Z_one_logit <- b_prime_Z_one %>%
        bound_precision() %>%
        stats::qlogis()
      b_prime_Z_zero_logit <- b_prime_Z_zero %>%
        bound_precision() %>%
        stats::qlogis()

      # fit tilting model for the outcome mechanism
      c_star_b_tilt <- c_star_Z_natural
      if (tilt_two_phase_weights) {
        weights_b_tilt <- as.numeric(data[R == 1, A] == contrast[1]) /
          g_prime * as.numeric(data[R == 1, two_phase_weights])
      } else {
        weights_b_tilt <- data$obs_weights * (data$A == contrast[1]) / g_prime
      }

      suppressWarnings(
        b_tilt_fit <- glm2::glm2(
          stats::as.formula("y_scaled ~ -1 + offset(b_prime_logit) + c_star"),
          data = data.table::as.data.table(list(
            y_scaled = data[R == 1, Y],
            b_prime_logit = b_prime_Z_natural_logit,
            c_star = c_star_b_tilt
          )),
          weights = weights_b_tilt,
          family = stats::binomial(),
          start = 0
        )
      )
      if (is.na(stats::coef(b_tilt_fit))) {
        b_tilt_fit$coefficients <- 0
      } else if (!b_tilt_fit$converged || abs(max(stats::coef(b_tilt_fit))) >
        tiltmod_tol) {
        b_tilt_fit$coefficients <- 0
      }
      b_tilt_coef <- unname(stats::coef(b_tilt_fit))

      # update nuisance estimates via tilting regressions for outcome
      b_prime_Z_natural <- stats::plogis(b_prime_Z_natural_logit +
        b_tilt_coef * c_star_Z_natural)
      b_prime_Z_one <- stats::plogis(b_prime_Z_one_logit +
        b_tilt_coef * c_star_Z_one)
      b_prime_Z_zero <- stats::plogis(b_prime_Z_zero_logit +
        b_tilt_coef * c_star_Z_zero)

      # compute efficient score for outcome regression component
      b_score <- data[R == 1, two_phase_weights] *
        ipw_prime * c_star_Z_natural * (data[R == 1, Y] - b_prime_Z_natural)
    } else {
      b_score <- 0
    }

    # NOTE: check convergence condition for intermediate confounding
    if (mean(q_score) > tilt_stop_crit) {
      # perform iterative targeting for intermediate confounding
      q_prime_Z_one_logit <- q_prime_Z_one %>%
        bound_precision() %>%
        stats::qlogis()

      # fit tilting regressions for intermediate confounding
      u_prime_diff_q_tilt <- cv_eif_est$u_int_diff
      if (tilt_two_phase_weights) {
        weights_q_tilt <- as.numeric(data[R == 1, A] == contrast[1]) /
          g_prime * as.numeric(data[R == 1, two_phase_weights])
      } else {
        weights_q_tilt <- data$obs_weights * (data$A == contrast[1]) / g_prime
      }
      suppressWarnings(
        q_tilt_fit <- glm2::glm2(
          stats::as.formula("Z ~ -1 + offset(q_prime_logit) + u_prime_diff"),
          data = data.table::as.data.table(list(
            Z = data[R == 1, Z],
            q_prime_logit = q_prime_Z_one_logit,
            u_prime_diff = u_prime_diff_q_tilt
          )),
          weights = weights_q_tilt,
          family = stats::binomial(),
          start = 0
        )
      )

      # NOTE: for the natural (in)direct effects, the regressor on the RHS
      #       is uniquely ZERO so estimated parameter should always be NaN
      if (effect_type == "natural") {
        q_tilt_fit$coefficients <- NA
      }
      if (is.na(stats::coef(q_tilt_fit))) {
        q_tilt_fit$coefficients <- 0
      } else if (!q_tilt_fit$converged || abs(max(stats::coef(q_tilt_fit))) >
        tiltmod_tol) {
        q_tilt_fit$coefficients <- 0
      }
      q_tilt_coef <- unname(stats::coef(q_tilt_fit))

      # update nuisance estimates via tilting of intermediate confounder
      if (effect_type == "natural") {
        # for the natural (in)direct effects, no updates necessary
        q_prime_Z_one <- data[R == 1, Z]
        q_prime_Z_natural <- data[R == 1, Z]
      } else {
        q_prime_Z_one <- stats::plogis(q_prime_Z_one_logit + q_tilt_coef *
          cv_eif_est$u_int_diff)
        q_prime_Z_natural <- (data[R == 1, Z] * q_prime_Z_one) +
          ((1 - data[R == 1, Z]) * (1 - q_prime_Z_one))
      }

      # compute efficient score for intermediate confounding component
      q_score <- ipw_prime * cv_eif_est$u_int_diff *
        (data[R == 1, Z] - q_prime_Z_one) *
        (data[R == 1, two_phase_weights])
    } else {
      q_score <- 0
    }

    # check convergence and iterate the counter
    eif_stop_crit <- all(
      abs(c(mean(b_score), mean(q_score))) < tilt_stop_crit
    )
    n_iter <- n_iter + 1
  }

  # update auxiliary covariates after completion of iterative targeting
  c_star_Z_natural <- (q_prime_Z_natural / r_prime_Z_natural) * c_star_mult
  c_star_Z_one <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
  if (effect_type == "natural") {
    # NOTE: this exception handles 0/0 division, since q(1|a',...)  = 1 and
    #       r(1|a',...) = 1, improperly yielding 0/0 => NaN
    c_star_Z_zero <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
  } else if (effect_type == "interventional") {
    c_star_Z_zero <- ((1 - q_prime_Z_one) / (1 - r_prime_Z_one)) * c_star_mult
  }

  # compute updated substitution estimator and prepare for tilting regression
  v_pseudo <- ((b_prime_Z_one * q_prime_Z_one) +
    (b_prime_Z_zero * (1 - q_prime_Z_one))) %>%
    bound_precision()
  v_star_logit <- cv_eif_est$v_star %>%
    bound_precision() %>%
    stats::qlogis()

  # fit tilting regression for substitution estimator
  if (tilt_two_phase_weights) {
    weights_v_tilt <- (as.numeric(data[R == 1, A]) == contrast[2]) / g_star *
      (as.numeric(data[R == 1, two_phase_weights]))
  } else {
    weights_v_tilt <- data$obs_weights * (data$A == contrast[2]) / g_star
  }
  suppressWarnings(
    v_tilt_fit <- glm2::glm2(
      stats::as.formula("v_pseudo ~ offset(v_star_logit)"),
      data = data.table::as.data.table(list(
        v_pseudo = v_pseudo,
        v_star_logit = v_star_logit
      )),
      weights = weights_v_tilt,
      family = stats::binomial(),
      start = 0
    )
  )
  v_star_tmle <- unname(stats::predict(v_tilt_fit, type = "response"))

  # compute influence function with centering at the TML estimate
  # make sure that it's in the same order as the original data
  eif_est <- unlist(cv_eif_results$D_star)[order(obs_valid_idx)]

  # re-scale efficient influence function
  v_star_tmle_rescaled <- v_star_tmle %>%
    scale_from_unit(y_bounds[2], y_bounds[1])
  eif_est_rescaled <- eif_est %>%
    scale_from_unit(y_bounds[2], y_bounds[1])

  # compute TML estimator and variance from efficient influence function
  if (is.null(svy_weights)) {
    tml_est <- mean(v_star_tmle_rescaled)
    eif_est_out <- eif_est_rescaled
  } else {
    # compute a re-weighted TMLE, with re-weighted influence function
    # NOTE: make sure that survey weights are ordered like the concatenated
    #       validation sets
    svy_weights <- svy_weights[obs_valid_idx]
    tml_est <- stats::weighted.mean(v_star_tmle_rescaled, svy_weights)
    eif_est_out <- eif_est_rescaled * svy_weights
  }
  tmle_var <- stats::var(eif_est_out) / length(eif_est_out)

  # output
  tmle_out <- list(
    theta = tml_est,
    theta_plugin = est_plugin(cv_eif_est$v_star),
    var = tmle_var,
    eif = (eif_est_out - tml_est),
    n_iter = n_iter,
    type = "tmle"
  )
  return(tmle_out)
}
nhejazi/medoutcon documentation built on July 16, 2025, 5:38 p.m.