R/tmle3_Spec_mediation_v2.R

Defines functions tmle_mediation

#' Defines a TML Estimator (except for the data)
#'
#' Longitudinal Mediation Targets; v2 aims for all common targets
#'
#' @importFrom R6 R6Class
#'
#' @export
#
tmle3_Spec_mediation <- R6Class(
  classname = "tmle3_Spec_mediation",
  portable = TRUE,
  class = TRUE,
  inherit = tmle3_Spec,
  public = list(
    initialize = function(treatment_level, control_level, ...) {
      super$initialize(
        treatment_level = treatment_level,
        control_level = control_level, ...
      )
    },
    make_tmle_task = function(data, node_list, if_drop_censored = NULL, ...) {
      variable_types <- self$options$variable_types
      if(is.null(if_drop_censored)) {
        tmle_task <- mediation_task(data, node_list, variable_types)
      } else {
        if(if_drop_censored) tmle_task <- mediation_task_drop_censored(data, node_list, variable_types) else
          tmle_task <- mediation_task(data, node_list, variable_types)
      }
      return(tmle_task)
    },
    make_initial_likelihood = function(tmle_task, learner_list = NULL) {
      # produce trained likelihood when likelihood_def provided

      if (!is.null(self$options$likelihood_override)) {
        likelihood <- self$options$likelihood_override$train(tmle_task)
      } else {
        likelihood <- mediation_likelihood(tmle_task, learner_list)  # see mediation_helper
      }

      return(likelihood)
    },
    make_params = function(tmle_task, likelihood, options = NULL, if_projection = NULL, static_likelihood = NULL, n_resampling = NULL) {
      if (is.null(static_likelihood)) static_likelihood <- likelihood
      if (is.null(options)) options <- list("tc")

      temp_names <- names(tmle_task$npsem)  # node names
      loc_A <- grep("A", temp_names)
      # ZW todo: in future can be dynamic
      A_levels <- tmle_task$npsem[[ temp_names[loc_A[1]] ]]$variable_type$levels

      tmle_params <-lapply(options, function(option) {
        if (option == "tc") {
          treatment_value <- self$options$treatment_level
          control_value <- self$options$control_level
        } else if (option == "tt") {
          treatment_value <- self$options$treatment_level
          control_value <- self$options$treatment_level
        } else if (option == "cc") {
          treatment_value <- self$options$control_level
          control_value <- self$options$control_level
        }  # decide the target of inference
        if (!is.null(A_levels)) {  # possible arm levels are stored as factor levels
          treatment_value <- factor(treatment_value, levels = A_levels)
          control_value <- factor(control_value, levels = A_levels)
        }
        # list of intervention nodes as LF_static objects
        treatment <- lapply(temp_names[loc_A], function(eachA) {
          define_lf(LF_static, eachA, value = treatment_value)
        })
        control <- lapply(temp_names[loc_A], function(eachA) {
          define_lf(LF_static, eachA, value = control_value)
        })
        names(treatment) <- names(control) <- temp_names[loc_A]
        if (is.null(if_projection)) {
          param <- Param_mediation$new(likelihood, treatment, control)
        } else if (!if_projection) {
          param <- Param_mediation$new(likelihood, treatment, control)
        } else if (if_projection) {
          param <- Param_mediation_projection$new(likelihood, treatment, control, static_likelihood, n_resampling)
        } else stop("Error: if_projection needs to be either True or False/NULL. ")
        return(param)
      })
      return(tmle_params)
    },
    make_params_survival = function(tmle_task, likelihood, options = NULL, if_projection = NULL, static_likelihood = NULL, n_resampling = NULL) {
      if (is.null(static_likelihood)) static_likelihood <- likelihood
      if (is.null(options)) options <- list("tc")

      temp_names <- names(tmle_task$npsem)
      loc_delta_nodes <- grep("delta_", temp_names)
      if (length(loc_delta_nodes) != 0) temp_names <- temp_names[-grep("delta_", temp_names)]  # remove delta nodes for wide format fitting
      loc_A_E <- grep("A_E", temp_names)
      # ZW todo: in future can be dynamic
      A_levels <- tmle_task$npsem[[ temp_names[loc_A_E[1]] ]]$variable_type$levels

      tmle_params <- lapply(options, function(option) {
        if (option == "tc") {
          treatment_value <- self$options$treatment_level
          control_value <- self$options$control_level
        } else if (option == "tt") {
          treatment_value <- self$options$treatment_level
          control_value <- self$options$treatment_level
        } else if (option == "cc") {
          treatment_value <- self$options$control_level
          control_value <- self$options$control_level
        }  # decide the target of inference
        if (!is.null(A_levels)) {
          treatment_value <- factor(treatment_value, levels = A_levels)
          control_value <- factor(control_value, levels = A_levels)
        }
        # list of intervention nodes as LF_static objects
        treatment <- lapply(temp_names[loc_A_E], function(eachA) {
          define_lf(LF_static, eachA, value = treatment_value)
        })
        control <- lapply(temp_names[loc_A_E], function(eachA) {
          define_lf(LF_static, eachA, value = control_value)
        })
        names(treatment) <- names(control) <- temp_names[loc_A_E]
        if (is.null(if_projection)) {
          param <- Param_mediation_survival$new(likelihood, treatment, control)
        } else if (!if_projection) {
          param <- Param_mediation_survival$new(likelihood, treatment, control)
        } else if (if_projection) {
          param <- Param_mediation_projection_survival$new(likelihood, treatment, control, static_likelihood, n_resampling)
        }
        return(param)
      })
      return(tmle_params)
    }
  ),
  active = list(),
  private = list()
)

#' Longitudinal Mediation Targets
#'
#' O=(L0, A1, R1, Z1, L1, (Y1), ..., An, Rn, Zn, Ln, Yn)
#' L0=Baseline covariates
#' A: Treatment (binary or categorical)
#' Z: Mediators
#' R: (time-varying) Covariates before mediator
#' L: (time-varying) Covariates after mediator
#' Y=Outcome (binary or bounded continuous)
#' @importFrom sl3 make_learner Lrnr_mean
#' @param treatment_level the level of A that corresponds to treatment
#' @param control_level the level of A that corresponds to a control or reference level
#' @export
tmle_mediation <- function(treatment_level, control_level) {
  # TODO: unclear why this has to be in a factory function
  tmle3_Spec_mediation$new(treatment_level, control_level)
}
zy-wang1/calm documentation built on July 30, 2024, 10:51 a.m.