R/estim_onestep.R

Defines functions cv_eif est_onestep

Documented in cv_eif est_onestep

utils::globalVariables(c("..eif_component_names", "..w_names"))

#' Efficient One-Step Estimator
#'
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, Z, A, W), with column names set
#'  appropriately based on the original input data. Such a structure is merely
#'  a convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medshift}}.
#' @param delta A \code{numeric} value indicating the degree of shift in the
#'  intervention to be used in defining the causal quantity of interest. In the
#'  case of binary interventions, this takes the form of an incremental
#'  propensity score shift, acting as a multiplier of the probability with
#'  which a given observational unit receives the intervention (EH Kennedy,
#'  2018, JASA; <doi:10.1080/01621459.2017.1422737>).
#' @param g_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting the propensity
#'  score, i.e., g = P(A | W).
#' @param e_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting a propensity
#'  score that conditions on the mediators, i.e., e = P(A | Z, W).
#' @param m_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting the outcome
#'  regression, i.e., m(A, Z, W).
#' @param phi_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in a regression of a
#'  pseudo-outcome on the baseline covariates, i.e.,
#'  phi(W) = E[m(A = 1, Z, W) - m(A = 0, Z, W) | W).
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medshift}}.
#' @param z_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (Z). The input for this argument is automatically
#'  generated by \code{\link{medshift}}.
#' @param cv_folds A \code{numeric} specifying the number of folds to be
#'  created for cross-validation. Use of cross-validation / cross-fitting
#'  allows for entropy conditions on the AIPW estimator to be relaxed. Note:
#'  for compatibility with \code{\link[origami]{make_folds}}, this value must
#'  be greater than or equal to 2; the default is to create 10 folds.
#'
#' @importFrom stats var
#' @importFrom origami make_folds cross_validate folds_vfold
est_onestep <- function(data,
                        delta,
                        g_learners,
                        e_learners,
                        m_learners,
                        phi_learners,
                        w_names,
                        z_names,
                        cv_folds = 10) {
  # use origami to perform CV-SL, fitting/evaluating EIF components per fold
  eif_component_names <- c("Dy", "Da", "Dzw")

  # create folds for use with origami::cross_validate
  folds <- origami::make_folds(data,
    fold_fun = origami::folds_vfold,
    V = cv_folds,
    cluster_ids = data[["ids"]]
  )

  # perform the cv_eif procedure on a per-fold basis
  cv_eif_results <- origami::cross_validate(
    cv_fun = cv_eif,
    folds = folds,
    data = data,
    delta = delta,
    g_learners = g_learners,
    e_learners = e_learners,
    m_learners = m_learners,
    phi_learners = phi_learners,
    w_names = w_names,
    z_names = z_names,
    use_future = FALSE,
    .combine = FALSE
  )

  # combine results of EIF components for full EIF
  est_over_delta <- lapply(seq_along(delta), function(iter) {
    # compute influence function, parameter, variance for given delta
    D_obs <- lapply(cv_eif_results[[iter]], function(x) {
      D_obs_fold <- rowSums(x[, ..eif_component_names])
      return(D_obs_fold)
    })

    # get estimated observation-level values of EIF
    estim_eif <- do.call(c, D_obs)

    # compute one-step estimate of parameter and variance from EIF
    estim_onestep_param <- mean(estim_eif)
    # combine repeated EIF estimates if IDs are non-unique
    if (length(unique(data[["ids"]])) < nrow(data)) {
      estim_eif_combined <- by(estim_eif, as.numeric(data[["ids"]]), mean,
        simplify = FALSE
      )
      estim_eif_reduced <- unname(do.call(c, estim_eif_combined))
      estim_onestep_var <- stats::var(estim_eif_reduced) /
        length(estim_eif_reduced)
    } else {
      estim_onestep_var <- stats::var(estim_eif) / length(estim_eif)
    }

    # output for a given delta
    estim_onestep_out <- list(
      theta = estim_onestep_param,
      var = estim_onestep_var,
      eif = (estim_eif - estim_onestep_param),
      type = "onestep"
    )
    return(estim_onestep_out)
  })

  # output
  if (length(delta) == 1) {
    return(unlist(est_over_delta, recursive = FALSE))
  } else {
    return(est_over_delta)
  }
}

###############################################################################

#' Cross-validated Evaluation of Efficient Influence Function Components
#'
#' @param fold Object specifying cross-validation folds as generated by a call
#'  to \code{origami::make_folds}.
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, Z, A, W), with column names set
#'  appropriately based on the original input data. Such a structure is merely
#'  a convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medshift}}.
#' @param delta A \code{numeric} value indicating the degree of shift in the
#'  intervention to be used in defining the causal quantity of interest. In the
#'  case of binary interventions, this takes the form of an incremental
#'  propensity score shift, acting as a multiplier of the odds with which a
#'  given observational unit receives the intervention (EH Kennedy, 2018,
#'  JASA; <doi:10.1080/01621459.2017.1422737>).
#' @param g_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting the propensity
#'  score, i.e., g = P(A | W).
#' @param e_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting a propensity
#'  score that conditions on the mediators, i.e., e = P(A | Z, W).
#' @param m_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting the outcome
#'  regression, i.e., m(A, Z, W).
#' @param phi_learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in a regression of a
#'  pseudo-outcome on the baseline covariates, i.e.,
#'  phi(W) = E[m(A = 1, Z, W) - m(A = 0, Z, W) | W).
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by a call to the wrapper function \code{medshift}.
#' @param z_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (Z). The input for this argument is automatically
#'  generated by a call to the wrapper function \code{medshift}.
#'
#' @importFrom data.table data.table
#' @importFrom origami training validation fold_index
#'
#' @keywords internal
cv_eif <- function(fold,
                   data,
                   delta,
                   g_learners,
                   e_learners,
                   m_learners,
                   phi_learners,
                   w_names,
                   z_names) {
  # make training and validation data
  train_data <- origami::training(data)
  valid_data <- origami::validation(data)

  # compute nuisance parameters eta = (g, m, e, phi)
  ## 1) fit regression for incremental propensity score intervention
  g_out <- lapply(delta, function(delta) {
    # NOTE: even this is _repeated_ computation since delta is a multiplier
    #       ...worth fixing later if a bottleneck.
    g_est_delta <- fit_g_mech(
      data = train_data, valid_data = valid_data,
      delta = delta,
      learners = g_learners, w_names = w_names
    )
    return(g_est_delta)
  })

  ## 2) fit clever regression for treatment, conditional on mediators
  e_out <- fit_e_mech(
    data = train_data, valid_data = valid_data,
    learners = e_learners,
    z_names = z_names, w_names = w_names
  )

  ## 3) fit regression for incremental propensity score intervention
  m_out <- fit_m_mech(
    data = train_data, valid_data = valid_data,
    learners = m_learners,
    z_names = z_names, w_names = w_names
  )

  ## 4) difference-reduced dimension regression with pseudo-outcome
  phi_est <- fit_phi_mech(
    train_data = train_data, valid_data = valid_data,
    learners = phi_learners,
    m_output = m_out, w_names = w_names
  )

  # get indices of treated and control units in validation data
  idx_A1 <- which(valid_data$A == 1)
  idx_A0 <- which(valid_data$A == 0)

  # loop over each delta-shift in grid to assemble EIF components
  eif_over_delta <- lapply(seq_along(delta), function(iter) {
    # compute component Dzw from nuisance parameters
    Dzw_groupwise <- compute_Dzw(g_output = g_out[[iter]], m_output = m_out)
    D_ZW <- Dzw_groupwise$dzw_cntrl + Dzw_groupwise$dzw_treat

    # compute component Da from nuisance parameters
    g_pred_A1 <- g_out[[iter]]$g_est$g_pred_A1
    g_pred_A0 <- g_out[[iter]]$g_est$g_pred_A0
    Da_numerator <- delta[iter] * phi_est * (valid_data$A - g_pred_A1)
    Da_denominator <- (delta[iter] * g_pred_A1 + g_pred_A0)^2
    D_A <- Da_numerator / Da_denominator

    # compute component Dy from nuisance parameters
    ipw_groupwise <- compute_ipw(
      g_output = g_out[[iter]], e_output = e_out,
      idx_treat = idx_A1, idx_cntrl = idx_A0
    )
    m_pred_obs <- rep(NA, nrow(valid_data))
    m_pred_obs[idx_A1] <- m_out$m_est$m_pred_A1[idx_A1]
    m_pred_obs[idx_A0] <- m_out$m_est$m_pred_A0[idx_A0]

    # stabilize weights in AIPW by dividing by sample average, n.b., E[g/e] = 1
    mean_aipw <- ipw_groupwise$mean_aipw
    g_shifted <- ipw_groupwise$g_shifted
    e_pred <- ipw_groupwise$e_pred
    D_Y <- ((g_shifted / e_pred) / mean_aipw) * (valid_data$Y - m_pred_obs)

    # output table of EIF results for given shift
    eif_out <- data.table::data.table(
      Dy = D_Y, Da = D_A, Dzw = D_ZW,
      fold = origami::fold_index()
    )
    return(eif_out)
  })

  # output
  return(eif_over_delta)
}
nhejazi/medshift documentation built on Feb. 8, 2022, 10:55 p.m.