#' onestep TMLE of average treatment effect on survival probabilities
#'
#' updating the hazard using constrained step size update
#'
#' @docType class
#' @importFrom R6 R6Class
#' @export
#' @keywords data
#' @return Object of \code{\link{R6Class}} with methods
#' @format \code{\link{R6Class}} object.
#' @examples
#' \donttest{
#' MOSS_hazard_ate$new(
#' A,
#' T_tilde,
#' Delta,
#' density_failure,
#' density_censor,
#' density_failure_0,
#' density_censor_0,
#' g1W,
#' A_intervene = 1,
#' k_grid = 1:max(T_tilde)
#' )
#' }
#' @field A vector of treatment
#' @field T_tilde vector of last follow up time
#' @field Delta vector of censoring indicator
#' @field density_failure survival_curve object of predicted counterfactual
#' survival curve
#' @field density_censor survival_curve object of predicted counterfactual
#' failure event survival curve
#' @field density_failure_0 survival_curve object of predicted counterfactual
#' survival curve 0
#' @field density_censor_0 survival_curve object of predicted counterfactual
#' failure event survival curve 0
#' @field g1W propensity score
#' @field k_grid vector of interested time points
#' @section Methods:
#' iterate_onestep update the initial estimator
#' @export
MOSS_hazard_ate <- R6Class("MOSS_hazard_ate",
inherit = MOSS_hazard,
public = list(
density_failure_0 = NULL,
density_censor_0 = NULL,
q_best_0 = NULL,
initialize = function(
density_failure_0,
density_censor_0,
...
) {
super$initialize(...)
self$density_failure_0 <- density_failure_0
self$density_censor_0 <- density_censor_0
return(self)
},
fit_epsilon = function(method = "l2", clipping = Inf) {
dNt <- self$create_dNt()
h_matrix_1 <- self$construct_long_data(
A_intervene = 1,
density_failure = self$density_failure,
density_censor = self$density_censor
)
h_matrix_0 <- self$construct_long_data(
A_intervene = 0,
density_failure = self$density_failure_0,
density_censor = self$density_censor_0
)
h_matrix <- h_matrix_1 - h_matrix_0
offset_submodel_1 <- logit(as.vector(t(self$density_failure$hazard)))
offset_submodel_0 <- logit(as.vector(t(self$density_failure_0$hazard)))
offset_submodel <- offset_submodel_1
offset_submodel[self$A == 0] <- offset_submodel_0[self$A == 0]
if (method == "glm") {
submodel_fit <- glm.fit(
x = h_matrix,
y = dNt,
family = binomial(),
offset = offset_submodel,
intercept = FALSE
)
epsilon_n <- submodel_fit$coefficients
l2_norm <- sqrt(sum(epsilon_n ^ 2))
if (l2_norm >= clipping) {
# clipping the step size
epsilon_n <- epsilon_n / l2_norm * clipping
}
}
if (method %in% c("l2", "l1")) {
if (method == "l2") {alpha <- 0; norm_func <- norm_l2; lambda.min.ratio = 1e-2}
if (method == "l1") {alpha <- 1; norm_func <- norm_l1; lambda.min.ratio = 9e-1}
ind <- 1
while (ind == 1) {
enet_fit <- glmnet::glmnet(
x = h_matrix,
y = dNt,
offset = logit(as.vector(t(self$density_failure$hazard))),
family = "binomial",
alpha = alpha,
standardize = FALSE,
intercept = FALSE,
lambda.min.ratio = lambda.min.ratio,
nlambda = 2e2
)
norms <- apply(enet_fit$beta, 2, norm_func)
ind <- max(which(norms <= clipping))
if (ind > 1) break
lambda.min.ratio <- (lambda.min.ratio + 1) / 2
}
# lambda_best <- enet_fit$lambda[ind]
epsilon_n <- enet_fit$beta[, ind]
# epsilon_n <- fit_enet_constrained(
# X = h_matrix,
# Y = dNt,
# beta_init = rep(0, ncol(h_matrix)),
# norm_max = clipping,
# offset = offset_submodel,
# type = method
# )
}
hazard_new_1 <- expit(
offset_submodel_1 + as.vector(h_matrix_1 %*% epsilon_n)
)
hazard_new_1 <- matrix(
hazard_new_1,
nrow = length(self$A),
ncol = max(self$T_tilde),
byrow = TRUE
)
hazard_new_0 <- expit(
offset_submodel_0 + as.vector(h_matrix_0 %*% epsilon_n)
)
hazard_new_0 <- matrix(
hazard_new_0,
nrow = length(self$A),
ncol = max(self$T_tilde),
byrow = TRUE
)
# the new hazard for failure
return(
list(
hazard_new_1 = survival_curve$new(
t = 1:max(self$T_tilde), hazard = hazard_new_1
)$hazard_to_survival(),
hazard_new_0 = survival_curve$new(
t = 1:max(self$T_tilde), hazard = hazard_new_0
)$hazard_to_survival()
)
)
},
compute_mean_eic = function(psi_n, k_grid) {
eic_fit_1 <- eic$new(
A = self$A,
T_tilde = self$T_tilde,
Delta = self$Delta,
density_failure = self$density_failure,
density_censor = self$density_censor,
g1W = self$g1W,
psi = psi_n,
A_intervene = 1
)$all_t(k_grid = k_grid)
eic_fit_0 <- eic$new(
A = self$A,
T_tilde = self$T_tilde,
Delta = self$Delta,
density_failure = self$density_failure_0,
density_censor = self$density_censor_0,
g1W = self$g1W,
psi = psi_n,
A_intervene = 0
)$all_t(k_grid = k_grid)
eic_fit <- eic_fit_1 - eic_fit_0
mean_eic <- colMeans(eic_fit)
return(mean_eic)
},
iterate_onestep = function(
method = "l2",
epsilon = 1e-2,
max_num_interation = 1e2,
tmle_tolerance = NULL,
verbose = FALSE
) {
self$epsilon <- epsilon
self$max_num_interation <- max_num_interation
if (is.null(tmle_tolerance)) {
self$tmle_tolerance <- 1 / self$density_failure$n()
} else {
self$tmle_tolerance <- tmle_tolerance
}
k_grid <- 1:max(self$T_tilde)
psi_n <- colMeans(
self$density_failure$survival - self$density_failure_0$survival
)
mean_eic <- self$compute_mean_eic(psi_n = psi_n, k_grid = k_grid)
num_iteration <- 0
mean_eic_inner_prod_prev <- abs(sqrt(sum(mean_eic ^ 2)))
mean_eic_inner_prod_current <- mean_eic_inner_prod_prev
mean_eic_inner_prod_best <- sqrt(sum(mean_eic ^ 2))
self$q_best <- self$density_failure$clone(deep = TRUE)
self$q_best_0 <- self$density_failure_0$clone(deep = TRUE)
to_iterate <- TRUE
if (is.infinite(mean_eic_inner_prod_current) | is.na(mean_eic_inner_prod_current)) {
to_iterate <- FALSE
}
while (
mean_eic_inner_prod_current >= self$tmle_tolerance * sqrt(max(k_grid)) &
to_iterate
) {
if (verbose) {
df_debug <- data.frame(num_iteration, mean_eic_inner_prod_current, mean(psi_n))
colnames(df_debug) <- NULL
print(df_debug)
}
# update
new_hazard <- self$fit_epsilon(method = method, clipping = self$epsilon)
self$density_failure <- new_hazard$hazard_new_1
self$density_failure_0 <- new_hazard$hazard_new_0
psi_n <- colMeans(
self$density_failure$survival - self$density_failure_0$survival
)
mean_eic <- self$compute_mean_eic(psi_n = psi_n, k_grid = k_grid)
# new stopping
mean_eic_inner_prod_prev <- mean_eic_inner_prod_current
mean_eic_inner_prod_current <- abs(sqrt(sum(mean_eic ^ 2)))
num_iteration <- num_iteration + 1
if (is.infinite(mean_eic_inner_prod_current) | is.na(mean_eic_inner_prod_current)) {
warning("stopping criteria diverged. Reporting best result so far.")
break()
}
if (mean_eic_inner_prod_current < mean_eic_inner_prod_best) {
# the update caused PnEIC to beat the current best
# update our best candidate
self$q_best <- self$density_failure$clone(deep = TRUE)
self$q_best_0 <- self$density_failure_0$clone(deep = TRUE)
mean_eic_inner_prod_best <- mean_eic_inner_prod_current
}
if (num_iteration == self$max_num_interation) {
break()
warning("Max number of iteration reached, stop TMLE")
}
}
# always output the best candidate for final result
self$density_failure <- self$q_best
self$density_failure_0 <- self$q_best_0
psi_n <- colMeans(
self$density_failure$survival - self$density_failure_0$survival
)
if (verbose) {
message(paste(
"Pn(EIC)=",
formatC(mean_eic_inner_prod_best, format = "e", digits = 2),
"Psi=",
formatC(mean(psi_n), format = "e", digits = 2)
))
}
return(psi_n)
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.