utils::globalVariables(c("..w_names", "A", "Z", "Y", "R", "v_star"))
#' EIF for natural and interventional (in)direct effects
#'
#' @param fold Object specifying cross-validation folds as generated by a call
#' to \code{\link[origami]{make_folds}}.
#' @param data_in A \code{data.table} containing the observed data with columns
#' are in the order specified by the NPSEM (Y, M, R, Z, A, W), with column
#' names set appropriately based on the data. Such a structure is merely a
#' convenience utility to passing data around to the various core estimation
#' routines and is automatically generated by \code{\link{medoutcon}}.
#' @param contrast A \code{numeric} double indicating the two values of the
#' intervention \code{A} to be compared. The default value of \code{NULL} has
#' no effect, as the value of the argument \code{effect} is instead used to
#' define the contrasts. To override \code{effect}, provide a \code{numeric}
#' double vector, giving the values of a' and a*, e.g., \code{c(0, 1)}.
#' @param g_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for the propensity score.
#' @param h_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a parameterization of the
#' propensity score that conditions on the mediators.
#' @param b_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for the outcome regression.
#' @param q_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#' the intermediate confounder, conditioning on the treatment and potential
#' baseline covariates.
#' @param r_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#' the intermediate confounder, conditioning on the mediators, the treatment,
#' and potential baseline confounders.
#' @param u_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#' for in the efficient influence function.
#' @param v_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#' for in the efficient influence function.
#' @param d_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit an initial efficient influence
#' function regression when computing the efficient influence function in a
#' two-phase sampling design.
#' @param effect_type A \code{character} indicating whether components of the
#' interventional or natural (in)direct effects are to be estimated. In the
#' case of the natural (in)direct effects, estimation of several nuisance
#' parameters is unnecessary.
#' @param w_names A \code{character} vector of the names of the columns that
#' correspond to baseline covariates (W). The input for this argument is
#' automatically generated by \code{\link{medoutcon}}.
#' @param m_names A \code{character} vector of the names of the columns that
#' correspond to mediators (M). The input for this argument is automatically
#' generated by \code{\link{medoutcon}}.
#' @param g_bounds A \code{numeric} vector containing two values, the
#' first being the minimum allowable estimated propensity score value and the
#' second being the maximum allowable for estimated propensity score value.
#'
#' @importFrom assertthat assert_that
#' @importFrom data.table data.table copy
#' @importFrom origami training validation fold_index
#' @importFrom sl3 Lrnr_mean
#'
#' @keywords internal
cv_eif <- function(fold,
data_in,
contrast,
g_learners,
h_learners,
b_learners,
q_learners,
r_learners,
u_learners,
v_learners,
d_learners,
effect_type = c("interventional", "natural"),
w_names,
m_names,
g_bounds = c(0.005, 0.995)) {
# make training and validation data
train_data <- origami::training(data_in)
valid_data <- origami::validation(data_in)
# 1) fit regression for propensity score regression
g_out <- fit_treat_mech(
train_data = train_data,
valid_data = valid_data,
contrast = contrast,
learners = g_learners,
w_names = w_names,
m_names = m_names,
type = "g",
bounds = g_bounds
)
# 2) fit clever regression for treatment, conditional on mediators
h_out <- fit_treat_mech(
train_data = train_data,
valid_data = valid_data,
contrast = contrast,
learners = h_learners,
w_names = w_names,
m_names = m_names,
type = "h",
bounds = g_bounds
)
# 3) fit outcome regression
b_out <- fit_out_mech(
train_data = train_data,
valid_data = valid_data,
contrast = contrast,
learners = b_learners,
m_names = m_names,
w_names = w_names
)
# 4) fit mediator-outcome confounder regression, excluding mediator(s)
if (effect_type == "natural") {
# NOTE: in this case Z := 1 in the wrapper function, so overriding the
# provided learner with an intercept model guarantees predictions
# that are returned will be uniformly 1
q_learners <- sl3::Lrnr_mean$new()
}
q_out <- fit_moc_mech(
train_data = train_data,
valid_data = valid_data,
contrast = contrast,
learners = q_learners,
m_names = m_names,
w_names = w_names,
type = "q"
)
# 5) fit mediator-outcome confounder regression, conditioning on mediator(s)
if (effect_type == "natural") {
# NOTE: in this case Z := 1 in the wrapper function, so overriding the
# provided learner with an intercept model guarantees predictions
# that are returned will be uniformly 1
r_learners <- sl3::Lrnr_mean$new()
}
r_out <- fit_moc_mech(
train_data = train_data,
valid_data = valid_data,
contrast = contrast,
learners = r_learners,
m_names = m_names,
w_names = w_names,
type = "r"
)
# extract components; NOTE: only do this for observations in validation set
b_prime <- b_out$b_est_valid$b_pred_A_prime
h_star <- h_out$treat_est_valid$treat_pred_A_star
g_star <- g_out$treat_est_valid$treat_pred_A_star[valid_data$R == 1]
h_prime <- h_out$treat_est_valid$treat_pred_A_prime
g_prime <- g_out$treat_est_valid$treat_pred_A_prime[valid_data$R == 1]
q_prime_Z_one <-
q_out$moc_est_valid_Z_one$moc_pred_A_prime[valid_data$R == 1]
r_prime_Z_one <- r_out$moc_est_valid_Z_one$moc_pred_A_prime
q_prime_Z_natural <-
q_out$moc_est_valid_Z_natural$moc_pred_A_prime[valid_data$R == 1]
r_prime_Z_natural <- r_out$moc_est_valid_Z_natural$moc_pred_A_prime
# need pseudo-outcome regressions with intervention set to a contrast
# NOTE: training fits of these nuisance functions must be performed using the
# data corresponding to the natural intervention value but predictions
# are only needed for u(z,a',w) and v(a*,w) as per form of the EIF
valid_data_a_prime <- data.table::copy(valid_data)[, A := contrast[1]]
valid_data_a_star <- data.table::copy(valid_data)[, A := contrast[2]]
u_out <- fit_nuisance_u(
train_data = train_data,
valid_data = valid_data_a_prime,
learners = u_learners,
b_out = b_out,
q_out = q_out,
r_out = r_out,
g_out = g_out,
h_out = h_out,
w_names = w_names
)
u_prime <- u_out$u_pred
v_out <- fit_nuisance_v(
train_data = train_data,
valid_data = valid_data_a_star,
contrast = contrast,
learners = v_learners,
b_out = b_out,
q_out = q_out,
m_names = m_names,
w_names = w_names
)
v_star <- v_out$v_pred
# NOTE: assuming Z in {0,1}; other cases not supported yet
u_int_eif <- lapply(c(1, 0), function(z_val) {
# intervene on training and validation data sets
valid_data_z_interv <- data.table::copy(valid_data[R == 1, ])
valid_data_z_interv[, `:=`(
Z = z_val,
A = contrast[1],
U_pseudo = u_prime
)]
# predict u(z, a', w) using intervened data with treatment set A = a'
# NOTE: here, obs_weights should not include two_phase_weights (?)
suppressWarnings(
u_task_valid_z_interv <- sl3::sl3_Task$new(
data = valid_data_z_interv,
weights = "obs_weights",
covariates = c("Z", "A", w_names),
outcome = "U_pseudo",
outcome_type = "continuous"
)
)
# return partial pseudo-outcome for v nuisance regression
out_valid <- u_out[["u_fit"]]$predict(u_task_valid_z_interv)
return(out_valid)
})
u_int_eif <- do.call(`-`, u_int_eif)
# create inverse probability weights
ipw_a_prime <- as.numeric(valid_data[R == 1, A] == contrast[1]) / g_prime
ipw_a_star <- as.numeric(valid_data[R == 1, A] == contrast[2]) / g_star
# residual term for outcome component of EIF
c_star <- (g_prime / g_star) * (q_prime_Z_natural / r_prime_Z_natural) *
(h_star / h_prime)
# compute uncentered efficient influence function components
eif_y <- ipw_a_prime * c_star / mean(ipw_a_prime * c_star) *
(valid_data[R == 1, Y] - b_prime)
eif_u <- ipw_a_prime / mean(ipw_a_prime) * u_int_eif *
(valid_data[R == 1, Z] - q_prime_Z_one)
eif_v <- ipw_a_star / mean(ipw_a_star) * (v_out$v_pseudo - v_star)
# SANITY CHECK: EIF_U should be ~ZERO~ for natural (in)direct effects
if (effect_type == "natural") {
assertthat::assert_that(all(eif_u == 0))
}
# un-centered efficient influence function
eif <- eif_y + eif_u + eif_v + v_star
# adjust un-centered EIF for two-phase sampling design
if (!all(data_in$R == 1) || !(all(data_in$two_phase_weights == 1))) {
# compute a centered EIF
plugin_est <- est_plugin(v_pred = v_star)
centered_eif <- eif - plugin_est
# estimate the conditional EIF using the validation data
d_out <- fit_nuisance_d(
train_data = train_data,
valid_data = valid_data,
contrast = contrast,
learners = d_learners,
b_out = b_out,
g_out = g_out,
h_out = h_out,
q_out = q_out,
r_out = r_out,
u_out = u_out,
v_out = v_out,
m_names = m_names,
w_names = w_names
)
centered_eif_pred <- d_out$d_pred
# compute the two-phase sampling un-centered EIF
full_eif <- two_phase_eif(
R = valid_data$R,
two_phase_weights = valid_data$two_phase_weights,
eif = centered_eif,
eif_predictions = centered_eif_pred,
plugin_est = plugin_est
)
} else {
full_eif <- eif
centered_eif_pred <- NA
}
# output list
out <- list(
tmle_components = data.table::data.table(
# components necessary for fluctuation step of TMLE
g_prime = g_prime, g_star = g_star, h_prime = h_prime, h_star = h_star,
q_prime_Z_natural = q_prime_Z_natural, q_prime_Z_one = q_prime_Z_one,
r_prime_Z_natural = r_prime_Z_natural, r_prime_Z_one = r_prime_Z_one,
v_star = v_star, u_int_diff = u_int_eif,
b_prime = b_prime, b_prime_Z_zero = v_out$b_A_prime_Z_zero,
b_prime_Z_one = v_out$b_A_prime_Z_one, D_star = eif,
# fold IDs
fold = origami::fold_index()
),
D_star = full_eif,
D_pred = centered_eif_pred
)
return(out)
}
###############################################################################
#' Plug-in estimator
#'
#' A convenience function for the plug-in estimator.
#'
#' @param v_pred A \code{numeric} vector of the predicted values of the v(a, w)
#' nuisance parameter.
#'
#' @return A \code{numeric} representing the plug-in estimate of the estimand.
#'
#' @keywords internal
est_plugin <- function(v_pred) {
mean(v_pred)
}
###############################################################################
#' Two-phase sampling adjusted, un-centered efficient influence function
#'
#' Adjust the efficient influence function to account for the use of two-phase
#' sampling designs in measuring the mediators.
#'
#' @param R A \code{logical} vector indicating whether an sampled observation's
#' mediators were measured using a two-phase sampling design.
#' @param two_phase_weights A \code{numeric} vector of known observation-level
#' weights corresponding to the inverse probability of the mediators being
#' measured. These weights should only be provided if the two-phase sampling
#' indicator \code{R} is specified.
#' @param eif A \code{numeric} vector of the efficient influence function.
#' @param eif_predictions A \code{numeric} vector of the predicted efficient
#' influence function, conditioning on the mediator being measured.
#' @param plugin_est A \code{numeric} corresponding to the plug-in estimate for
#' the given contrast.
#'
#' @return An un-centered efficeint influence function that accounts for the
#' two-phase sampling design.
#'
#' @keywords internal
two_phase_eif <- function(R,
two_phase_weights,
eif,
eif_predictions,
plugin_est) {
# compute the weights for the EIF update
ipw_two_phase <- R * two_phase_weights
# for each index in R with R == 0, add a zero at the same index in eif
new_eif <- rep(NA, length(R))
eif_idx <- 1
for (idx in seq_along(R)) {
if (R[idx] == 1) {
new_eif[idx] <- eif[eif_idx]
eif_idx <- eif_idx + 1
} else {
new_eif[idx] <- 0
}
}
# compute updated observed-data EIF by projection of complete-data EIF
# NOTE: D_{obs} = R/g_R * D_{full} -
# (R/g_R - 1) * E[D_{full} | R = 1, W, A, Z, Y]
two_phase_eif <- ipw_two_phase * new_eif +
(1 - ipw_two_phase) * eif_predictions
# return the un-centered two-phase eif
uncentered_two_phase_eif <- two_phase_eif + plugin_est
return(uncentered_two_phase_eif)
}
###############################################################################
#' One-step estimator for natural and interventional (in)direct effects
#'
#' @param data A \code{data.table} containing the observed data, with columns
#' in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#' set appropriately based on the input data. Such a structure is merely a
#' convenience utility to passing data around to the various core estimation
#' routines and is automatically generated by \code{\link{medoutcon}}.
#' @param contrast A \code{numeric} double indicating the two values of the
#' intervention \code{A} to be compared. The default value of \code{NULL} has
#' no effect, as the value of the argument \code{effect} is instead used to
#' define the contrasts. To override \code{effect}, provide a \code{numeric}
#' double vector, giving the values of a' and a*, e.g., \code{c(0, 1)}.
#' @param g_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for the propensity score.
#' @param h_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a parameterization of the
#' propensity score that conditions on the mediators.
#' @param b_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for the outcome regression.
#' @param q_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#' the intermediate confounder, conditioning on the treatment and potential
#' baseline covariates.
#' @param r_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#' the intermediate confounder, conditioning on the mediators, the treatment,
#' and potential baseline confounders.
#' @param u_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#' for in the efficient influence function.
#' @param v_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#' for in the efficient influence function.
#' @param d_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit an initial efficient influence
#' function regression when computing the efficient influence function in a
#' two-phase sampling design.
#' @param w_names A \code{character} vector of the names of the columns that
#' correspond to baseline covariates (W). The input for this argument is
#' automatically generated by \code{\link{medoutcon}}.
#' @param m_names A \code{character} vector of the names of the columns that
#' correspond to mediators (M). The input for this argument is automatically
#' generated by \code{\link{medoutcon}}.
#' @param y_bounds A \code{numeric} double indicating the minimum and maximum
#' observed values of the outcome variable Y prior to its being re-scaled to
#' the unit interval.
#' @param g_bounds A \code{numeric} vector containing two values, the first
#' being the minimum allowable estimated propensity score value and the second
#' being the maximum allowable for estimated propensity score value.
#' @param effect_type A \code{character} indicating whether components of the
#' interventional or natural (in)direct effects are to be estimated. In the
#' case of the natural (in)direct effects, estimation of several nuisance
#' parameters is unnecessary.
#' @param svy_weights A \code{numeric} vector of observation-level weights that
#' have been computed externally. Such weights are used in the construction of
#' a re-weighted estimator.
#' @param cv_folds A \code{numeric} integer specifying the number of folds to
#' be created for cross-validation. Use of cross-validation allows for entropy
#' conditions on the one-step estimator to be relaxed. For compatibility with
#' \code{\link[origami]{make_folds}}, this value specified must be greater
#' than or equal to 2; the default is to create 5 folds.
#' @param cv_strat A \code{logical} atomic vector indicating whether V-fold
#' cross-validation should stratify the folds based on the outcome variable.
#' If \code{TRUE}, the folds are stratified by passing the outcome variable to
#' the \code{strata_ids} argument of \code{\link[origami]{make_folds}}. While
#' the default is \code{FALSE}, an override is triggered when the incidence of
#' the binary outcome variable falls below the tolerance in \code{strat_pmin}.
#' @param strat_pmin A \code{numeric} atomic vector indicating a tolerance for
#' the minimum proportion of cases (for a binary outcome variable) below which
#' stratified V-fold cross-validation is invoked if \code{cv_strat} is set to
#' \code{TRUE} (default is \code{FALSE}). The default tolerance is 0.1.
#'
#' @importFrom assertthat assert_that
#' @importFrom stats var weighted.mean
#' @importFrom origami make_folds cross_validate folds_vfold
#'
#' @keywords internal
est_onestep <- function(data,
contrast,
g_learners,
h_learners,
b_learners,
q_learners,
r_learners,
u_learners,
v_learners,
d_learners,
w_names,
m_names,
y_bounds,
g_bounds = c(0.005, 0.995),
effect_type = c("interventional", "natural"),
svy_weights = NULL,
cv_folds = 10L,
cv_strat = FALSE,
strat_pmin = 0.1) {
# make sure that more than one fold is specified
assertthat::assert_that(cv_folds > 1L)
# create cross-validation folds
if (cv_strat && data[, mean(Y) <= strat_pmin]) {
# check that outcome is binary for stratified V-fold cross-validation
assertthat::assert_that(data[, all(unique(Y) %in% c(0, 1))])
# if outcome is binary and rare, use stratified V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds,
strata_ids = data$Y
)
} else {
# just use standard V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds
)
}
# estimate the EIF on a per-fold basis
cv_eif_results <- origami::cross_validate(
cv_fun = cv_eif,
folds = folds,
data_in = data,
contrast = contrast,
g_learners = g_learners,
h_learners = h_learners,
b_learners = b_learners,
q_learners = q_learners,
r_learners = r_learners,
u_learners = u_learners,
v_learners = v_learners,
d_learners = d_learners,
effect_type = effect_type,
w_names = w_names,
m_names = m_names,
g_bounds = g_bounds,
use_future = FALSE,
.combine = FALSE
)
# get estimated efficient influence function
v_star <- do.call(rbind, cv_eif_results[[1]])$v_star
obs_valid_idx <- do.call(c, lapply(folds, `[[`, "validation_set"))
cv_eif_est <- unlist(cv_eif_results$D_star)[order(obs_valid_idx)]
# re-scale efficient influence function
eif_est_rescaled <- cv_eif_est %>%
scale_from_unit(y_bounds[2], y_bounds[1])
# compute one-step estimate and variance from efficient influence function
if (is.null(svy_weights)) {
os_est <- mean(eif_est_rescaled)
eif_est_out <- eif_est_rescaled
} else {
# compute a re-weighted one-step, with re-weighted influence function
os_est <- stats::weighted.mean(eif_est_rescaled, svy_weights)
eif_est_out <- eif_est_rescaled * svy_weights
}
os_var <- stats::var(eif_est_out) / length(eif_est_out)
# output
os_est_out <- list(
theta = os_est,
theta_plugin = est_plugin(v_star),
var = os_var,
eif = (eif_est_out - os_est),
type = "onestep"
)
return(os_est_out)
}
###############################################################################
#' TML estimator for natural and interventional (in)direct effects
#'
#' @param data A \code{data.table} containing the observed data, with columns
#' in the order specified by the NPSEM (Y, M, R, Z, A, W), with column names
#' set appropriately based on the input data. Such a structure is merely a
#' convenience utility to passing data around to the various core estimation
#' routines and is automatically generated by \code{\link{medoutcon}}.
#' @param contrast A \code{numeric} double indicating the two values of the
#' intervention \code{A} to be compared. The default value of \code{NULL} has
#' no effect, as the value of the argument \code{effect} is instead used to
#' define the contrasts. To override \code{effect}, provide a \code{numeric}
#' double vector, giving the values of a' and a*, e.g., \code{c(0, 1)}.
#' @param g_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for the propensity score.
#' @param h_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a parameterization of the
#' propensity score that conditions on the mediators.
#' @param b_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for the outcome regression.
#' @param q_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#' the intermediate confounder, conditioning on the treatment and potential
#' baseline covariates.
#' @param r_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a model for a nuisance regression of
#' the intermediate confounder, conditioning on the mediators, the treatment,
#' and potential baseline confounders.
#' @param u_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#' for in the efficient influence function.
#' @param v_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit a pseudo-outcome regression required
#' for in the efficient influence function.
#' @param d_learners A \code{\link[sl3]{Stack}} object, or other learner class
#' (inheriting from \code{\link[sl3]{Lrnr_base}}), containing instantiated
#' learners from \pkg{sl3}; used to fit an initial efficient influence
#' function regression when computing the efficient influence function in a
#' two-phase sampling design.
#' @param w_names A \code{character} vector of the names of the columns that
#' correspond to baseline covariates (W). The input for this argument is
#' automatically generated by \code{\link{medoutcon}}.
#' @param m_names A \code{character} vector of the names of the columns that
#' correspond to mediators (M). The input for this argument is automatically
#' generated by \code{\link{medoutcon}}.
#' @param y_bounds A \code{numeric} double indicating the minimum and maximum
#' observed values of the outcome variable Y prior to its being re-scaled to
#' the unit interval.
#' @param g_bounds A \code{numeric} vector containing two values, the first
#' being the minimum allowable estimated propensity score value and the second
#' being the maximum allowable for estimated propensity score value.
#' @param effect_type A \code{character} indicating whether components of the
#' interventional or natural (in)direct effects are to be estimated. In the
#' case of the natural (in)direct effects, estimation of several nuisance
#' parameters is unnecessary.
#' @param svy_weights A \code{numeric} vector of observation-level weights that
#' have been computed externally. Such weights are used in the construction of
#' a re-weighted estimator.
#' @param cv_folds A \code{numeric} value specifying the number of folds to be
#' created for cross-validation. Use of cross-validation allows for entropy
#' conditions on the TML estimator to be relaxed. Note: for compatibility with
#' \code{\link[origami]{make_folds}}, this value must be greater than or
#' equal to 2; the default is to create 10 folds.
#' @param cv_strat A \code{logical} atomic vector indicating whether V-fold
#' cross-validation should stratify the folds based on the outcome variable.
#' If \code{TRUE}, the folds are stratified by passing the outcome variable to
#' the \code{strata_ids} argument of \code{\link[origami]{make_folds}}. While
#' the default is \code{FALSE}, an override is triggered when the incidence of
#' the binary outcome variable falls below the tolerance in \code{strat_pmin}.
#' @param strat_pmin A \code{numeric} atomic vector indicating a tolerance for
#' the minimum proportion of cases (for a binary outcome variable) below which
#' stratified V-fold cross-validation is invoked if \code{cv_strat} is set to
#' \code{TRUE} (default is \code{FALSE}). The default tolerance is 0.1.
#' @param max_iter A \code{numeric} integer giving the maximum number of steps
#' to be taken for the iterative procedure to construct a TML estimator.
#' @param tiltmod_tol A \code{numeric} indicating the maximum step size to be
#' taken when performing TMLE updates based on logistic tilting models. When
#' the step size of a given update exceeds this value, the update is avoided.
#'
#' @importFrom dplyr "%>%"
#' @importFrom assertthat assert_that
#' @importFrom origami make_folds cross_validate folds_vfold
#' @importFrom stats var as.formula plogis qlogis coef predict weighted.mean
#' binomial
#' @importFrom glm2 glm2
#'
#' @keywords internal
est_tml <- function(data,
contrast,
g_learners,
h_learners,
b_learners,
q_learners,
r_learners,
u_learners,
v_learners,
d_learners,
w_names,
m_names,
y_bounds,
g_bounds = c(0.005, 0.95),
effect_type = c("interventional", "natural"),
svy_weights = NULL,
cv_folds = 10L,
cv_strat = FALSE,
strat_pmin = 0.1,
max_iter = 10L,
tiltmod_tol = 5) {
# make sure that more than one fold is specified
assertthat::assert_that(cv_folds > 1L)
# create cross-validation folds
if (cv_strat && data[, mean(Y) <= strat_pmin]) {
# check that outcome is binary for stratified V-fold cross-validation
assertthat::assert_that(data[, all(unique(Y) %in% c(0, 1))])
# if outcome is binary and rare, use stratified V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds,
strata_ids = data$Y
)
} else {
# just use standard V-fold cross-validation
folds <- origami::make_folds(
data,
fold_fun = origami::folds_vfold,
V = cv_folds
)
}
# perform the cv_eif procedure on a per-fold basis
cv_eif_results <- origami::cross_validate(
cv_fun = cv_eif,
folds = folds,
data_in = data,
contrast = contrast,
g_learners = g_learners,
h_learners = h_learners,
b_learners = b_learners,
q_learners = q_learners,
r_learners = r_learners,
u_learners = u_learners,
v_learners = v_learners,
d_learners = d_learners,
effect_type = effect_type,
w_names = w_names,
m_names = m_names,
g_bounds = g_bounds,
use_future = FALSE,
.combine = FALSE
)
# concatenate nuisance function function estimates
# make sure that data is in the same order as the concatenated validations
# sets
cv_eif_est <- do.call(rbind, cv_eif_results[[1]])
obs_valid_idx <- do.call(c, lapply(folds, `[[`, "validation_set"))
data <- data[obs_valid_idx]
# extract nuisance function estimates and auxiliary quantities
g_prime <- cv_eif_est$g_prime
h_prime <- cv_eif_est$h_prime
g_star <- cv_eif_est$g_star
h_star <- cv_eif_est$h_star
q_prime_Z_one <- cv_eif_est$q_prime_Z_one
q_prime_Z_natural <- cv_eif_est$q_prime_Z_natural
r_prime_Z_one <- cv_eif_est$r_prime_Z_one
r_prime_Z_natural <- cv_eif_est$r_prime_Z_natural
b_prime_Z_one <- cv_eif_est$b_prime_Z_one
b_prime_Z_zero <- cv_eif_est$b_prime_Z_zero
b_prime_Z_natural <- cv_eif_est$b_prime
# generate inverse weights and multiplier for auxiliary covariates
ipw_prime <- as.numeric(data[R == 1, A] == contrast[1]) / g_prime
ipw_star <- as.numeric(data[R == 1, A] == contrast[2]) / g_star
c_star_mult <- (g_prime / g_star) * (h_star / h_prime)
# prepare for iterative targeting
eif_stop_crit <- FALSE
n_iter <- 0
n_obs <- nrow(data)
se_eif <- sqrt(var(cv_eif_est$D_star) / n_obs)
tilt_stop_crit <- se_eif / log(n_obs)
b_score <- q_score <- Inf
tilt_two_phase_weights <- sum(data$R) != nrow(data)
d_pred <- unlist(cv_eif_results$D_pred)[order(obs_valid_idx)]
# perform iterative targeting
while (!eif_stop_crit && n_iter <= max_iter) {
# NOTE: check convergence condition for outcome regression
if (mean(b_score) > tilt_stop_crit) {
# compute auxiliary covariates from updated estimates
c_star_Z_natural <- (q_prime_Z_natural / r_prime_Z_natural) * c_star_mult
c_star_Z_one <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
if (effect_type == "natural") {
# NOTE: this exception handles 0/0 division, since q(1|a',...) = 1
# and r(1|a',...) = 1, improperly yielding 0/0 => NaN
c_star_Z_zero <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
} else if (effect_type == "interventional") {
c_star_Z_zero <- ((1 - q_prime_Z_one) / (1 - r_prime_Z_one)) *
c_star_mult
}
# bound and transform nuisance estimates for tilting regressions
b_prime_Z_natural_logit <- b_prime_Z_natural %>%
bound_precision() %>%
stats::qlogis()
b_prime_Z_one_logit <- b_prime_Z_one %>%
bound_precision() %>%
stats::qlogis()
b_prime_Z_zero_logit <- b_prime_Z_zero %>%
bound_precision() %>%
stats::qlogis()
# fit tilting model for the outcome mechanism
c_star_b_tilt <- c_star_Z_natural
if (tilt_two_phase_weights) {
weights_b_tilt <- as.numeric(data[R == 1, A] == contrast[1]) /
g_prime * as.numeric(data[R == 1, two_phase_weights])
} else {
weights_b_tilt <- data$obs_weights * (data$A == contrast[1]) / g_prime
}
suppressWarnings(
b_tilt_fit <- glm2::glm2(
stats::as.formula("y_scaled ~ -1 + offset(b_prime_logit) + c_star"),
data = data.table::as.data.table(list(
y_scaled = data[R == 1, Y],
b_prime_logit = b_prime_Z_natural_logit,
c_star = c_star_b_tilt
)),
weights = weights_b_tilt,
family = stats::binomial(),
start = 0
)
)
if (is.na(stats::coef(b_tilt_fit))) {
b_tilt_fit$coefficients <- 0
} else if (!b_tilt_fit$converged || abs(max(stats::coef(b_tilt_fit))) >
tiltmod_tol) {
b_tilt_fit$coefficients <- 0
}
b_tilt_coef <- unname(stats::coef(b_tilt_fit))
# update nuisance estimates via tilting regressions for outcome
b_prime_Z_natural <- stats::plogis(b_prime_Z_natural_logit +
b_tilt_coef * c_star_Z_natural)
b_prime_Z_one <- stats::plogis(b_prime_Z_one_logit +
b_tilt_coef * c_star_Z_one)
b_prime_Z_zero <- stats::plogis(b_prime_Z_zero_logit +
b_tilt_coef * c_star_Z_zero)
# compute efficient score for outcome regression component
b_score <- data[R == 1, two_phase_weights] *
ipw_prime * c_star_Z_natural * (data[R == 1, Y] - b_prime_Z_natural)
} else {
b_score <- 0
}
# NOTE: check convergence condition for intermediate confounding
if (mean(q_score) > tilt_stop_crit) {
# perform iterative targeting for intermediate confounding
q_prime_Z_one_logit <- q_prime_Z_one %>%
bound_precision() %>%
stats::qlogis()
# fit tilting regressions for intermediate confounding
u_prime_diff_q_tilt <- cv_eif_est$u_int_diff
if (tilt_two_phase_weights) {
weights_q_tilt <- as.numeric(data[R == 1, A] == contrast[1]) /
g_prime * as.numeric(data[R == 1, two_phase_weights])
} else {
weights_q_tilt <- data$obs_weights * (data$A == contrast[1]) / g_prime
}
suppressWarnings(
q_tilt_fit <- glm2::glm2(
stats::as.formula("Z ~ -1 + offset(q_prime_logit) + u_prime_diff"),
data = data.table::as.data.table(list(
Z = data[R == 1, Z],
q_prime_logit = q_prime_Z_one_logit,
u_prime_diff = u_prime_diff_q_tilt
)),
weights = weights_q_tilt,
family = stats::binomial(),
start = 0
)
)
# NOTE: for the natural (in)direct effects, the regressor on the RHS
# is uniquely ZERO so estimated parameter should always be NaN
if (effect_type == "natural") {
q_tilt_fit$coefficients <- NA
}
if (is.na(stats::coef(q_tilt_fit))) {
q_tilt_fit$coefficients <- 0
} else if (!q_tilt_fit$converged || abs(max(stats::coef(q_tilt_fit))) >
tiltmod_tol) {
q_tilt_fit$coefficients <- 0
}
q_tilt_coef <- unname(stats::coef(q_tilt_fit))
# update nuisance estimates via tilting of intermediate confounder
if (effect_type == "natural") {
# for the natural (in)direct effects, no updates necessary
q_prime_Z_one <- data[R == 1, Z]
q_prime_Z_natural <- data[R == 1, Z]
} else {
q_prime_Z_one <- stats::plogis(q_prime_Z_one_logit + q_tilt_coef *
cv_eif_est$u_int_diff)
q_prime_Z_natural <- (data[R == 1, Z] * q_prime_Z_one) +
((1 - data[R == 1, Z]) * (1 - q_prime_Z_one))
}
# compute efficient score for intermediate confounding component
q_score <- ipw_prime * cv_eif_est$u_int_diff *
(data[R == 1, Z] - q_prime_Z_one) *
(data[R == 1, two_phase_weights])
} else {
q_score <- 0
}
# check convergence and iterate the counter
eif_stop_crit <- all(
abs(c(mean(b_score), mean(q_score))) < tilt_stop_crit
)
n_iter <- n_iter + 1
}
# update auxiliary covariates after completion of iterative targeting
c_star_Z_natural <- (q_prime_Z_natural / r_prime_Z_natural) * c_star_mult
c_star_Z_one <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
if (effect_type == "natural") {
# NOTE: this exception handles 0/0 division, since q(1|a',...) = 1 and
# r(1|a',...) = 1, improperly yielding 0/0 => NaN
c_star_Z_zero <- (q_prime_Z_one / r_prime_Z_one) * c_star_mult
} else if (effect_type == "interventional") {
c_star_Z_zero <- ((1 - q_prime_Z_one) / (1 - r_prime_Z_one)) * c_star_mult
}
# compute updated substitution estimator and prepare for tilting regression
v_pseudo <- ((b_prime_Z_one * q_prime_Z_one) +
(b_prime_Z_zero * (1 - q_prime_Z_one))) %>%
bound_precision()
v_star_logit <- cv_eif_est$v_star %>%
bound_precision() %>%
stats::qlogis()
# fit tilting regression for substitution estimator
if (tilt_two_phase_weights) {
weights_v_tilt <- (as.numeric(data[R == 1, A]) == contrast[2]) / g_star *
(as.numeric(data[R == 1, two_phase_weights]))
} else {
weights_v_tilt <- data$obs_weights * (data$A == contrast[2]) / g_star
}
suppressWarnings(
v_tilt_fit <- glm2::glm2(
stats::as.formula("v_pseudo ~ offset(v_star_logit)"),
data = data.table::as.data.table(list(
v_pseudo = v_pseudo,
v_star_logit = v_star_logit
)),
weights = weights_v_tilt,
family = stats::binomial(),
start = 0
)
)
v_star_tmle <- unname(stats::predict(v_tilt_fit, type = "response"))
# compute influence function with centering at the TML estimate
# make sure that it's in the same order as the original data
eif_est <- unlist(cv_eif_results$D_star)[order(obs_valid_idx)]
# re-scale efficient influence function
v_star_tmle_rescaled <- v_star_tmle %>%
scale_from_unit(y_bounds[2], y_bounds[1])
eif_est_rescaled <- eif_est %>%
scale_from_unit(y_bounds[2], y_bounds[1])
# compute TML estimator and variance from efficient influence function
if (is.null(svy_weights)) {
tml_est <- mean(v_star_tmle_rescaled)
eif_est_out <- eif_est_rescaled
} else {
# compute a re-weighted TMLE, with re-weighted influence function
# NOTE: make sure that survey weights are ordered like the concatenated
# validation sets
svy_weights <- svy_weights[obs_valid_idx]
tml_est <- stats::weighted.mean(v_star_tmle_rescaled, svy_weights)
eif_est_out <- eif_est_rescaled * svy_weights
}
tmle_var <- stats::var(eif_est_out) / length(eif_est_out)
# output
tmle_out <- list(
theta = tml_est,
theta_plugin = est_plugin(cv_eif_est$v_star),
var = tmle_var,
eif = (eif_est_out - tml_est),
n_iter = n_iter,
type = "tmle"
)
return(tmle_out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.