R/fit_mechanisms.R

Defines functions fit_nuisance_d fit_nuisance_v fit_nuisance_u fit_moc_mech fit_out_mech fit_treat_mech

Documented in fit_moc_mech fit_nuisance_d fit_nuisance_u fit_nuisance_v fit_out_mech fit_treat_mech

utils::globalVariables(c(
  "..w_names", "A", "Z", "R", "V_pseudo", "obs_weights", "two_phase_weights",
  "eif"
))

#' Fit propensity scores for treatment contrasts
#'
#' @param train_data A \code{data.table} containing the observed data; columns
#'   are in the order specified by the NPSEM (Y, M, R, Z, A, W), with column
#'   names set appropriately based on the data. Such a structure is merely a
#'   convenience utility to passing data around to the various core estimation
#'   routines and is automatically generated \code{\link{medoutcon}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'   appearing in the preceding argument \code{train_data}, to be used for
#'   estimation via cross-fitting. Optional, defaulting to \code{NULL}.
#' @param contrast A \code{numeric} double indicating the two values of the
#'   intervention \code{A} to be compared. The default value of \code{c(0, 1)}
#'   assumes a binary intervention node \code{A}.
#' @param learners \code{\link[sl3]{Stack}}, or other learner class (inheriting
#'   from \code{\link[sl3]{Lrnr_base}}), containing a set of learners from
#'   \pkg{sl3}, to be used in fitting a propensity score models, i.e., g := P(A
#'   = 1 | W) and h := P(A = 1 | M, W).
#' @param m_names A \code{character} vector of the names of the columns that
#'   correspond to mediators (M). The input for this argument is automatically
#'   generated by \code{\link{medoutcon}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'   correspond to baseline covariates (W). The input for this argument is
#'   automatically generated by \code{\link{medoutcon}}.
#' @param type A \code{character} indicating which of the treatment mechanism
#'   variants to estimate. Option \code{"g"} is the propensity score g(A|W)
#'   while option \code{"h"} is a re-parameterized mediator density h(A|M,W).
#' @param bounds A \code{numeric} vector containing two values, the first being
#'   the minimum allowable estimated propensity score value and the second
#'   being the maximum allowable for estimated propensity score value.
#'
#' @importFrom data.table as.data.table copy setnames ":="
#' @importFrom sl3 sl3_Task
fit_treat_mech <- function(train_data,
                           valid_data = NULL,
                           contrast,
                           learners,
                           m_names,
                           w_names,
                           type = c("g", "h"),
                           bounds = c(0.01, 0.99)) {
  if (type == "g") {
    # NOTE: estimation of treatment propensity does not require two-phase
    #       sampling weights
    cov_names <- w_names
  } else if (type == "h") {
    cov_names <- c(m_names, w_names)

    # update observation weights with two-phase sampling weights, if necessary
    # NOTE: importantly, re-weighting the propensity score (g) estimator is
    #       not necessary under two-phase sampling of the mediators
    train_data[, obs_weights := two_phase_weights * obs_weights]
    valid_data[, obs_weights := two_phase_weights * obs_weights]

    # remove observations that were not sampled in second stage
    train_data <- train_data[R == 1, ]
    valid_data <- valid_data[R == 1, ]
  }

  ## construct task for treatment mechanism fit
  treat_task <- sl3::sl3_Task$new(
    data = train_data,
    weights = "obs_weights",
    covariates = cov_names,
    outcome = "A",
    outcome_type = "binomial"
  )

  ## fit and predict treatment mechanism
  treat_fit <- learners$train(treat_task)
  treat_pred <- treat_fit$predict()

  ## use full data for prediction if no validation data provided
  if (is.null(valid_data)) {
    treat_pred_A_prime <- contrast[1] * treat_pred +
      (1 - contrast[1]) * (1 - treat_pred)
    treat_pred_A_star <- contrast[2] * treat_pred +
      (1 - contrast[2]) * (1 - treat_pred)

    ## bounding to numerical precision and for positivity considerations
    out_treat_mat <- cbind(
      treat_pred_A_prime,
      treat_pred_A_star
    )
    out_treat_est <- apply(out_treat_mat, 2, function(x) {
      x_precise <- bound_precision(x)
      x_bounded <- bound_propensity(x_precise, bounds = bounds)
      return(x_bounded)
    })
    out_treat_est <- data.table::as.data.table(out_treat_est)
    data.table::setnames(out_treat_est, c(
      "treat_pred_A_prime",
      "treat_pred_A_star"
    ))

    ## output
    out <- list(
      treat_est = out_treat_est,
      treat_fit = treat_fit
    )
  } else {
    out_treat_est <- lapply(
      list(train_data, valid_data),
      function(data) {
        ## create task to generate contrast-specific predictions
        treat_task <- sl3::sl3_Task$new(
          data = data,
          weights = "obs_weights",
          covariates = cov_names,
          outcome = "A",
          outcome_type = "binomial"
        )

        ## predictions for training data
        treat_pred <- treat_fit$predict(treat_task)

        treat_pred_A_prime <- contrast[1] * treat_pred +
          (1 - contrast[1]) * (1 - treat_pred)
        treat_pred_A_star <- contrast[2] * treat_pred +
          (1 - contrast[2]) * (1 - treat_pred)

        ## bounding to numerical precision and for positivity considerations
        out_treat_mat <- cbind(
          treat_pred_A_prime,
          treat_pred_A_star
        )
        out_treat_est <- apply(out_treat_mat, 2, function(x) {
          x_precise <- bound_precision(x)
          x_bounded <- bound_propensity(x_precise, bounds = bounds)
          return(x_bounded)
        })
        out_treat_est <- data.table::as.data.table(out_treat_est)
        data.table::setnames(out_treat_est, c(
          "treat_pred_A_prime",
          "treat_pred_A_star"
        ))
      }
    )

    ## output
    out <- list(
      treat_est_train = out_treat_est[[1]],
      treat_est_valid = out_treat_est[[2]],
      treat_fit = treat_fit
    )
  }
  return(out)
}

###############################################################################

#' Fit outcome regression
#'
#' @param train_data A \code{data.table} containing the observed data, with
#'  columns in the order specified by the NPSEM (Y, M, R, Z, A, W), with column
#'  names set based on the input data. Such a structure is a convenience
#'  utility to passing data around to the various core estimation routines and
#'  is automatically generated \code{\link{medoutcon}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. Optional, defaulting to \code{NULL}.
#' @param contrast A \code{numeric} double indicating the two values of the
#'  intervention \code{A} to be compared. The default of \code{c(0, 1)} assumes
#'  a binary intervention node \code{A}.
#' @param learners \code{\link[sl3]{Stack}}, or other learner class (inheriting
#'  from \code{\link[sl3]{Lrnr_base}}), containing a set of learners from
#'  \pkg{sl3}, to be used in fitting the outcome regression, i.e., b(A,Z,M,W).
#' @param m_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (M). The input for this argument is automatically
#'  generated by \code{\link{medoutcon}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medoutcon}}.
#'
#' @importFrom data.table as.data.table copy setnames ":="
#' @importFrom sl3 sl3_Task
fit_out_mech <- function(train_data,
                         valid_data = NULL,
                         contrast,
                         learners,
                         m_names,
                         w_names) {
  # update observation weights with two-phase sampling weights, if necessary
  train_data[, obs_weights := two_phase_weights * obs_weights]
  valid_data[, obs_weights := two_phase_weights * obs_weights]

  # remove observations that were not sampled in second stage
  train_data <- train_data[R == 1, ]
  valid_data <- valid_data[R == 1, ]

  ##  construct task for propensity score fit
  b_natural_task <- sl3::sl3_Task$new(
    data = train_data,
    weights = "obs_weights",
    covariates = c(m_names, "Z", "A", w_names),
    outcome = "Y"
  )

  ## fit and predict
  b_natural_fit <- learners$train(b_natural_task)
  b_natural_pred <- b_natural_fit$predict()

  ## use full data for counterfactual prediction if no validation data given
  if (is.null(valid_data)) {
    ## set intervention to first contrast a_prime := contrast[1]
    train_data_intervene <- data.table::copy(train_data)
    train_data_intervene[, A := contrast[1]]

    ## predictions on observed data (i.e., under observed treatment status)
    b_natural_pred <- b_natural_fit$predict()

    ## create task for post-intervention outcome regression
    b_intervened_prime_task <- sl3::sl3_Task$new(
      data = train_data_intervene,
      weights = "obs_weights",
      covariates = c(m_names, "Z", "A", w_names),
      outcome = "Y"
    )

    ## predict from trained model on counterfactual data
    b_intervened_pred_A_prime <- b_natural_fit$predict(b_intervened_prime_task)

    ## set intervention to second contrast a* := contrast[2] + create task
    train_data_intervene[, A := contrast[2]]
    b_intervened_star_task <- sl3::sl3_Task$new(
      data = train_data_intervene,
      weights = "obs_weights",
      covariates = c(m_names, "Z", "A", w_names),
      outcome = "Y"
    )

    ## predict from trained model on counterfactual data
    b_intervened_pred_A_star <- b_natural_fit$predict(b_intervened_star_task)

    ## output
    out_b_est <- data.table::as.data.table(cbind(
      b_natural_pred,
      b_intervened_pred_A_prime,
      b_intervened_pred_A_star
    ))
    data.table::setnames(out_b_est, c(
      "b_pred_A_natural",
      "b_pred_A_prime",
      "b_pred_A_star"
    ))

    ## output
    out <- list(
      b_est = out_b_est,
      b_fit = b_natural_fit
    )
  } else {
    ## copy both training and validation data, once for each contrast
    train_data_intervene <- data.table::copy(train_data)
    valid_data_intervene <- data.table::copy(valid_data)

    ## predictions on observed data (i.e., under observed treatment status)
    b_natural_pred_train <- b_natural_fit$predict()
    b_natural_task_valid <- sl3::sl3_Task$new(
      data = valid_data,
      weights = "obs_weights",
      covariates = c(m_names, "Z", "A", w_names),
      outcome = "Y"
    )
    b_natural_pred_valid <- b_natural_fit$predict(b_natural_task_valid)

    ## set intervention to first contrast a' := contrast[1]
    out_b_est <- lapply(
      list(train_data_intervene, valid_data_intervene),
      function(data_intervene) {
        ## set intervention to first contrast a' := contrast[1]
        data_intervene[, A := contrast[1]]
        b_intervened_prime_task <- sl3::sl3_Task$new(
          data = data_intervene,
          weights = "obs_weights",
          covariates = c(m_names, "Z", "A", w_names),
          outcome = "Y"
        )

        ## predict from trained model on counterfactual data
        b_intervened_pred_A_prime <-
          b_natural_fit$predict(b_intervened_prime_task)

        ## set intervention to second contrast a* := contrast[2]
        data_intervene[, A := contrast[2]]
        b_intervened_star_task <- sl3::sl3_Task$new(
          data = data_intervene,
          weights = "obs_weights",
          covariates = c(m_names, "Z", "A", w_names),
          outcome = "Y"
        )

        ## predict from trained model on counterfactual data
        b_intervened_pred_A_star <-
          b_natural_fit$predict(b_intervened_star_task)

        ## output
        out_b_est <- data.table::as.data.table(cbind(
          b_intervened_pred_A_prime,
          b_intervened_pred_A_star
        ))
        return(out_b_est)
      }
    )

    ## add natural treatment estimates to post-intervention predictions
    out_b_est[[1]] <- cbind(b_natural_pred_train, out_b_est[[1]])
    out_b_est[[2]] <- cbind(b_natural_pred_valid, out_b_est[[2]])
    lapply(out_b_est, function(x) {
      data.table::setnames(x, c(
        "b_pred_A_natural",
        "b_pred_A_prime",
        "b_pred_A_star"
      ))
    })

    ## output
    out <- list(
      b_est_train = out_b_est[[1]],
      b_est_valid = out_b_est[[2]],
      b_fit = b_natural_fit
    )
  }
  return(out)
}

###############################################################################

#' Fit intermediate confounding mechanism with(out) conditioning on mediators
#'
#' @param train_data A \code{data.table} containing observed data, with columns
#'  in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#'  set appropriately based on the input data. Such a structure is a
#'  convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medoutcon}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. Optional, defaulting to \code{NULL}.
#' @param contrast A \code{numeric} double indicating the two values of the
#'  intervention \code{A} to be compared. The default value of \code{c(0, 1)}
#'  assumes a binary intervention node \code{A}.
#' @param learners \code{\link[sl3]{Stack}}, or other learner class (inheriting
#'  from \code{\link[sl3]{Lrnr_base}}), containing a set of learners from
#'  \pkg{sl3}, to be used in fitting a model for the intermediate confounding
#'  mechanism, i.e., q = E[z|a',W] and r = E[z|a',m,w]).
#' @param m_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (M). The input for this argument is automatically
#'  generated by a call to the wrapper function \code{\link{medoutcon}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medoutcon}}.
#' @param type A \code{character} vector indicating whether to condition on the
#'  mediators (M) or not. Specifically, this is an option for specifying one of
#'  two types of nuisance regressions: "r" is defined as the component that
#'  conditions on the mediators (i.e., r = E[z|a',m,w]) while "q" is defined as
#'  the component that does not (i.e., q = E[z|a',w]).
#'
#' @importFrom data.table as.data.table copy setnames ":="
#' @importFrom sl3 sl3_Task
fit_moc_mech <- function(train_data,
                         valid_data = NULL,
                         contrast,
                         learners,
                         m_names,
                         w_names,
                         type = c("q", "r")) {
  ## construct task for nuisance parameter fit
  if (type == "q") {
    cov_names <- w_names

    # update observation weights with two-phase sampling weights, if necessary
    # NOTE: Might not be necessary, check with Nima
    train_data[, obs_weights := two_phase_weights * obs_weights]
    valid_data[, obs_weights := two_phase_weights * obs_weights]
  } else if (type == "r") {
    cov_names <- c(m_names, w_names)

    # update observation weights with two-phase sampling weights, if necessary
    train_data[, obs_weights := two_phase_weights * obs_weights]
    valid_data[, obs_weights := two_phase_weights * obs_weights]

    # remove observations that were not sampled in second stage
    train_data <- train_data[R == 1, ]
    valid_data <- valid_data[R == 1, ]
  }

  moc_task <- sl3::sl3_Task$new(
    data = train_data,
    weights = "obs_weights",
    covariates = c("A", cov_names),
    outcome = "Z",
    outcome_type = "binomial"
  )

  ## fit model on observed data
  moc_fit <- learners$train(moc_task)

  ## use full data for counterfactual prediction if no validation data given
  if (is.null(valid_data)) {
    ## set intervention to first contrast a_prime := contrast[1]
    train_data_intervene <- data.table::copy(train_data)
    train_data_intervene[, A := contrast[1]]

    ## predictions on observed data (i.e., under observed treatment status)
    moc_pred_A_natural <- moc_fit$predict()

    ## create task for post-intervention outcome regression
    moc_prime_task <- sl3::sl3_Task$new(
      data = train_data_intervene,
      weights = "obs_weights",
      covariates = c("A", cov_names),
      outcome = "Z",
      outcome_type = "binomial"
    )

    ## predict from trained model on counterfactual data
    moc_pred_A_prime <- moc_fit$predict(moc_prime_task)

    ## set intervention to a* := contrast[2] and create task
    train_data_intervene[, A := contrast[2]]
    moc_star_task <- sl3::sl3_Task$new(
      data = train_data_intervene,
      weights = "obs_weights",
      covariates = c("A", cov_names),
      outcome = "Z",
      outcome_type = "binomial"
    )

    ## predict from trained model on counterfactual data
    moc_pred_A_star <- moc_fit$predict(moc_star_task)

    ## output
    out_moc_est <- data.table::as.data.table(cbind(
      moc_pred_A_natural,
      moc_pred_A_prime,
      moc_pred_A_star
    ))
    data.table::setnames(out_moc_est, c(
      "moc_pred_A_natural",
      "moc_pred_A_prime",
      "moc_pred_A_star"
    ))

    ## output
    out <- list(
      moc_est = out_moc_est,
      moc_fit = moc_fit
    )
  } else {
    ## copy both training and validation data, once for each contrast
    train_data_intervene <- data.table::copy(train_data)
    valid_data_intervene <- data.table::copy(valid_data)

    ## predictions on observed data (i.e., under observed treatment status)
    moc_pred_A_natural_train <- moc_fit$predict()

    ## create task for post-intervention outcome regression
    moc_task_valid <- sl3::sl3_Task$new(
      data = valid_data,
      weights = "obs_weights",
      covariates = c("A", cov_names),
      outcome = "Z",
      outcome_type = "binomial"
    )

    ## prediction on observed data, in validation set
    moc_pred_A_natural_valid <- moc_fit$predict(moc_task_valid)

    ## set intervention to first contrast a_prime := contrast[1]
    out_moc_est <- lapply(
      list(train_data_intervene, valid_data_intervene),
      function(data_intervene) {
        ## intervene to set treatment to first contrast (A prime)
        data_intervene[, A := contrast[1]]

        ## create task for post-intervention outcome regression
        moc_prime_task <- sl3::sl3_Task$new(
          data = data_intervene,
          weights = "obs_weights",
          covariates = c("A", cov_names),
          outcome = "Z",
          outcome_type = "binomial"
        )

        ## predict from trained model on counterfactual data
        moc_pred_A_prime <- moc_fit$predict(moc_prime_task)

        ## set intervention to contrast a* := contrast[2] + create task
        data_intervene[, A := contrast[2]]
        moc_star_task <- sl3::sl3_Task$new(
          data = data_intervene,
          weights = "obs_weights",
          covariates = c("A", cov_names),
          outcome = "Z",
          outcome_type = "binomial"
        )

        ## predict from trained model on counterfactual data
        moc_pred_A_star <- moc_fit$predict(moc_star_task)

        ## output
        out_moc_est <-
          data.table::as.data.table(cbind(
            moc_pred_A_prime,
            moc_pred_A_star
          ))
      }
    )

    ## add natural treatment estimates to post-intervention predictions
    out_moc_est[[1]] <- cbind(moc_pred_A_natural_train, out_moc_est[[1]])
    out_moc_est[[2]] <- cbind(moc_pred_A_natural_valid, out_moc_est[[2]])
    lapply(out_moc_est, function(x) {
      data.table::setnames(x, c(
        "moc_pred_A_natural",
        "moc_pred_A_prime",
        "moc_pred_A_star"
      ))
    })

    ## output
    out <- list(
      moc_est_train_Z_one = out_moc_est[[1]],
      moc_est_valid_Z_one = out_moc_est[[2]],
      moc_est_train_Z_natural = out_moc_est[[1]] * train_data$Z +
        (1 - out_moc_est[[1]]) * (1 - train_data$Z),
      moc_est_valid_Z_natural = out_moc_est[[2]] * valid_data$Z +
        (1 - out_moc_est[[2]]) * (1 - valid_data$Z),
      moc_fit = moc_fit
    )
  }
  return(out)
}

###############################################################################

#' Fit pseudo-outcome regression conditioning on mediator-outcome confounder
#'
#' @param train_data A \code{data.table} containing observed data, with columns
#'  in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#'  set appropriately based on the input data. Such a structure is a
#'  convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medoutcon}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. NOT optional for this nuisance parameter.
#' @param learners \code{\link[sl3]{Stack}}, or other learner class (inheriting
#'  from \code{\link[sl3]{Lrnr_base}}), containing a set of learners from
#'  \pkg{sl3}, to be used in fitting a model for this nuisance parameter.
#' @param b_out Output from the internal function for fitting the outcome
#'  regression \code{\link{fit_out_mech}}.
#' @param q_out Output from the internal function for fitting the mechanism of
#'  the intermediate confounder while conditioning on mediators, i.e.,
#'  \code{\link{fit_moc_mech}}, setting \code{type = "q"}.
#' @param r_out Output from the internal function for fitting the mechanism of
#'  the intermediate confounder without conditioning on mediators, i.e.,
#'  \code{\link{fit_moc_mech}}, setting \code{type = "r"}.
#' @param g_out Output from the internal function for fitting the treatment
#'  mechanism without conditioning on mediators \code{\link{fit_treat_mech}}.
#' @param h_out Output from the internal function for fitting the treatment
#'  mechanism conditioning on the mediators \code{\link{fit_treat_mech}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medoutcon}}.
#'
#' @importFrom data.table as.data.table copy setnames ":="
#' @importFrom sl3 sl3_Task Lrnr_mean
#' @importFrom stats sd
fit_nuisance_u <- function(train_data,
                           valid_data,
                           learners,
                           b_out,
                           q_out,
                           r_out,
                           g_out,
                           h_out,
                           w_names) {
  # update observation weights with two-phase sampling weights, if necessary
  train_data[, obs_weights := two_phase_weights * obs_weights]
  valid_data[, obs_weights := two_phase_weights * obs_weights]

  ## extract nuisance estimates necessary for constructing pseudo-outcome
  b_prime <- b_out$b_est_train$b_pred_A_prime
  h_star <- h_out$treat_est_train$treat_pred_A_star
  g_star <- g_out$treat_est_train$treat_pred_A_star[train_data$R == 1]
  h_prime <- h_out$treat_est_train$treat_pred_A_prime
  g_prime <- g_out$treat_est_train$treat_pred_A_prime[train_data$R == 1]
  q_prime_Z_natural <-
    q_out$moc_est_train_Z_natural$moc_pred_A_prime[train_data$R == 1]
  r_prime_Z_natural <- r_out$moc_est_train_Z_natural$moc_pred_A_prime

  # remove observations that were not sampled in second stage
  train_data <- train_data[R == 1, ]
  valid_data <- valid_data[R == 1, ]

  ## create multiplier for pseudo-outcome and then pseudo-outcome
  c_star <- (g_prime / g_star) * (q_prime_Z_natural / r_prime_Z_natural) *
    (h_star / h_prime)
  u_pseudo_train <- b_prime * c_star

  ## override choice of learner with intercept model if constant
  if (stats::sd(u_pseudo_train) < .Machine$double.eps) {
    warning("U: constant pseudo-outcome, using intercept model.")
    learners <- sl3::Lrnr_mean$new()
  }

  ## construct data set and training task
  u_data_train <- data.table::as.data.table(cbind(
    train_data[, ..w_names],
    train_data$A, train_data$Z,
    u_pseudo_train,
    train_data$obs_weights
  ))
  data.table::setnames(u_data_train, c(
    w_names, "A", "Z", "U_pseudo",
    "obs_weights"
  ))
  suppressWarnings(
    u_task_train <- sl3::sl3_Task$new(
      data = u_data_train,
      weights = "obs_weights",
      covariates = c("Z", "A", w_names),
      outcome = "U_pseudo",
      outcome_type = "continuous"
    )
  )

  ## fit model for nuisance parameter regression on training data
  u_param_fit <- learners$train(u_task_train)

  ## construct data set and validation task for prediction
  u_data_valid <- data.table::as.data.table(cbind(
    valid_data[, ..w_names],
    valid_data$A, valid_data$Z,
    rep(0, nrow(valid_data)),
    valid_data$obs_weights
  ))
  data.table::setnames(u_data_valid, c(
    w_names, "A", "Z", "U_pseudo",
    "obs_weights"
  ))
  suppressWarnings(
    u_task_valid <- sl3::sl3_Task$new(
      data = u_data_valid,
      weights = "obs_weights",
      covariates = c("Z", "A", w_names),
      outcome = "U_pseudo",
      outcome_type = "continuous"
    )
  )

  ## predict from nuisance parameter regression on validation and training data
  u_valid_pred <- u_param_fit$predict(u_task_valid)
  u_train_pred <- u_param_fit$predict(u_task_train)

  ## return prediction on validation set
  return(list(
    u_fit = u_param_fit,
    u_pred = as.numeric(u_valid_pred),
    u_train_pred = as.numeric(u_train_pred)
  ))
}

###############################################################################

#' Fit pseudo-outcome regression conditioning on treatment and baseline
#'
#' @param train_data A \code{data.table} containing observed data, with columns
#'  in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#'  set appropriately based on the input data. Such a structure is a
#'  convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medoutcon}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. Not optional for this nuisance parameter.
#' @param contrast A \code{numeric} double indicating the two values of the
#'  intervention \code{A} to be compared. The default value of \code{c(0, 1)}
#'  assumes a binary intervention node \code{A}.
#' @param learners \code{\link[sl3]{Stack}}, or other learner class (inheriting
#'  from \code{\link[sl3]{Lrnr_base}}), containing a set of learners from
#'  \pkg{sl3}, to be used in fitting a model for this nuisance parameter.
#' @param b_out Output from the internal function for fitting the outcome
#'  regression \code{\link{fit_out_mech}}.
#' @param q_out Output from the internal function for fitting the mechanism of
#'  the intermediate confounder while conditioning on the mediators, i.e.,
#'  \code{\link{fit_moc_mech}}, setting \code{type = "q"}.
#' @param m_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (M). The input for this argument is automatically
#'  generated by a call to the wrapper function \code{\link{medoutcon}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medoutcon}}.
#'
#' @importFrom data.table as.data.table copy setnames ":="
#' @importFrom sl3 sl3_Task Lrnr_mean
#' @importFrom stats sd
fit_nuisance_v <- function(train_data,
                           valid_data,
                           contrast,
                           learners,
                           b_out,
                           q_out,
                           m_names,
                           w_names) {
  ## extract nuisance estimates necessary for this routrine
  q_train_prime_Z_one <-
    q_out$moc_est_train_Z_one$moc_pred_A_prime[train_data$R == 1]
  q_valid_prime_Z_one <-
    q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]

  # remove observations that were not sampled in second stage
  train_data <- train_data[R == 1, ]
  valid_data <- valid_data[R == 1, ]

  ## first, compute components of integral over mediator-outcome confounder
  ## assuming Z in {0,1} for interventional effects. NOTE: other cases (e.g.,
  ## continuous intermediate confounder) not yet supported. For the natural
  ## (in)direct effects, this will loop only over Z = 1.
  v_pseudo <- lapply(unique(train_data$Z), function(z_val) {
    ## training data
    train_data_z_interv <- data.table::copy(train_data)
    train_data_z_interv[, obs_weights := two_phase_weights * obs_weights]
    train_data_z_interv[, `:=`(
      Z = z_val,
      A = contrast[1]
    )]

    ## tasks for predicting from trained b and q regression models
    b_reg_train_v_subtask <- sl3::sl3_Task$new(
      data = train_data_z_interv,
      weights = "obs_weights",
      covariates = c(m_names, "Z", "A", w_names),
      outcome = "Y"
    )

    ## outcome regression after intervening on mediator-outcome confounder
    b_pred_train_z_interv <- b_out$b_fit$predict(b_reg_train_v_subtask)
    q_train_prime_z_val <- (z_val * q_train_prime_Z_one) +
      (1 - z_val) * (1 - q_train_prime_Z_one)

    ## now on validation set
    valid_data_z_interv <- data.table::copy(valid_data)
    valid_data_z_interv[, obs_weights := two_phase_weights * obs_weights]
    valid_data_z_interv[, `:=`(
      Z = z_val,
      A = contrast[1]
    )]

    ## tasks for predicting from trained m and q regression models
    b_reg_valid_v_subtask <- sl3::sl3_Task$new(
      data = valid_data_z_interv,
      weights = "obs_weights",
      covariates = c(m_names, "Z", "A", w_names),
      outcome = "Y"
    )

    ## outcome regression after intervening on mediator-outcome confounder
    b_pred_valid_z_interv <- b_out$b_fit$predict(b_reg_valid_v_subtask)
    q_valid_prime_z_val <- (z_val * q_valid_prime_Z_one) +
      (1 - z_val) * (1 - q_valid_prime_Z_one)

    ## return partial pseudo-outcome for v nuisance regression
    out_train <- b_pred_train_z_interv * q_train_prime_z_val
    out_valid <- b_pred_valid_z_interv * q_valid_prime_z_val
    out <- list(
      training = out_train, validation = out_valid,
      b_train = b_pred_train_z_interv,
      b_valid = b_pred_valid_z_interv
    )
    return(out)
  })

  ## compute pseudo-outcome by computing integral via discrete summation
  if (length(unique(train_data$Z)) > 1) {
    ## for the interventional (in)direct effects with binary Z
    v_pseudo_train <- v_pseudo[[1]]$training + v_pseudo[[2]]$training
    v_pseudo_valid <- v_pseudo[[1]]$validation + v_pseudo[[2]]$validation
  } else {
    ## for the natural (in)direct effects with "constant" Z
    v_pseudo_train <- v_pseudo[[1]]$training
    v_pseudo_valid <- v_pseudo[[1]]$validation
  }

  ## extract outcome model predictions with intervened Z for TMLE fluctuation
  if (length(unique(train_data$Z)) > 1) {
    ## for the interventional (in)direct effects with binary Z
    b_pred_A_prime_Z_zero <- v_pseudo[[1]]$b_valid
    b_pred_A_prime_Z_one <- v_pseudo[[2]]$b_valid
  } else {
    ## for the natural (in)direct effects with "constant" Z
    ## NOTE: used in a TMLE fluctuation step later, so by setting the estimate
    ##       to zero under the Z = 0 contrast, we can avoid redundant summation
    b_pred_A_prime_Z_zero <- rep(0, nrow(valid_data))
    b_pred_A_prime_Z_one <- v_pseudo[[1]]$b_valid
  }

  ## override choice of learner with intercept model if constant
  if (stats::sd(v_pseudo_train) < .Machine$double.eps) {
    warning("V: constant pseudo-outcome, using intercept model.")
    learners <- sl3::Lrnr_mean$new()
  }

  ## build regression tasks for training and validation sets
  train_data[, V_pseudo := v_pseudo_train]
  suppressWarnings(
    v_task_train <- sl3::sl3_Task$new(
      data = train_data,
      weights = "obs_weights", # NOTE: should not include two_phase_weights
      covariates = c("A", w_names),
      outcome = "V_pseudo",
      outcome_type = "continuous"
    )
  )
  # NOTE: independent implementation from ID sets A to a* as done below
  valid_data[, `:=`(
    V_pseudo = v_pseudo_valid,
    A = contrast[2]
  )]
  suppressWarnings(
    v_task_valid <- sl3::sl3_Task$new(
      data = valid_data,
      weights = "obs_weights", # NOTE: should not include two_phase_weights
      covariates = c("A", w_names),
      outcome = "V_pseudo",
      outcome_type = "continuous"
    )
  )

  ## fit regression model for v on training task, get predictions on validation
  v_param_fit <- learners$train(v_task_train)
  v_valid_pred <- v_param_fit$predict(v_task_valid)

  ## get predictions on training data
  v_train_pred <- v_param_fit$predict(v_task_train)

  ## return prediction on validation set
  return(list(
    v_fit = v_param_fit,
    v_pred = as.numeric(v_valid_pred),
    v_train_pred = as.numeric(v_train_pred),
    v_pseudo = as.numeric(v_pseudo_valid),
    v_pseudo_train = as.numeric(v_pseudo_train),
    b_A_prime_Z_zero = as.numeric(b_pred_A_prime_Z_zero),
    b_A_prime_Z_one = as.numeric(b_pred_A_prime_Z_one)
  ))
}

################################################################################

#' Fit estimated efficient influence function conditioning on W, A, Z, R, Y
#'
#' This function estimates the conditional efficient influence function,
#' setting the treatment to \ifelse{html}{\out{a<sup>*</sup>}}{\eqn{a^\star}}
#' and conditioning on W, A, Z, R, and Y. It is used to adjust the full-data
#' efficient influence function to account for two-phase sampling.
#'
#' @param train_data A \code{data.table} containing observed data, with columns
#'   in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#'   set appropriately based on the input data. Such a structure is a
#'   convenience utility to passing data around to the various core estimation
#'   routines and is automatically generated by \code{\link{medoutcon}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'   appearing in the preceding argument \code{data}, to be used for estimation
#'   via cross-fitting. Not optional for this nuisance parameter.
#' @param contrast A \code{numeric} double indicating the two values of the
#'   intervention \code{A} to be compared. The default value of \code{c(0, 1)}
#'   assumes a binary intervention node \code{A}.
#' @param learners \code{\link[sl3]{Stack}}, or other learner class (inheriting
#'   from \code{\link[sl3]{Lrnr_base}}), containing a set of learners from
#'   \pkg{sl3}, to be used in fitting a model for this nuisance parameter.
#' @param b_out Output from the internal function for fitting the outcome
#'   regression \code{\link{fit_out_mech}}.
#' @param q_out Output from the internal function for fitting the mechanism of
#'   the intermediate confounder while conditioning on the mediators, i.e.,
#'   \code{\link{fit_moc_mech}}, setting \code{type = "q"}.
#' @param g_out Output from the internal function for fitting the treatment
#'   mechanism without conditioning on mediators \code{\link{fit_treat_mech}}.
#' @param h_out Output from the internal function for fitting the treatment
#'   mechanism conditioning on the mediators \code{\link{fit_treat_mech}}.
#' @param r_out Output from the internal function for fitting the mechanism of
#'   the intermediate confounder without conditioning on mediators, i.e.,
#'   \code{\link{fit_moc_mech}}, setting \code{type = "r"}.
#' @param u_out Output from the internal function for fitting the pseudo-outcome
#'   regression conditioning on mediator-outcome confounder , i.e.,
#'   \code{\link{fit_nuisance_u}}.
#' @param v_out Output from the internal function for fitting the pseudo-outcome
#'   regression conditioning on treatment and baseline i.e.,
#'   \code{\link{fit_nuisance_v}}.
#' @param m_names A \code{character} vector of the names of the columns that
#'   correspond to mediators (M). The input for this argument is automatically
#'   generated by a call to the wrapper function \code{\link{medoutcon}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'   correspond to baseline covariates (W). The input for this argument is
#'   automatically generated by \code{\link{medoutcon}}.
#'
#' @importFrom data.table copy ":="
#' @importFrom sl3 sl3_Task
#'
#' @keywords internal
fit_nuisance_d <- function(train_data,
                           valid_data,
                           contrast,
                           learners,
                           b_out,
                           g_out,
                           h_out,
                           q_out,
                           r_out,
                           u_out,
                           v_out,
                           m_names,
                           w_names) {
  ## extract nuisance estimates necessary for constructing pseudo-outcome
  b_prime <- b_out$b_est_train$b_pred_A_prime
  h_star <- h_out$treat_est_train$treat_pred_A_star
  g_star <- g_out$treat_est_train$treat_pred_A_star[train_data$R == 1]
  h_prime <- h_out$treat_est_train$treat_pred_A_prime
  g_prime <- g_out$treat_est_train$treat_pred_A_prime[train_data$R == 1]
  u_prime <- u_out$u_train_pred
  v_star <- v_out$v_train_pred
  q_prime_Z_one <-
    q_out$moc_est_train_Z_one$moc_pred_A_prime[train_data$R == 1]
  q_prime_Z_natural <-
    q_out$moc_est_train_Z_natural$moc_pred_A_prime[train_data$R == 1]
  r_prime_Z_natural <- r_out$moc_est_train_Z_natural$moc_pred_A_prime

  # NOTE: assuming Z in {0,1}; other cases not supported yet
  u_int_eif <- lapply(c(1, 0), function(z_val) {
    # intervene on training and validation data sets
    train_data_z_interv <- data.table::copy(train_data[R == 1, ])
    train_data_z_interv[, `:=`(
      Z = z_val,
      A = contrast[1],
      U_pseudo = u_prime
    )]

    # predict u(z, a', w) using intervened data with treatment set A = a'
    suppressWarnings(
      u_task_train_z_interv <- sl3::sl3_Task$new(
        data = train_data_z_interv,
        weights = "obs_weights", # NOTE: should not include two_phase_weights
        covariates = c("Z", "A", w_names),
        outcome = "U_pseudo",
        outcome_type = "continuous"
      )
    )

    # return partial pseudo-outcome for v nuisance regression
    out_train <- u_out[["u_fit"]]$predict(u_task_train_z_interv)
    return(out_train)
  })
  u_int_eif <- do.call(`-`, u_int_eif)

  # create inverse probability weights
  ipw_a_prime <- as.numeric(train_data[R == 1, A] == contrast[1]) / g_prime
  ipw_a_star <- as.numeric(train_data[R == 1, A] == contrast[2]) / g_star

  # residual term for outcome component of EIF
  c_star <- (g_prime / g_star) * (q_prime_Z_natural / r_prime_Z_natural) *
    (h_star / h_prime)

  # compute uncentered efficient influence function components
  eif_y <- ipw_a_prime * c_star / mean(ipw_a_prime * c_star) *
    (train_data[R == 1, Y] - b_prime)
  eif_u <- ipw_a_prime / mean(ipw_a_prime) * u_int_eif *
    (train_data[R == 1, Z] - q_prime_Z_one)
  eif_v <- ipw_a_star / mean(ipw_a_star) * (v_out$v_pseudo_train - v_star)

  # compute the centered eif
  plugin_est <- est_plugin(v_pred = v_star)
  centered_eif <- eif_y + eif_u + eif_v + v_star - plugin_est

  # create a dataset to for the estimation task
  eif_data_train <- data.table::copy(train_data)
  eif_data_train <- eif_data_train[R == 1, ]
  eif_data_train[, eif := centered_eif]

  # generate the sl3 task
  # NOTE: Purposefully not adding two-phase sampling weights
  suppressWarnings(
    d_task_train <- sl3::sl3_Task$new(
      data = eif_data_train,
      weights = "obs_weights", # NOTE: should not include two_phase_weights
      covariates = c(w_names, "A", "Z", "Y"),
      outcome = "eif",
      outcome_type = "continuous"
    )
  )

  ## fit model for nuisance parameter regression on training data
  d_param_fit <- learners$train(d_task_train)

  ## predict the efficient influence function on the validation data
  d_task_valid <- sl3::sl3_Task$new(
    data = valid_data,
    weights = "obs_weights",
    covariates = c(w_names, "A", "Z", "Y")
  )

  ## predict from nuisance parameter regression model on validation
  d_valid_pred <- d_param_fit$predict(d_task_valid)

  ## return prediction on validation set
  return(list(
    "d_fit" = d_param_fit,
    "d_pred" = as.numeric(d_valid_pred)
  ))
}
nhejazi/medoutcon documentation built on July 16, 2025, 5:38 p.m.