R/fit_mechanisms.R

Defines functions fit_phi_mech fit_m_mech fit_e_mech fit_g_mech

Documented in fit_e_mech fit_g_mech fit_m_mech fit_phi_mech

utils::globalVariables(c("A", ".N", "..w_names"))

#' Fit propensity score with incremental stochastic shift intervention
#'
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, Z, A, W), with column names set
#'  appropriately based on the original input data. Such a structure is merely
#'  a convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medshift}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. Optional, defaulting to \code{NULL}.
#' @param delta A \code{numeric} value indicating the degree of shift in the
#'  intervention to be used in defining the causal quantity of interest. In the
#'  case of binary interventions, this takes the form of an incremental
#'  propensity score shift, acting as a multiplier of the odds with which a
#'  given observational unit receives the intervention (EH Kennedy, 2018, JASA;
#'  <doi:10.1080/01621459.2017.1422737>).
#' @param learners A \code{\link[sl3]{Stack}} (or other learner class that
#'   inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'   instantiated learners from \pkg{sl3}, to be used in fitting the propensity
#'   score, i.e., g = P(A | W).
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medshift}}.
#'
#' @importFrom data.table as.data.table copy ":="
#' @importFrom sl3 sl3_Task
fit_g_mech <- function(data, valid_data = NULL,
                       delta, learners, w_names) {
  #  construct task for propensity score fit
  g_task <- sl3::sl3_Task$new(
    data = data,
    covariates = w_names,
    outcome = "A",
    id = "ids"
  )

  # fit and predict
  g_fit_stack <- learners$train(g_task)

  # use full data for counterfactual prediction if no validation data provided
  if (is.null(valid_data)) {
    # copy full data
    data_pred <- data.table::copy(data)
  } else {
    # copy only validation data
    data_pred <- data.table::copy(valid_data)
  }

  # create task for estimating propensity score P(A = 1 | W)
  g_task_pred <- sl3::sl3_Task$new(
    data = data_pred,
    covariates = w_names,
    outcome = "A",
    id = "ids"
  )
  g_pred_A1 <- g_fit_stack$predict(g_task_pred)

  # compute A = 0 case by symmetry
  g_pred_A0 <- 1 - g_pred_A1

  # directly computed the shifted propensity score
  g_pred_shifted_A1 <- (delta * g_pred_A1) /
    (delta * g_pred_A1 + (1 - g_pred_A1))

  # compute shifted propensity score for A = 0 by symmetry
  g_pred_shifted_A0 <- 1 - g_pred_shifted_A1

  # bounding to numerical precision
  g_pred_A1 <- bound_precision(g_pred_A1)
  g_pred_A0 <- bound_precision(g_pred_A0)
  g_pred_shifted_A1 <- bound_precision(g_pred_shifted_A1)
  g_pred_shifted_A0 <- bound_precision(g_pred_shifted_A0)

  # bounding for potential positivity issues
  g_pred_A1 <- bound_propensity(g_pred_A1)
  g_pred_A0 <- bound_propensity(g_pred_A0)
  g_pred_shifted_A1 <- bound_propensity(g_pred_shifted_A1)
  g_pred_shifted_A0 <- bound_propensity(g_pred_shifted_A0)

  # output
  out <- list(
    g_est = data.table::data.table(cbind(
      g_pred_A1 = g_pred_A1,
      g_pred_A0 = g_pred_A0,
      g_pred_shifted_A1 = g_pred_shifted_A1,
      g_pred_shifted_A0 = g_pred_shifted_A0
    )),
    g_fit = g_fit_stack
  )
  return(out)
}

###############################################################################

#' Fit propensity score regression while conditioning on mediators
#'
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, Z, A, W), with column names set
#'  appropriately based on the original input data. Such a structure is merely
#'  a convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medshift}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. Optional, defaulting to \code{NULL}.
#' @param learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting a propensity
#'  score that conditions on the mediators, i.e., e = P(A | Z, W).
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by a call to the wrapper function \code{medshift}.
#' @param z_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (Z). The input for this argument is automatically
#'  generated by \code{\link{medshift}}.
#'
#' @importFrom data.table as.data.table copy
#' @importFrom sl3 sl3_Task
fit_e_mech <- function(data, valid_data = NULL,
                       learners, z_names, w_names) {
  # construct task for nuisance parameter fit
  e_task <- sl3::sl3_Task$new(
    data = data,
    covariates = c(z_names, w_names),
    outcome = "A",
    id = "ids"
  )

  # fit and predict
  e_fit_stack <- learners$train(e_task)

  # use full data for counterfactual prediction if no validation data provided
  if (is.null(valid_data)) {
    # copy full data
    data_pred <- data.table::copy(data)
  } else {
    # copy only validation data
    data_pred <- data.table::copy(valid_data)
  }

  # create task for estimating propensity score P(A = 1 | W)
  e_task_pred <- sl3::sl3_Task$new(
    data = data_pred,
    covariates = c(z_names, w_names),
    outcome = "A",
    id = "ids"
  )

  # predict from trained model on counterfactual data
  e_pred_A1 <- e_fit_stack$predict(e_task_pred)

  # get values of nuisance parameter E for A = 0 by symmetry with A = 1 case
  e_pred_A0 <- 1 - e_pred_A1

  # bounding to numerical precision
  e_pred_A1 <- bound_precision(e_pred_A1)
  e_pred_A0 <- bound_precision(e_pred_A0)

  # bounding for potential positivity issues
  e_pred_A1 <- bound_propensity(e_pred_A1)
  e_pred_A0 <- bound_propensity(e_pred_A0)

  # output
  out <- list(
    e_est = data.table::data.table(cbind(
      e_pred_A1 = e_pred_A1,
      e_pred_A0 = e_pred_A0
    )),
    e_fit = e_fit_stack
  )
  return(out)
}

###############################################################################

#' Fit outcome regression
#'
#' @param data A \code{data.table} containing the observed data, with columns
#'  in the order specified by the NPSEM (Y, Z, A, W), with column names set
#'  appropriately based on the original input data. Such a structure is merely
#'  a convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medshift}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{data}, to be used for estimation
#'  via cross-fitting. Optional, defaulting to \code{NULL}.
#' @param learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in fitting the outcome
#'  regression, i.e., m(A, Z, W).
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medshift}}.
#' @param z_names A \code{character} vector of the names of the columns that
#'  correspond to mediators (Z). The input for this argument is automatically
#'  generated by \code{\link{medshift}}.
#'
#' @importFrom data.table as.data.table copy
#' @importFrom sl3 sl3_Task
fit_m_mech <- function(data, valid_data = NULL,
                       learners, z_names, w_names) {
  #  construct task for propensity score fit
  m_task <- sl3::sl3_Task$new(
    data = data,
    covariates = c("A", z_names, w_names),
    outcome = "Y",
    id = "ids"
  )

  # fit and predict
  m_fit_stack <- learners$train(m_task)

  # use full data for counterfactual prediction if no validation data provided
  if (is.null(valid_data)) {
    # copy full data
    data_A1 <- data.table::copy(data)
    data_A0 <- data.table::copy(data)
  } else {
    # copy only validation data
    data_A1 <- data.table::copy(valid_data)
    data_A0 <- data.table::copy(valid_data)

    # NOTE: to fit nuisance regression phi, we need estimates on training set
    # to construct the relevant pseudo-outcome, i.e., m(Z,A=1,W) - m(Z,A=0,W)
    data_train_A1 <- data.table::copy(data)
    data_train_A1[, A := 1]
    m_task_train_A1 <- sl3::sl3_Task$new(
      data = data_train_A1,
      covariates = c("A", z_names, w_names),
      outcome = "Y",
      id = "ids"
    )
    m_pred_train_A1 <- m_fit_stack$predict(m_task_train_A1)

    # repeat for A = 0 case
    data_train_A0 <- data.table::copy(data)
    data_train_A0[, A := 0]
    m_task_train_A0 <- sl3::sl3_Task$new(
      data = data_train_A0,
      covariates = c("A", z_names, w_names),
      outcome = "Y",
      id = "ids"
    )
    m_pred_train_A0 <- m_fit_stack$predict(m_task_train_A0)
  }

  # copy data and set intervention A = 1
  data_A1[, A := 1]
  m_task_A1 <- sl3::sl3_Task$new(
    data = data_A1,
    covariates = c("A", z_names, w_names),
    outcome = "Y",
    id = "ids"
  )
  m_pred_A1 <- m_fit_stack$predict(m_task_A1)

  # copy data and set intervention A = 0
  data_A0[, A := 0]
  m_task_A0 <- sl3::sl3_Task$new(
    data = data_A0,
    covariates = c("A", z_names, w_names),
    outcome = "Y",
    id = "ids"
  )
  m_pred_A0 <- m_fit_stack$predict(m_task_A0)

  # output
  out <- list(
    m_est = data.table::data.table(cbind(
      m_pred_A1 = m_pred_A1,
      m_pred_A0 = m_pred_A0
    )),
    m_est_training = data.table::data.table(cbind(
      m_pred_A1 =
        if (!is.null(valid_data)) {
          m_pred_train_A1
        } else {
          rep(NA, nrow(data))
        },
      m_pred_A0 =
        if (!is.null(valid_data)) {
          m_pred_train_A0
        } else {
          rep(NA, nrow(data))
        }
    )),
    m_fit_sl = m_fit_stack
  )
  return(out)
}

###############################################################################

#' Fit intervention-specific exponential tilt nuisance parameter
#'
#' @param train_data A \code{data.table} containing the observed data, with
#'  columns in the order specified by the NPSEM (Y, Z, A, W), with column names
#'  set appropriately based on the input data. Such a structure is merely a
#'  convenience utility to passing data around to the various core estimation
#'  routines and is automatically generated by \code{\link{medshift}}.
#' @param valid_data A holdout data set, with columns exactly matching those
#'  appearing in the preceding argument \code{train_data}, to be used for
#'  estimation via cross-fitting. Not optional for this nuisance parameter.
#' @param learners A \code{\link[sl3]{Stack}} (or other learner class that
#'  inherits from \code{\link[sl3]{Lrnr_base}}), containing a single or set of
#'  instantiated learners from \pkg{sl3}, to be used in a regression of a
#'  pseudo-outcome on the baseline covariates, i.e.,
#'  phi(W) = E[m(A = 1, Z, W) - m(A = 0, Z, W) | W).
#' @param m_output Object containing results from fitting the outcome
#'  regression, as produced by \code{\link{fit_m_mech}}.
#' @param w_names A \code{character} vector of the names of the columns that
#'  correspond to baseline covariates (W). The input for this argument is
#'  automatically generated by \code{\link{medshift}}.
#'
#' @importFrom data.table data.table as.data.table
#' @importFrom sl3 sl3_Task
fit_phi_mech <- function(train_data, valid_data, learners, m_output,
                         w_names) {
  # regression on pseudo-outcome for this nuisance parameter
  # NOTE: first, learn the regression model using the training data
  m_pred_train_A1 <- m_output$m_est_training$m_pred_A1
  m_pred_train_A0 <- m_output$m_est_training$m_pred_A0
  m_pred_train_diff <- m_pred_train_A1 - m_pred_train_A0

  # construct data structure for use with task objects
  phi_train_data <- data.table::data.table(
    m_diff = m_pred_train_diff,
    train_data[, ..w_names],
    ids = train_data[["ids"]]
  )
  phi_train_task <- sl3::sl3_Task$new(
    data = phi_train_data,
    covariates = w_names,
    outcome = "m_diff",
    outcome_type = "continuous",
    id = "ids"
  )

  # fit stack of learners to learn the regression model for phi
  phi_fit <- learners$train(phi_train_task)

  # NOW, predict on the validation data
  # NOTE: first, as before, must construct the pseudo-outcome
  m_pred_valid_A1 <- m_output$m_est$m_pred_A1
  m_pred_valid_A0 <- m_output$m_est$m_pred_A0
  m_pred_valid_diff <- m_pred_valid_A1 - m_pred_valid_A0

  # construct data structure for use with task objects
  phi_valid_data <- data.table::data.table(
    m_diff = m_pred_valid_diff,
    valid_data[, ..w_names],
    ids = valid_data[["ids"]]
  )
  phi_valid_task <- sl3::sl3_Task$new(
    data = phi_valid_data,
    covariates = w_names,
    outcome = "m_diff",
    outcome_type = "continuous",
    id = "ids"
  )

  # predict and return for validation set only
  phi_est <- phi_fit$predict(phi_valid_task)
  return(phi_est)
}
nhejazi/medshift documentation built on Feb. 8, 2022, 10:55 p.m.