R/tmle.R

Defines functions fit_GCOMP_onet iterTMLE_onet fit_GCOMP fit_CVTMLE fit_TMLE fit_iterTMLE defineNodeGstarGCOMP

Documented in fit_CVTMLE fit_GCOMP fit_iterTMLE fit_TMLE

if(getRversion() >= "2.15.1") {
  utils::globalVariables(c("Qk_hat", "cum.inc", "mean_Q", "EIC_i", "TMLE_Var", "IC.St"))
}

## ------------------------------------------------------------------------------------------
## TO DO:
## ------------------------------------------------------------------------------------------
## *) The definition of Qperiods below needs to be based on actual periods observed in the data.
## *) Stochastic interventions ***
##    intervened_TRT/intervened_MONITOR are vectors of counterfactual probs -> need to allow each to be multivariate (>1 cols)
## *) Accessing correct ModelQlearn ***
##     Need to come up with a good way of accessing correct terminal ModelQlearn to get final Q-predictions in private$probAeqa.
##     These predictions are for the final n observations that were used for evaluating E[Y^a].
##     However, if we do stratify=TRUE and use a stack version of g-comp with different strata all in one stacked database it will be difficult
##     to pull the right ModelQlearn.
##     Alternative is to ignore that completely and go directly for the observed data (by looking up the right column/row combos).
## *) Pooling across regimens ***
##     Since new regimen results in new Qkplus1 and hence new outcome -> requires a separate regression for each regimen:
##     => Can either pool all Q.regimens at once (fit one model for repeated observations, one for each rule, smoothing over rules).
##     => Can use the same stacked dataset of regime-followers, but with a separate stratification for each regimen.
## ------------------------------------------------------------------------------------------

## ------------------------------------------------------------------------------------------
## Sequential (Recursive) G-COMP Algorithm:
## ------------------------------------------------------------------------------------------
## 1. At each t iteration:
##     Fitting:
##        Outcome is always Qkplus1 (either by regimen or not). Fitting subset excludes all obs censored at t.
##     Prediction:
##        Swap the entire column(s) A[t] with A^*[t] by renaming them in OData$dat.sVar (might do this for more than one regimen in the future, if pooling Q's)
##        Subset all observation by t (including all censored and non-followers).
##        PredictP1() for everybody in the subset, then save the prediction in rows Qkplus1[t] and rows Qkplus1[t-1] for all obs that were used in prediction.
##        Results in correct algorithm even when stratifyQ_by_rule=TRUE: next iteration (t-1) will do fit only based on obs that followed the rule at t-1.
## 2. At next iteration t-1:
##     Fitting:
##        Use the predictions in Qkplus1[t-1] as new outcomes and repeat until reached minimum t.

## ------------------------------------------------------------------------------------------
## Stochastic interventions:
## ------------------------------------------------------------------------------------------
## When the intervention node values are <1 and >0, those are interpreted as probabilities,
## i.e., stochastic interventions.
## In more detail, the column 0 < intervened_TRT < 1 defines the counterfactual probability that P(TRT[t]^*=1)=intervened_TRT[t].
## Those are dealt with in ModelQlearn class by directly integrating, i.e.,
##  evaluating a weighted sum of predicted Q's with weights given by the probabilities in g.star columns
## (on A and N).

## ------------------------------------------------------------------------------------------
##  *** rule_followers:
## ------------------------------------------------------------------------------------------
##     Once you go off the treatment first time, this is it, the person is censored for the rest of the follow-up.
##     Rule followers are evaluated automatically by comparing (intervened_TRT and TRT and CENS) and (intervened_MONITOR and MONITOR and CENS).
##     Rule followers are: (intervened_TRT = 1) & (TRT == 1) or (intervened_TRT = 0) & (TRT == 0) or (intervened_TRT > 0 & intervened_TRT < 0) & (Not Censored).
##     Exactly the same logic is also applied to (intervened_MONITOR & MONITOR) when these are specified.
##     NOT TESTED: If either TRT or MONITOR is multivariate (>1 col), this logic needs to be applied FOR EACH COLUMN of TRT/MONITOR.
## ------------------------------------------------------------------------------------------
## *** Pooling & performing only one TMLE update across all Q.k ***
## ------------------------------------------------------------------------------------------
##     In general, we can pool all Q.kplus (across all t's) and do a single update on all of them
##     Would make this TMLE iterative, but would not require refitting on the initial Q's

## ------------------------------------------------------------------------------------------------------------------------
## Example call: defineNodeGstarGCOMP(OData, intervened_TRT, nodes$Anodes, useonly_t_TRT, stratifyQ_by_rule)
## ------------------------------------------------------------------------------------------------------------------------
## OData - R6 objects that stores input data and associated data management functions
## intervened_NODE - column name(s) w/ counterfactual intervention values of A^* or N^*; can be len > 1 for multivariate A or N
## NodeNames - column names with observed exposures for A or N; can be multivariate (len > 1)
## useonly_t_NODE - a single logical expression or a list of logical expressions, intervene only on the subset of exposure nodes 
##                  (rows that evaluate to TRUE by expr in useonly_t_NODE)
## stratifyQ_by_rule - define observations that are following the rule (using entire duration of follow-up)
## stratify_by_last - define observations that are following the rule at current time-point only
## ... When useonly_t_TRT or useonly_t_MONITOR is specified, we set nodes A^* nodes to their observed values (A), rather than the counterfactual values
defineNodeGstarGCOMP <- function(OData, intervened_NODE, NodeNames, useonly_t_NODE, stratifyQ_by_rule, stratify_by_last) {
  # if intervened_NODE returns more than one rule-column, evaluate g^* for each and the multiply to get a single joint (for each time point)
  if (!is.null(intervened_NODE) && !is.na(intervened_NODE)) {
    gstar.NODEs <- intervened_NODE
    for (intervened_NODE_col in intervened_NODE) CheckVarNameExists(OData$dat.sVar, intervened_NODE_col)
    assert_that(length(intervened_NODE) == length(NodeNames))
    intervened_NODE_l <- as.list(intervened_NODE)
    names(intervened_NODE_l) <- NodeNames

    ## assert that useonly_t_NODE is either character vector of length == 1 or 0 (NULL) or a list
    if (is.list(useonly_t_NODE)) {
      useonly_t_NODE_l <- useonly_t_NODE
    } else if (is.null(useonly_t_NODE)) { ## for NULL entry, convert to named list of NULL elements
      useonly_t_NODE_l <- sapply(intervened_NODE, function(x) NULL)
    } else if (is.character(useonly_t_NODE) & length(useonly_t_NODE)==1) {
      useonly_t_NODE_l <- as.list(rep.int(useonly_t_NODE, length(intervened_NODE)))
      names(useonly_t_NODE_l) <- intervened_NODE
    } else {
      stop("useonly_t_NODE must be either 'NULL', a single character expression (eg 't = 5') or a named list of character expressions")
    }
    ## check subset of intervened nodes is specified by a list
    assert_that(is.list(useonly_t_NODE_l))
    ## add a name check: all names in subset node list are part of intervened_NODE
    assert_that(all(names(useonly_t_NODE_l) %in% intervened_NODE))
    ## add a check that list input consists of either: a) characters or b) NULLs
    assert_that(all(sapply(useonly_t_NODE_l, function(x) is.null(x) | is.character(x))))

    ## ------------------------------------------------------------------------------------------
    ## Modify the observed intervened_NODE (A^*) in OData$dat.sVar with values from NodeName (A) for repl_subset_idx:
    ## ------------------------------------------------------------------------------------------
    ## 1. Loop over each node in NodeNames (source_node)
    ## 2. Pick corresponding A^* name (node_to_repl) and useonly_t_NODE expression, 
    ## 3. Evaluate the subset index of rows that need replacement (repl_subset_idx)
    ## 3. Replace the node values in node_to_repl with source_node for rows in subset_idx only
    for (source_node in NodeNames) {
      node_to_repl <- intervened_NODE_l[[source_node]]
      useonly_t = useonly_t_NODE_l[[node_to_repl]]
      subset_idx <- OData$evalsubst(subset_exprs = useonly_t)
      ## inverse of subset_idx: replace A^*[idx] nodes with A[idx] only when we DO NOT want to intevene on A[idx]
      repl_subset_idx <- setdiff(1:nrow(OData$dat.sVar), subset_idx)
      OData$replaceNodesVals(repl_subset_idx, nodes_to_repl = node_to_repl, source_for_repl = source_node)
    }

    ## ------------------------------------------------------------------------------------------
    ## update rule followers for trt if doing stratified G-COMP:
    ## Note this will define rule followers based on REPLACED intervened_NODE in dat.sVar (i.e., modified n^*(t) under N.D.E.)
    ## FOR NDE BASED TMLE THE DEFINITION OF RULE-FOLLOWERS CHANGES ACCORDINGLY based on modified n^*(t) and a^*(t)
    # ------------------------------------------------------------------------------------------
    if (stratifyQ_by_rule) {
      for (Anode_i in NodeNames) {
        if (!stratify_by_last) {
          follow_rule <- OData$eval_follow_rule(NodeName = Anode_i, gstar.NodeName = intervened_NODE_l[[Anode_i]])
        } else {
          follow_rule <- OData$eval_follow_rule_each_t(NodeName = Anode_i, gstar.NodeName = intervened_NODE_l[[Anode_i]])
        }
        OData$follow_rule <- follow_rule & OData$follow_rule & OData$uncensored
      }
    }
  } else {
    # use the actual (observed) node names under g0:
    gstar.NODEs <- NodeNames
  }
  return(gstar.NODEs)
}

# ---------------------------------------------------------------------------------------
#' Iterative TMLE wrapper for \code{fit_GCOMP}
#'
#' Calls \code{fit_GCOMP} with argument \code{iterTMLE = TRUE}.
#' @param ... Arguments that will be passed down to the underlying function \code{fit_GCOMP}
#' @return \code{data.table} with survival by time for sequential GCOMP and iterative TMLE
#' @seealso \code{\link{fit_GCOMP}}
#' @example tests/examples/2_building_blocks_example.R
#' @export
fit_iterTMLE <- function(...) {
  fit_GCOMP(TMLE = FALSE, iterTMLE = TRUE, ...)
}

# ---------------------------------------------------------------------------------------
#' TMLE wrapper for \code{fit_GCOMP}
#'
#' Calls \code{fit_GCOMP} with argument \code{TMLE = TRUE}.
#' @param ... Arguments that will be passed down to the underlying function \code{fit_GCOMP}
#' @return \code{data.table} with TMLE survival by time
#' @seealso \code{\link{fit_GCOMP}}
#' @example tests/examples/2_building_blocks_example.R
#' @export
fit_TMLE <- function(...) {
  fit_GCOMP(TMLE = TRUE, ...)
}

# ---------------------------------------------------------------------------------------
#' CV-TMLE wrapper for \code{fit_GCOMP}
#'
#' Calls \code{fit_GCOMP} with arguments \code{TMLE = TRUE} and \code{CVTMLE = TRUE}.
#' @param ... Arguments that will be passed down to the underlying function \code{fit_GCOMP}
#' @return \code{data.table} with TMLE survival by time
#' @seealso \code{\link{fit_GCOMP}}
#' @example tests/examples/2_building_blocks_example.R
#' @export
fit_CVTMLE <- function(...) {
  fit_GCOMP(TMLE = TRUE, CVTMLE = TRUE, ...)
}


# ---------------------------------------------------------------------------------------
#' Fit sequential GCOMP and TMLE for survival
#'
#' Interventions on up to 3 nodes are allowed: \code{CENS}, \code{TRT} and \code{MONITOR}.
#' TMLE adjustment will be based on the inverse of the propensity score fits for the observed likelihood (g0.C, g0.A, g0.N),
#' multiplied by the indicator of not being censored and the probability of each intervention in \code{intervened_TRT} and \code{intervened_MONITOR}.
#' Requires column name(s) that specify the counterfactual node values or the counterfactual probabilities of each node being 1 (for stochastic interventions).
#' @param OData Input data object created by \code{importData} function.
#' @param tvals Vector of time-points in the data for which the survival function (and risk) should be estimated
#' @param Qforms Regression formulas, one formula per Q. Only main-terms are allowed.
#' @param Qstratify Placeholder for future user-defined model stratification for fitting Qs (CURRENTLY NOT FUNCTIONAL, WILL RESULT IN ERROR).
#' @param intervened_TRT Column name in the input data with the probabilities (or indicators) of counterfactual treatment nodes being equal to 1 at each time point.
#' Leave the argument unspecified (\code{NULL}) when not intervening on treatment node(s).
#' @param intervened_MONITOR Column name in the input data with probabilities (or indicators) of counterfactual monitoring nodes being equal to 1 at each time point.
#' Leave the argument unspecified (\code{NULL}) when not intervening on the monitoring node(s).
#' @param useonly_t_TRT Use for intervening only on some subset of observation and time-specific treatment nodes.
#' Should be a character string with a logical expression that defines the subset of intervention observations.
#' For example, using \code{TRT==0} will intervene only at observations with the value of \code{TRT} being equal to zero.
#' The expression can contain any variable name that was defined in the input dataset.
#' Leave as \code{NULL} when intervening on all observations/time-points.
#' @param useonly_t_MONITOR Same as \code{useonly_t_TRT}, but for monitoring nodes.
#' @param rule_name Optional name for the treatment/monitoring regimen.
#' @param stratifyQ_by_rule Set to \code{TRUE} for stratifying the fit of Q (the outcome model) by rule-followers only.
#' There are two ways to do this stratification. The first option is to use \code{stratify_by_last=TRUE}  (default),
#' which would fit the outcome model only among the observations that were receiving their supposed
#' counterfactual treatment at the current time-point (ignoring the past history of treatments leading up to time-point t).
#' The second option is to set \code{stratify_by_last=FALSE} in which case the outcome model will be fit only
#' among the observations who followed their counterfactual treatment regimen throughout the entire treatment history up to
#' current time-point t (rule followers). For the latter option, the observation would be considered a non-follower if
#' the person's treatment did not match their supposed counterfactual treatment at any time-point up to and including current
#' time-point t.
#' @param stratify_by_last Only used when \code{stratifyQ_by_rule} is \code{TRUE}.
#' Set to \code{TRUE} for stratification by last time-point, set to \code{FALSE} for stratification by all time-points (rule-followers).
#' See \code{stratifyQ_by_rule} for more details.
#' @param TMLE Set to \code{TRUE} to run the usual longitudinal TMLE algorithm (with a separate TMLE update of Q for every sequential regression).
#' @param iterTMLE Set to \code{TRUE} to run the iterative univariate TMLE instead of the usual longitudinal TMLE.
#' When set to \code{TRUE} this will also provide the standard sequential Gcomp as party of the output.
#' @param CVTMLE Set to \code{TRUE} to run the CV-TMLE algorithm instead of the usual TMLE algorithm.
#' Must set either \code{TMLE}=\code{TRUE} or \code{iterTMLE}=\code{TRUE} for this argument to have any effect.
#' @param byfold_Q (ADVANCED USE) Fit iterative means (Q parameter) using "by-fold" (aka "fold-specific" or "split-specific") cross-validation approach.
#' Only works with \code{fit_method}=\code{"origamiSL"}.
#' @param IPWeights (Optional) result of calling function \code{getIPWeights} for running TMLE (evaluated automatically when missing)
#' @param trunc_weights Specify the numeric weight truncation value. All final weights exceeding the value in \code{trunc_weights} will be truncated.
#' @param models Optional parameters specifying the models for fitting the iterative (sequential) G-Computation formula.
#' Must be an object of class \code{ModelStack} specified with \code{gridisl::defModel} function.
#' @param weights Optional \code{data.table} with additional observation- and time-specific weights.  Must contain columns \code{ID}, \code{t} and \code{weight}.
#' The column named \code{weight} is merged back into the original data according to (\code{ID}, \code{t}). Not implemented yet.
#' @param max_iter For iterative TMLE only: Integer, set to maximum number of iterations for iterative TMLE algorithm.
#' @param adapt_stop For iterative TMLE only: Choose between two stopping criteria for iterative TMLE, default is \code{TRUE},
#' which will stop the iterative TMLE algorithm in an adaptive way. Specifically, the iterations will stop when the mean estimate
#' of the efficient influence curve is less than or equal to 1 / (\code{adapt_stop_factor}*sqrt(\code{N})), where
#' N is the total number of unique subjects in data and \code{adapt_stop_factor} is set to 10 by default.
#' When \code{TRUE}, the argument \code{tol_eps} is ignored and TMLE stops when either \code{max_iter} has been reached or this criteria has been satisfied.
#' When \code{FALSE}, the stopping criteria is determined by values of \code{max_iter} and \code{tol_eps}.
#' @param adapt_stop_factor For iterative TMLE only: The adaptive factor to choose the stopping criteria for iterative TMLE when
#' \code{adapt_stop} is set to \code{TRUE}. Default is 10.
#' TMLE will keep iterative until
#' the mean estimate of the efficient influence curve is less than 1 / (\code{adapt_stop_factor}*sqrt(\code{N})) or when the number of iterations is \code{max_iter}.
#' @param tol_eps For iterative TMLE only: Numeric error tolerance for the iterative TMLE update.
#' The iterative TMLE algorithm will stop when the absolute value of the TMLE intercept update is below \code{tol_eps}
#' @param parallel Set to \code{TRUE} to run the sequential G-COMP or TMLE in parallel (uses \code{foreach} with \code{dopar} and
#' requires a previously defined parallel back-end cluster)
#' @param fit_method Model selection approach. Can be either \code{"none"} - no model selection or
#' \code{"cv"} - discrete Super Learner using V fold cross-validation that selects the best model according to lowest cross-validated MSE (must specify the column name that contains the fold IDs) or
#' \code{"origamiSL"} - continuous Super Learner that uses the \code{origami} R package to select the
#' convex combination of the model predictions (aka model stacking).
#  \code{"holdout"} - model selection by splitting the data into training and validation samples according to lowest validation sample MSE (must specify the column of \code{TRUE} / \code{FALSE} indicators,
#  where \code{TRUE} indicates that this row will be selected as part of the model validation sample).
#' @param fold_column The column name in the input data (ordered factor) that contains the fold IDs to be used as part of the validation sample.
#' Use the provided function \code{\link{define_CVfolds}} to
#' define such folds or define the folds using your own method.
#' @param return_wts Applies only when \code{TMLE = TRUE}.
#' Return the data.table with subject-specific IP weights as part of the output.
#' Note: for large datasets setting this to \code{TRUE} may lead to extremely large object sizes!
#' @param return_fW When \code{TRUE}, will return the object fit for the last Q regression as part of the output table.
#' Can be used for obtaining subject-specific predictions of the counterfactual functional E(Y_{d}|W_i).
#' @param reg_Q (ADVANCED USE ONLY) Directly specify the Q regressions, separately for each time-point.
#' @param intervened_type_TRT (ADVANCED FUNCTIONALITY) Set to \code{NULL} by default, can be characters that are set to either 
#' \code{"bin"}, \code{"shift"} or \code{"MSM"}. 
#' Provides support for different types of interventions on \code{TRT} (treatment) node (counterfactual treatment node \code{A^*(t)}).
#' The default behavior is the same as \code{"bin"}, which assumes that \code{A^*(t)} is binary and 
#' is set equal to either \code{0}, \code{1} or \code{p(t)}, where 0<=\code{p(t)}<=1. 
#' Here, \code{p(t)} denotes the probability that counterfactual A^*(t) is equal to 1, i.e., P(A^*(t)=1)=\code{p(t)} 
#' and it can change in time and subject to subject.
#' For \code{"shift"}, it is assumed that the intervention node \code{A^*(t)} is a shift in the value of the continuous treatment \code{A}, 
#' i.e., \code{A^*(t)}=\code{A(t)}+delta(t).
#' Finally, for "MSM" it is assumed that we simply want the final intervention density \code{g^*(t)} to be set to a constant 1. 
#' This has use for static MSMs. 
#' @param intervened_type_MONITOR (ADVANCED FUNCTIONALITY) Same as \code{intervened_type_TRT}, but for monitoring intervention node 
#' (counterfactual monitoring node \code{N^*(t)}).
#' @param maxpY Maximum probability that the cumulative incidence of the outcome Y(t) is equal to 1.
#' Useful for upper-bounding the rare-outcomes.
#' @param TMLE_updater Function for performing the TMLE update. Default is the TMLE updater based on speedglm (called \code{"TMLE.updater.speedglm"}).
#' Other possible options include \code{"TMLE.updater.glm"}, \code{"linear.TMLE.updater.speedglm"} and \code{"iTMLE.updater.xgb"}.
#' @param verbose Set to \code{TRUE} to print auxiliary messages during model fitting.
#' @param ... When \code{models} arguments is NOT specified, these additional arguments will be passed on directly to all \code{GridSL}
#' modeling functions that are called from this routine,
#' e.g., \code{family = "binomial"} can be used to specify the model family.
#' Note that all such arguments must be named.
#' @return An output list containing the \code{data.table} with survival estimates over time saved as \code{"estimates"}.
#' @seealso \code{\link{stremr-package}} for the general overview of the package.
#' @example tests/examples/2_building_blocks_example.R
#' @export
# proposed name change:
fit_GCOMP <- function(OData,
                        tvals,
                        Qforms,
                        intervened_TRT = NULL,
                        intervened_MONITOR = NULL,
                        rule_name = paste0(c(intervened_TRT, intervened_MONITOR), collapse = ""),
                        models = NULL,
                        fit_method = stremrOptions("fit_method"),
                        fold_column = stremrOptions("fold_column"),
                        TMLE = FALSE,
                        stratifyQ_by_rule = FALSE,
                        stratify_by_last = TRUE,
                        Qstratify = NULL,
                        useonly_t_TRT = NULL,
                        useonly_t_MONITOR = NULL,
                        iterTMLE = FALSE,
                        CVTMLE = FALSE,
                        byfold_Q = FALSE,
                        IPWeights = NULL,
                        trunc_weights = 10^6,
                        weights = NULL,
                        max_iter = 15,
                        adapt_stop = TRUE,
                        adapt_stop_factor = 10,
                        tol_eps = 0.001,
                        parallel = FALSE,
                        return_wts = FALSE,
                        return_fW = FALSE,
                        reg_Q = NULL,
                        intervened_type_TRT = NULL,
                        intervened_type_MONITOR = NULL,
                        maxpY = 1.0,
                        TMLE_updater = "TMLE.updater.speedglm",
                        verbose = getOption("stremr.verbose"), ...) {

  if (!is.null(weights)) stop("optional argument 'weights' is not implemented yet for TMLE or GCOMP")

  gvars$verbose <- verbose
  nodes <- OData$nodes
  new.factor.names <- OData$new.factor.names

  assert_that(is.logical(adapt_stop))

  if (TMLE & iterTMLE) stop("Either 'TMLE' or 'iterTMLE' must be set to FALSE. Cannot estimate both within a single algorithm run.")

  if (byfold_Q && !(fit_method %in% "origamiSL"))
    stop("Running byfold_Q=TRUE requires seting fit_method='origamiSL'")

  if (missing(rule_name)) rule_name <- paste0(c(intervened_TRT,intervened_MONITOR), collapse = "")

  added_rule_col <- FALSE
  if (!"rule.name" %in% names(OData$dat.sVar)) {
    added_rule_col <- TRUE
    rule_name <- paste0(c(intervened_TRT,intervened_MONITOR), collapse = "")
    OData$dat.sVar[, "rule.name" := rule_name]
  } else {
    rule_name <- unique(OData$dat.sVar[["rule.name"]])
  }

  # ------------------------------------------------------------------------------------------------
  # **** Evaluate the uncensored and initialize rule followers (everybody is a follower by default)
  # ------------------------------------------------------------------------------------------------
  OData$uncensored <- OData$eval_uncensored()
  OData$follow_rule <- rep.int(TRUE, nrow(OData$dat.sVar)) # (everybody is a follower by default)

  opt_params <- capture.exprs(...)
  if (!("family" %in% names(opt_params))) opt_params[["family"]] <- quasibinomial()

  if (!is.null(models)) {
    assert_that(is.ModelStack(models) || is(models, "Lrnr_base"))
  } else {
    # models <- do.call(sl3::Lrnr_glm_fast$new, opt_params)
    opt_params_est = c(estimator = "speedglm__glm", opt_params)
    models <- do.call(gridisl::defModel, opt_params_est)
    # print("--------------------------------------------------")
    # print("opt_params_est: "); print(opt_params_est)
    # print("models: "); print(models)
    # print("--------------------------------------------------")
  }

  models_control <- c(list(models = models), 
                      list(reg_Q = reg_Q), 
                      opt_params = list(opt_params))
  models_control[["fit_method"]] <- fit_method[1L]
  models_control[["fold_column"]] <- fold_column

  # ------------------------------------------------------------------------------------------------
  # **** Add weights if TMLE=TRUE and if weights were defined
  # NOTE: This needs to be done only once if evaluating survival over several time-points
  # ------------------------------------------------------------------------------------------------
  if (TMLE || iterTMLE) {
    if (is.null(IPWeights)) {
      if (gvars$verbose) message("...evaluating IPWeights for TMLE...")
      IPWeights <- getIPWeights(OData,
                                intervened_TRT,
                                intervened_MONITOR,
                                useonly_t_TRT,
                                useonly_t_MONITOR,
                                rule_name,
                                holdout = CVTMLE,
                                eval_stabP = FALSE,
                                intervened_type_TRT = intervened_type_TRT,
                                intervened_type_MONITOR = intervened_type_MONITOR)
      if (trunc_weights < Inf) IPWeights[eval(as.name("cum.IPAW")) > trunc_weights, ("cum.IPAW") := trunc_weights]
    } else {
      getIPWeights_fun_call <- attributes(IPWeights)[['getIPWeights_fun_call']]
      if (gvars$verbose) message("applying user-specified IPWeights obtained with a call: \n" %+% deparse(getIPWeights_fun_call)[[1]])
      assert_that(all.equal(attributes(IPWeights)[['intervened_TRT']], intervened_TRT))
      assert_that(all.equal(attributes(IPWeights)[['intervened_MONITOR']], intervened_MONITOR))
    }
    assert_that(is.data.table(IPWeights))
    assert_that("cum.IPAW" %in% names(IPWeights))
    OData$IPwts_by_regimen <- IPWeights

    ## Add additional observation-specific weights to the cumulative weights:
    # IPWeights <- process_opt_wts(IPWeights, weights, nodes, adjust_outcome = FALSE)
  }

  if (missing(tvals)) stop("must specify survival 'tvals' of interest (time period values from column " %+% nodes$tnode %+% ")")

  # ------------------------------------------------------------------------------------------
  # Create a back-up of the observed input gstar nodes (created by user in input data):
  # Will add new columns (that were not backed up yet) TO SAME backup data.table
  # ------------------------------------------------------------------------------------------
  OData$backupNodes(c(intervened_TRT,intervened_MONITOR))

  # ------------------------------------------------------------------------------------------------
  # Define the intervention nodes
  # Modify the observed input intervened_NODE in OData$dat.sVar with values from NodeNames for subset_idx
  # ------------------------------------------------------------------------------------------------
  gstar.A <- defineNodeGstarGCOMP(OData, intervened_TRT, nodes$Anodes, useonly_t_TRT, stratifyQ_by_rule, stratify_by_last)
  gstar.N <- defineNodeGstarGCOMP(OData, intervened_MONITOR, nodes$Nnodes, useonly_t_MONITOR, stratifyQ_by_rule, stratify_by_last)
  interventionNodes.g0 <- c(nodes$Anodes, nodes$Nnodes)
  interventionNodes.gstar <- c(gstar.A, gstar.N)

  # When the same node names belongs to both g0 and gstar it doesn't need to be intervened upon, so exclude
  common_names <- intersect(interventionNodes.g0, interventionNodes.gstar)
  interventionNodes.g0 <- interventionNodes.g0[!interventionNodes.g0 %in% common_names]
  interventionNodes.gstar <- interventionNodes.gstar[!interventionNodes.gstar %in% common_names]

  OData$interventionNodes.g0 <- interventionNodes.g0
  OData$interventionNodes.gstar <- interventionNodes.gstar

  ## ------------------------------------------------------------------------------------------------
  ## RUN GCOMP OR TMLE FOR SEVERAL TIME-POINTS EITHER IN PARALLEL OR SEQUENTIALLY
  ## For est of S(t) over vector of ts, estimate for highest t first going down to smallest t
  ## This is a more efficient when parallelizing, since larger t implies more model runs & longer run time
  ## ------------------------------------------------------------------------------------------------
  est_name <- ifelse(TMLE, "TMLE", ifelse(iterTMLE, "GCOMP & iterTMLE", "GCOMP"))
  tmle.run.res <- try(
    if (parallel) {
      mcoptions <- list(preschedule = FALSE)
      '%dopar%' <- foreach::'%dopar%'
      res_byt <- foreach::foreach(t_idx = rev(seq_along(tvals)), .options.multicore = mcoptions) %dopar% {
        t_period <- tvals[t_idx]
        res <- fit_GCOMP_onet(OData, t_period, Qforms, Qstratify, stratifyQ_by_rule,
                                TMLE = TMLE, iterTMLE = iterTMLE, CVTMLE = CVTMLE, byfold_Q = byfold_Q,
                                models = models_control, max_iter = max_iter, adapt_stop = adapt_stop,
                                adapt_stop_factor = adapt_stop_factor, tol_eps = tol_eps,
                                return_fW = return_fW, maxpY = maxpY, TMLE_updater = TMLE_updater, verbose = verbose)
        return(res)
      }
      res_byt[] <- res_byt[rev(seq_along(tvals))] # re-assign to order results by increasing t
    } else {
      res_byt <- vector(mode = "list", length = length(tvals))
      for (t_idx in rev(seq_along(tvals))) {
        t_period <- tvals[t_idx]
        res <- fit_GCOMP_onet(OData, t_period, Qforms, Qstratify, stratifyQ_by_rule,
                                TMLE = TMLE, iterTMLE = iterTMLE, CVTMLE = CVTMLE, byfold_Q = byfold_Q,
                                models = models_control, max_iter = max_iter, adapt_stop = adapt_stop,
                                adapt_stop_factor = adapt_stop_factor, tol_eps = tol_eps,
                                return_fW = return_fW, maxpY = maxpY, TMLE_updater = TMLE_updater, verbose = verbose)
        res_byt[[t_idx]] <- res
      }
    }
  )

  ## using delayed (works with multisession, errors with multcore)
  # subtasks <- lapply(rev(seq_along(tvals)), function(t_idx) {
  #   t_period <- tvals[t_idx]
  #   delayed_Q <- delayed_fun(fit_GCOMP_onet)(OData, t_period, Qforms, Qstratify, stratifyQ_by_rule,
  #                  TMLE = TMLE, iterTMLE = iterTMLE, CVTMLE = CVTMLE, byfold_Q = byfold_Q,
  #                  models = models_control, max_iter = max_iter, adapt_stop = adapt_stop,
  #                  adapt_stop_factor = adapt_stop_factor, tol_eps = tol_eps,
  #                  return_fW = return_fW, maxpY = maxpY, TMLE_updater = TMLE_updater, verbose = verbose)
  #   delayed_Q$expect_error <- TRUE
  #   return(delayed_Q)
  # })
  # bundle_Q <- bundle_delayed(subtasks)

  # tmle.run.res <- try(
  #   res_byt <- bundle_Q$compute(verbose = TRUE)
  #   # res_byt <- bundle_Q$compute(nworkers = 2, verbose = TRUE)
  # )

  ## using native future interface (error: result is too long a vector):
  # f <- list()
  # for (ii in seq_along(tvals)) {
  #   t_period <- tvals[ii]
  #   f[[ii]] <- future({
  #     fit_GCOMP_onet(OData, t_period, Qforms, Qstratify, stratifyQ_by_rule,
  #                    TMLE = TMLE, iterTMLE = iterTMLE, CVTMLE = CVTMLE, byfold_Q = byfold_Q,
  #                    models = models_control, max_iter = max_iter, adapt_stop = adapt_stop,
  #                    adapt_stop_factor = adapt_stop_factor, tol_eps = tol_eps,
  #                    return_fW = return_fW, maxpY = maxpY, TMLE_updater = TMLE_updater, verbose = verbose)
  #   })
  # }
  # #
  # res <- f
  # res <- lapply(f, FUN = value)

  ## Restore backed up nodes, even in the event of failure (otherwise the input data is corrupted for good)
  OData$restoreNodes(c(intervened_TRT,intervened_MONITOR))
  ## remove the newly added rule name column from observed data
  if (added_rule_col) {
    OData$dat.sVar[, "rule.name" := NULL]
  }

  if (inherits(tmle.run.res, "try-error")) {
    stop(
"...attempt at running TMLE for one or several time-points has failed for unknown reason;
If this error cannot be fixed, consider creating a replicable example and filing a bug report at:
  https://github.com/osofr/stremr/issues
", call. = TRUE)
  }

  resultDT <- rbindlist(res_byt)
  ## to extract the EIC estimates by time point (no longer returning those separately, only as part of the estimates data.table)
  # ICs_byt <- resultDT[["IC.St"]]
  # IC.Var.S.d <- t(do.call("cbind", ICs_byt))

  resultDT <- cbind(est_name = est_name, resultDT)
  if (!"rule.name" %in% names(resultDT)) {
    resultDT[, "rule.name" := eval(as.character(rule_name))]
  }

  attr(resultDT, "estimator_short") <- est_name
  attr(resultDT, "estimator_long") <- est_name
  attr(resultDT, "nID") <- OData$nuniqueIDs
  attr(resultDT, "rule_name") <- rule_name
  attr(resultDT, "stratifyQ_by_rule") <- stratifyQ_by_rule
  attr(resultDT, "stratify_by_last") <- stratify_by_last
  attr(resultDT, "trunc_weights") <- trunc_weights
  attr(resultDT, "time") <- resultDT[["time"]]

  res_out <- list(
              # est_name = est_name,
              # periods = tvals,
              # IC.Var.S.d = list(IC.S = IC.Var.S.d),
              # nID = OData$nuniqueIDs,
              wts_data = { if ((TMLE || iterTMLE) && return_wts) { IPWeights } else { NULL } },
              # rule_name = rule_name,
              # trunc_weights = trunc_weights
              estimates = resultDT
              )

  attr(res_out, "estimator_short") <- est_name
  attr(res_out, "estimator_long") <- est_name
  return(res_out)
}

# ------------------------------------------------------------------------------------------------
# ITERATIVE (UNIVARIATE) TMLE AGLORITHM for a single time-point.
# Called as part of the fit_GCOMP_onet()
# ------------------------------------------------------------------------------------------------
iterTMLE_onet <- function(OData, Qlearn.fit, Qreg_idx, max_iter = 15, adapt_stop = TRUE, adapt_stop_factor = 10, tol_eps = 0.001, TMLE_updater = "TMLE.updater.speedglm") {
  get_field_Qclass <- function(allQmodels, fieldName) {
    lapply(allQmodels, function(Qclass) Qclass[[fieldName]])
    # lapply(allQmodels, function(Qclass) Qclass$getPsAsW.models()[[1]][[fieldName]])
  }

  eval_idx_used_to_fit_initQ <- function(allQmodels, OData) {
    lapply(allQmodels, function(Qclass) which(Qclass$define_idx_to_fit_initQ(data = OData)))
    # lapply(allQmodels, function(Qclass) which(Qclass$getPsAsW.models()[[1]]$define_idx_to_fit_initQ(data = OData)))
  }

  # clean up left-over indices after we are done iterating
  eval_wipe.all.indices <- function(allQmodels, OData) {
    lapply(allQmodels, function(Qclass) Qclass$wipe.all.indices)
    # lapply(allQmodels, function(Qclass) Qclass$getPsAsW.models()[[1]]$wipe.all.indices)
  }

  Propagate_TMLE_fits <- function(allQmodels, OData, TMLE.fit) {
    lapply(allQmodels,
      function(Qclass) Qclass$Propagate_TMLE_fit(data = OData, new.TMLE.fit = TMLE.fit))
      # function(Qclass) Qclass$getPsAsW.models()[[1]]$Propagate_TMLE_fit(data = OData, new.TMLE.fit = TMLE.fit))
    return(invisible(allQmodels))
  }

  TMLE_updater <- get(TMLE_updater)

  allQmodels <- Qlearn.fit$getPsAsW.models() # Get the individual Qlearning classes
  # res_all_subset_idx <- as.vector(sort(unlist(get_field_Qclass(allQmodels, "subset_idx"))))
  # use_subset_idx <- res_all_subset_idx
  # use_subset_idx <- res_idx_used_to_fit_initQ

  # one cat'ed vector of all observations that were used for fitting init Q & updating TMLE (across all t's):
  res_idx_used_to_fit_initQ <- as.vector(sort(unlist(get_field_Qclass(allQmodels, "idx_used_to_fit_initQ"))))
  # res_idx_used_to_fit_initQ_2 <- as.vector(sort(unlist(eval_idx_used_to_fit_initQ(allQmodels, OData))))
  # all.equal(res_idx_used_to_fit_initQ, res_idx_used_to_fit_initQ_2)

  # Consider only observations with non-zero weights, these are the only obs that are needed for the TMLE update:
  idx_all_wts_above0 <- which(OData$IPwts_by_regimen[["cum.IPAW"]] > 0)
  use_subset_idx <- intersect(idx_all_wts_above0, res_idx_used_to_fit_initQ)
  wts_TMLE <- OData$IPwts_by_regimen[use_subset_idx, "cum.IPAW", with = FALSE][[1]]

  for (iter in 1:(max_iter+1)) {
    Qkplus1 <- OData$dat.sVar[use_subset_idx, "Qkplus1", with = FALSE][[1]]
    Qk_hat <- OData$dat.sVar[use_subset_idx, "Qk_hat", with = FALSE][[1]]

    # ------------------------------------------------------------------------------------
    # ESTIMATE OF THE EIC:
    # ------------------------------------------------------------------------------------
    # Get t-specific and i-specific components of the EIC for all t > t.init:
    EIC_i_tplus <- wts_TMLE * (Qkplus1 - Qk_hat)

    # Get t-specific and i-specific components of the EIC for all t = t.init (mean pred from last reg, all n obs)
    res_lastPredQ <- allQmodels[[Qreg_idx[1]]]$predictAeqa()  # Qreg_idx[1] is the index for the last Q-fit
    # res_lastPredQ <- allQmodels[[length(allQmodels)]]$predictAeqa()

    EIC_i_t0 <- (res_lastPredQ - mean(res_lastPredQ))
    # Sum them all up and divide by N -> obtain the estimate of P_n(D^*_n):
    EIC_est <- (sum(EIC_i_t0) + sum(EIC_i_tplus)) / OData$nuniqueIDs
    if (gvars$verbose) print("Mean estimate of the EIC: " %+% EIC_est)

    # Quit loop if the error tolerance level has been reached
    if (iter > 1) {
      if (adapt_stop) {
        # if (abs(EIC_est) < (1 / OData$nuniqueIDs)) break
        if (abs(EIC_est) < (1 / (adapt_stop_factor*sqrt(OData$nuniqueIDs)))) break
      } else {
        if (!is.null(tol_eps) & (abs(TMLE.fit$object$coef) <= tol_eps)) break
      }
    } else if (iter > max_iter) {
      break
    }

    Y <- Qkplus1
    newX <- X <- data.frame(intercept = rep(1L, length(Qk_hat)), offset = qlogis(Qk_hat))
    TMLE.fit <- TMLE_updater(Y, X, newX, obsWeights = wts_TMLE)
    TMLE.fit <- TMLE.fit$fit
    # TMLE.fit <- tmle.update(Qkplus1 = Qkplus1, Qk_hat = Qk_hat, IPWts = wts_TMLE, lower_bound_zero_Q = FALSE, skip_update_zero_Q = FALSE)

    Propagate_TMLE_fits(allQmodels, OData, TMLE.fit)
  }

  if (gvars$verbose) print("iterative TMLE ran for N iter: " %+% iter)

  # clean up after we are done iterating (set the indices in Q classes to NULL to conserve memory)
  tmp <- eval_wipe.all.indices(allQmodels)

  # EVALUTE THE t and i-specific components of the EIC (estimates):
  # Qkplus1 <- OData$dat.sVar[use_subset_idx, "Qkplus1", with = FALSE][[1]]
  # Qk_hat <- OData$dat.sVar[use_subset_idx, "Qkplus1", with = FALSE][[1]]
  # EIC_i_t_calc <- wts_TMLE * (Qkplus1 - Qk_hat)
  # OData$dat.sVar[use_subset_idx, ("EIC_i_t") := EIC_i_t_calc]
  OData$dat.sVar[use_subset_idx, ("EIC_i_t") := EIC_i_tplus]

  return(invisible(Qlearn.fit))
}

# GCOMP_onet <- function(OData,
fit_GCOMP_onet <- function(OData,
                             t_period,
                             Qforms,
                             Qstratify,
                             stratifyQ_by_rule,
                             TMLE,
                             iterTMLE,
                             CVTMLE = FALSE,
                             byfold_Q = FALSE,
                             models,
                             max_iter = 10,
                             adapt_stop = TRUE,
                             adapt_stop_factor = 10,
                             tol_eps = 0.001,
                             return_fW = FALSE,
                             maxpY = 1.0,
                             TMLE_updater = "TMLE.updater.speedglm",
                             verbose = getOption("stremr.verbose")) {
  gvars$verbose <- verbose
  nodes <- OData$nodes
  new.factor.names <- OData$new.factor.names

  # ------------------------------------------------------------------------------------------------
  # Defining the t periods to loop over FOR A SINGLE RUN OF THE iterative G-COMP/TMLE (one survival point)
  # **** TO DO: The stratification by follow-up has to be based only on 't' values that were observed in the data****
  # ------------------------------------------------------------------------------------------------
  Qperiods <- rev(OData$min.t:t_period)
  Qreg_idx <- rev(seq_along(Qperiods))
  Qstratas_by_t <- as.list(nodes[['tnode']] %+% " == " %+% (Qperiods))
  names(Qstratas_by_t) <- rep.int("Qkplus1", length(Qstratas_by_t))

  # Adding user-specified stratas to each Q(t) regression:
  all_Q_stratify <- Qstratas_by_t
  if (!is.null(Qstratify)) {
    stop("...Qstratify is not implemented yet...")
    assert_that(is.vector(Qstratify))
    assert_that(is.character(Qstratify))
    for (idx in seq_along(all_Q_stratify)) {
      all_Q_stratify[[idx]] <- paste0(all_Q_stratify[[idx]], " & ", Qstratify)
    }
  }

  # ------------------------------------------------------------------------------------------------
  # **** Process the input formulas and stratification settings
  # **** TO DO: Add checks that Qforms has correct number of regressions in it
  # ------------------------------------------------------------------------------------------------
  Qforms.default <- rep.int("Qkplus1 ~ Lnodes + Anodes + Cnodes + Nnodes", length(Qperiods))
  if (missing(Qforms)) {
    Qforms_single_t <- Qforms.default
  } else {
    # Qforms_single_t <- Qforms[seq_along(Qperiods)]
    Qforms_single_t <- Qforms[Qreg_idx]
  }

  # ------------------------------------------------------------------------------------------------
  #  ****** Qkplus1 THIS NEEDS TO BE MOVED OUT OF THE MAIN data.table *****
  # Running fit_GCOMP_onet might conflict with different Qkplus1
  # ------------------------------------------------------------------------------------------------
  # **** G-COMP: Initiate Qkplus1 - (could be multiple if more than one regimen)
  # That column keeps the tabs on the running Q-fit (SEQ G-COMP)
  # ------------------------------------------------------------------------------------------------
  OData$dat.sVar[, ("EIC_i_t") := 0.0] # set the initial (default values of the t-specific and i-specific EIC estimates)
  OData$dat.sVar[, "Qkplus1" := as.numeric(get(OData$nodes$Ynode))] # set the initial values of Q (the observed outcome node)

  if ("Qk_hat" %in% names(OData$dat.sVar)) {
    OData$dat.sVar[, "Qk_hat" := NULL]
  }
  if ("res_lastPredQ" %in% names(OData$dat.sVar)) {
    OData$dat.sVar[, "res_lastPredQ" := NULL]
  }

  # ------------------------------------------------------------------------------------------------
  # **** Define regression classes for Q.Y and put them in a single list of regressions.
  # **** TO DO: This could also be done only once in the main routine, then just subset the appropriate Q_regs_list
  # ------------------------------------------------------------------------------------------------
  Q_regs_list <- vector(mode = "list", length = length(Qstratas_by_t))
  names(Q_regs_list) <- unlist(Qstratas_by_t)
  class(Q_regs_list) <- c(class(Q_regs_list), "ListOfRegressionForms")

  for (i in seq_along(Q_regs_list)) {
    regform <- process_regform(as.formula(Qforms_single_t[[i]]), sVar.map = nodes, factor.map = new.factor.names)
    if (!is.null(models[["reg_Q"]])) {
      models[["models"]] <- models[["reg_Q"]][[Qreg_idx[i]]]
    }
    last.reg <- Qreg_idx[i] == 1L
    reg <- RegressionClassQlearn$new(Qreg_counter = Qreg_idx[i],
                                     all_Qregs_indx = Qreg_idx,
                                     t_period = Qperiods[i],
                                     TMLE = TMLE,
                                     CVTMLE = CVTMLE,
                                     byfold_Q = byfold_Q,
                                     keep_idx = ifelse(iterTMLE, TRUE, FALSE),
                                     # keep_model_fit = ifelse(iterTMLE, TRUE, last.reg),
                                     keep_model_fit = ifelse(iterTMLE, TRUE, FALSE),
                                     stratifyQ_by_rule = stratifyQ_by_rule,
                                     outvar = "Qkplus1",
                                     predvars = regform$predvars,
                                     outvar.class = list("Qlearn"),
                                     subset_vars = list("Qkplus1"),
                                     subset_exprs = all_Q_stratify[i],
                                     model_contrl = models,
                                     censoring = FALSE,
                                     maxpY = maxpY,
                                     TMLE_updater = TMLE_updater)

    ## For Q-learning this reg class always represents a terminal model class,
    ## since there cannot be any additional model-tree splits by values of subset_vars, subset_exprs, etc.
    ## The following two lines allow for a slightly simplified (shallower) tree representation of ModelGeneric-type classes.
    ## This also means that stratifying Q fits by some covariate value will not possible with this approach
    ## (i.e., such stratifications would have to be implemented locally by the actual model fitting functions).
    reg_i <- reg$clone()
    reg <- reg_i$ChangeManyToOneRegresssion(1, reg)
    Q_regs_list[[i]] <- reg
  }

  # Run all Q-learning regressions (one for each subsets defined above, predictions of the last regression form the outcomes for the next:
  Qlearn.fit <- ModelGeneric$new(reg = Q_regs_list, DataStorageClass.g0 = OData)
  Qlearn.fit$fit(data = OData, Qlearn.fit = Qlearn.fit)
  # Qlearn.fit$getPsAsW.models()[[1]]
  OData$Qlearn.fit <- Qlearn.fit

  # When the model is fit with user-defined stratas, need special functions to extract the final fit (current aproach will not work):
  # allQmodels[[1]]$getPsAsW.models()[[1]]$getPsAsW.models()

  # get the individual TMLE updates and evaluate if any updates have failed
  allQmodels <- Qlearn.fit$getPsAsW.models()

  lastQ_inx <- Qreg_idx[1] # The index for the last Q-fit (first time-point)
  lastQ.fit <- allQmodels[[lastQ_inx]]
  subset_vars <- lastQ.fit$subset_vars
  subset_exprs <- lastQ.fit$subset_exprs
  subset_idx <- OData$evalsubst(subset_exprs = subset_exprs) ## subset_vars = subset_vars,
  
  Qreg_fail <- lastQ.fit$Qreg_fail ## indicator that one of the Q model fit has failed

  ## Get the previously saved mean prediction for Q from the very last regression (first time-point, all n obs):
  res_lastPredQ <- lastQ.fit$predictAeqa()
  ## remove vector of predictions from the modeling class (to free up memory):
  if (!iterTMLE) {
    lastQ.fit$wipe.probs
  }

  ## Can instead grab the last prediction of Q (Qk_hat) directly from the data, using the appropriate strata-subsetting expression:
  mean_est_t_2 <- OData$dat.sVar[subset_idx, list("cum.inc" = mean(Qk_hat)), by = "rule.name"]

  ## Evaluate the last predictions by rule (which rule each prediction belongs to):
  OData$dat.sVar[subset_idx, "res_lastPredQ" := res_lastPredQ]
  mean_est_t <- OData$dat.sVar[subset_idx, list("cum.inc" = mean(res_lastPredQ)), by = "rule.name"]

  if (Qreg_fail) {
    mean_est_t[, cum.inc := NA]
    mean_est_t_2[, cum.inc := NA]
    }

  if (gvars$verbose) {
    print("mean est: ");  print(mean_est_t)
    print("mean est 2: "); print(mean_est_t_2)
    print("No. of obs for last prediction of Q: " %+% length(res_lastPredQ))
  }

  resDF_onet <- data.table(time = t_period,
                           St.GCOMP = NA,
                           St.TMLE = NA,
                           type = ifelse(stratifyQ_by_rule, "stratified", "pooled"),
                           rule.name = unique(OData$dat.sVar[["rule.name"]])
                          )

  est_name <- ifelse(TMLE, "St.TMLE", "St.GCOMP")
  resDF_onet <- resDF_onet[mean_est_t_2, on = "rule.name"]

  resDF_onet[, (est_name) := (1 - cum.inc)]

  # ------------------------------------------------------------------------------------------------
  # RUN ITERATIVE TMLE (updating all Q's at once):
  # ------------------------------------------------------------------------------------------------
  if (iterTMLE){
    res <- iterTMLE_onet(OData, Qlearn.fit, Qreg_idx, max_iter = max_iter, adapt_stop = adapt_stop, adapt_stop_factor = adapt_stop_factor, tol_eps = tol_eps, TMLE_updater = TMLE_updater)
    ## 1a. Grab the mean prediction from the very last regression (over all n observations);
    lastQ_inx <- Qreg_idx[1]
    res_lastPredQ <- allQmodels[[lastQ_inx]]$predictAeqa()

    mean_est_t <- mean(res_lastPredQ)
    if (gvars$verbose) print("Iterative TMLE surv estimate: " %+% (1 - mean_est_t))

    ## 1b. Grab it directly from the data, using the appropriate strata-subsetting expression
    resDF_onet[, ("St.TMLE") := (1 - mean_est_t)]
  }

  # ------------------------------------------------------------------------------------------------
  # TMLE INFERENCE
  # ------------------------------------------------------------------------------------------------
  IC_i_onet <- vector(mode = "numeric", length = OData$nuniqueIDs)
  IC_i_onet[] <- NA
  ## save the i-specific estimates of the EIC as a separate column:
  resDF_onet[, ("IC.St") := list(list(IC_i_onet))]

  if (TMLE || iterTMLE) {
    resDF_onet[, ("IC.St") := NULL]
    IC_dt <- OData$dat.sVar[, list("EIC_i_tplus" = sum(eval(as.name("EIC_i_t")))), by = c(nodes$IDnode, "rule.name")]

    IC_dt_t0 <- OData$dat.sVar[subset_idx,  c(nodes$IDnode, "res_lastPredQ", "rule.name"), with = FALSE][,
      "mean_Q" := mean(res_lastPredQ), by = c("rule.name")][,
      "EIC_i_t0" := res_lastPredQ-mean_Q][,
      c("mean_Q", "res_lastPredQ") := list(NULL, NULL)
      ]

    ## remove newly added column from observed dataset
    OData$dat.sVar[, "res_lastPredQ" := NULL]

    IC_dt <- IC_dt[IC_dt_t0, on = c(nodes$IDnode, "rule.name")]
    IC_dt[, ("EIC_i") := EIC_i_t0 + EIC_i_tplus]
    IC_dt[, c("EIC_i_t0", "EIC_i_tplus") :=  list(NULL, NULL)]

    ## asymptotic variance (var of the EIC), by rule:
    IC_Var <- IC_dt[,
      list("IC_Var" = (1 / (.N)) * sum(EIC_i^2),
           "TMLE_Var" = (1 / (.N)) * sum(EIC_i^2) / .N
            ), by = "rule.name"][,
      "SE.TMLE" := sqrt(TMLE_Var)]
    IC_Var[, "IC_Var" := NULL][, "TMLE_Var" := NULL]

    if (Qreg_fail) {
      IC_Var[, SE.TMLE := NA]
    }

    if (gvars$verbose) {
      print("...empirical mean of the estimated EIC: ")
      print(IC_dt[, list("mean_EIC" = mean(EIC_i)), by = "rule.name"])
      print("...estimated TMLE variance: "); print(IC_Var)
    }

    resDF_onet <- resDF_onet[IC_Var, on = "rule.name"]
    data.table::setkeyv(IC_dt, c("rule.name", nodes$IDnode))
    IC_i_onet_byrule <- IC_dt[, (nodes$IDnode) := NULL] %>%
                        tibble::as_tibble() %>%
                        tidyr::nest(EIC_i, .key = "IC.St") %>%
                        dplyr::mutate(IC.St = purrr::map(IC.St, ~ .x[[1]]))
    data.table::setDT(IC_i_onet_byrule)
    # IC_i_onet_byrule[, "IC.St" := list(list(as.numeric(IC.St[[1]][[1]]))), by = "rule.name"]
    resDF_onet <- resDF_onet[IC_i_onet_byrule, on = "rule.name"]
  }

  fW_fit <- lastQ.fit$getfit
  resDF_onet[, ("fW_fit") := { if (return_fW) {list(list(fW_fit))} else {list(list(NULL))} }]
  
  ## re-order the columns so that "rule.name" col is always last
  data.table::setcolorder(resDF_onet, c(setdiff(names(resDF_onet), "rule.name"), "rule.name"))
  return(resDF_onet)
}
osofr/estimtr documentation built on Jan. 25, 2022, 8:05 a.m.