R/AIPW.R

#' @title Augmented Inverse Probability Weighting (AIPW)
#'
#' @description An R6Class of AIPW for estimating the average causal effects with users' inputs of exposure, outcome, covariates and related
#' libraries for estimating the efficient influence function.
#'
#' @details An AIPW object is constructed by `new()` with users' inputs of data and causal structures, then it `fit()` the data using the
#' libraries in `Q.SL.library` and `g.SL.library` with `k_split` cross-fitting, and provides results via the `summary()` method.
#' After using `fit()` and/or `summary()` methods, propensity scores  and inverse probability weights by exposure status can be
#' examined with `plot.p_score()` and `plot.ip_weights()`, respectively.
#'
#' If outcome is missing, analysis assumes missing at random (MAR) by estimating propensity scores of I(A=a, observed=1) with all covariates `W`.
#' (`W.Q` and `W.g` are disabled.) Missing exposure is not supported.
#'
#' See examples for illustration.
#'
#' @section Constructor:
#' \code{AIPW$new(Y = NULL, A = NULL, W = NULL, W.Q = NULL, W.g = NULL, Q.SL.library = NULL, g.SL.library = NULL, k_split = 10, verbose = TRUE, save.sl.fit = FALSE)}
#'
#' ## Constructor Arguments
#' \tabular{lll}{
#' \strong{Argument}      \tab   \strong{Type}     \tab     \strong{Details} \cr
#' \code{Y}               \tab   Integer           \tab     A vector of outcome (binary (0, 1) or continuous) \cr
#' \code{A}               \tab   Integer           \tab     A vector of binary exposure (0 or 1) \cr
#' \code{W}               \tab   Data              \tab     Covariates for \strong{both} exposure and outcome models. \cr
#' \code{W.Q}             \tab   Data              \tab     Covariates for the \strong{outcome} model (Q).\cr
#' \code{W.g}             \tab   Data              \tab     Covariates for the \strong{exposure} model (g). \cr
#' \code{Q.SL.library}    \tab   SL.library        \tab     Algorithms used for the \strong{outcome} model (Q). \cr
#' \code{g.SL.library}    \tab   SL.library         \tab    Algorithms used for the \strong{exposure} model (g). \cr
#' \code{k_split}         \tab   Integer           \tab    Number of folds for splitting (Default = 10).\cr
#' \code{verbose}         \tab   Logical           \tab    Whether to print the result (Default = TRUE) \cr
#' \code{save.sl.fit}     \tab   Logical          \tab     Whether to save Q.fit and g.fit (Default = FALSE) \cr
#' }
#'
#' ## Constructor Argument Details
#' \describe{
#'   \item{\code{W}, \code{W.Q} & \code{W.g}}{It can be a vector, matrix or data.frame. If and only if `W == NULL`, `W` would be replaced by `W.Q` and `W.g`. }
#'   \item{\code{Q.SL.library} & \code{g.SL.library}}{Machine learning algorithms from [SuperLearner] libraries}
#'   \item{\code{k_split}}{It ranges from 1 to number of observation-1.
#'                         If k_split=1, no cross-fitting; if k_split>=2, cross-fitting is used
#'                         (e.g., `k_split=10`, use 9/10 of the data to estimate and the remaining 1/10 leftover to predict).
#'                          \strong{NOTE: it's recommended to use cross-fitting.} }
#'  \item{\code{save.sl.fit}}{This option allows users to save the fitted sl object (libs$Q.fit & libs$g.fit) for debug use.
#'                             \strong{Warning: Saving the SuperLearner fitted object may cause a substantive storage/memory use.}}
#' }
#'
#'
#' @section Public Methods:
#'  \tabular{lll}{
#'  \strong{Methods}      \tab   \strong{Details}                                      \tab \strong{Link}     \cr
#'  \code{fit()}          \tab   Fit the data to the [AIPW] object                     \tab   [fit.AIPW]  \cr
#'  \code{stratified_fit()}\tab   Fit the data to the [AIPW] object stratified by `A`  \tab   [stratified_fit.AIPW]  \cr
#'  \code{summary()}      \tab   Summary of the average treatment effects from AIPW    \tab   [summary.AIPW_base]\cr
#'  \code{plot.p_score()} \tab   Plot the propensity scores by exposure status         \tab   [plot.p_score]\cr
#'  \code{plot.ip_weights()} \tab   Plot the inverse probability weights using truncated propensity scores  \tab   [plot.ip_weights]\cr
#'  }
#'
#' @section Public Variables:
#'  \tabular{lll}{
#'  \strong{Variable}     \tab   \strong{Generated by}      \tab     \strong{Return} \cr
#'  \code{n}              \tab   Constructor                \tab     Number of observations \cr
#'  \code{stratified_fitted} \tab   `stratified_fit()`      \tab    Fit the outcome model stratified by exposure status \cr
#'  \code{obs_est}        \tab   `fit()` & `summary()`      \tab     Components calculating average causal effects \cr
#'  \code{estimates}      \tab   `summary()`                \tab     A list of Risk difference, risk ratio, odds ratio \cr
#'  \code{result}         \tab   `summary()`                \tab     A matrix contains RD, ATT, ATC, RR and OR with their SE and 95%CI \cr
#'  \code{g.plot}         \tab   `plot.p_score()`           \tab     A density plot of propensity scores by exposure status\cr
#'  \code{ip_weights.plot}         \tab   `plot.ip_weights()`           \tab     A box plot of inverse probability weights \cr
#'  \code{libs}           \tab   `fit()`                    \tab     [SuperLearner] libraries and their fitted objects \cr
#'  \code{sl.fit}         \tab   Constructor                \tab     A wrapper function for fitting [SuperLearner] \cr
#'  \code{sl.predict}     \tab   Constructor                \tab     A wrapper function using \code{sl.fit} to predict \cr
#'  }
#'
#' ## Public Variable Details
#' \describe{
#'    \item{\code{stratified_fit}}{An indicator for whether the outcome model is fitted stratified by exposure status in the`fit()` method.
#'    Only when using `stratified_fit()` to turn on `stratified_fit = TRUE`, `summary` outputs average treatment effects among the treated and the controls.}
#'    \item{\code{obs_est}}{After using `fit()` and `summary()` methods, this list contains the propensity scores (`p_score`),
#'    counterfactual predictions (`mu`, `mu1` & `mu0`) and
#'    efficient influence functions (`aipw_eif1` & `aipw_eif0`) for later average treatment effect calculations.}
#'    \item{\code{g.plot}}{This plot is generated by `ggplot2::geom_density`}
#'    \item{\code{ip_weights.plot}}{This plot uses truncated propensity scores stratified by exposure status (`ggplot2::geom_boxplot`)}
#' }
#'
#' @return \code{AIPW} object
#'
#' @references Zhong Y, Kennedy EH, Bodnar LM, Naimi AI (2021, In Press). AIPW: An R Package for Augmented Inverse Probability Weighted Estimation of Average Causal Effects. \emph{American Journal of Epidemiology}.
#' @references Robins JM, Rotnitzky A (1995). Semiparametric efficiency in multivariate regression models with missing data. \emph{Journal of the American Statistical Association}.
#' @references Chernozhukov V, Chetverikov V, Demirer M, et al (2018). Double/debiased machine learning for treatment and structural parameters. \emph{The Econometrics Journal}.
#' @references Kennedy EH, Sjolander A, Small DS (2015). Semiparametric causal inference in matched cohort studies. \emph{Biometrika}.
#'
#'
#' @examples
#' library(SuperLearner)
#' library(ggplot2)
#'
#' #create an object
#' aipw_sl <- AIPW$new(Y=rbinom(100,1,0.5), A=rbinom(100,1,0.5),
#'                     W.Q=rbinom(100,1,0.5), W.g=rbinom(100,1,0.5),
#'                     Q.SL.library="SL.mean",g.SL.library="SL.mean",
#'                     k_split=1,verbose=FALSE)
#'
#' #fit the object
#' aipw_sl$fit()
#' # or use `aipw_sl$stratified_fit()` to estimate ATE and ATT/ATC
#'
#' #calculate the results
#' aipw_sl$summary(g.bound = 0.025)
#'
#' #check the propensity scores by exposure status after truncation
#' aipw_sl$plot.p_score()
#'
#' @export
AIPW <- R6::R6Class(
  "AIPW",
  portable = TRUE,
  inherit = AIPW_base,
  public = list(
    #-------------------------public fields-----------------------------#
    libs =list(Q.SL.library=NULL,
               Q.fit = NULL,
               g.SL.library=NULL,
               g.fit = NULL,
               validation_index = NULL,
               validation_index.Q = NULL),
    sl.fit = NULL,
    sl.predict = NULL,


    #-------------------------constructor-----------------------------#
    initialize = function(Y=NULL, A=NULL, verbose=TRUE,
                          W=NULL, W.Q=NULL, W.g=NULL,
                          Q.SL.library=NULL, g.SL.library=NULL,
                          k_split=10, save.sl.fit=FALSE){
      #-----initialize from AIPW_base class-----#
      super$initialize(Y=Y,A=A,verbose=verbose)
      #decide covariate set(s): W.Q and W.g only works when W is null.
      if (is.null(W) & private$Y.missing==FALSE){
        if (any(is.null(W.Q),is.null(W.g))) {
          stop("Insufficient covariate sets were provided.")
        } else{
          tryCatch({
            private$Q.set=cbind(A, W.Q)
            }, error = function(e) stop('Covariates dimension error: nrow(W.Q) != length(A)'))
          private$g.set=W.g
        }
      } else if (is.null(W) & private$Y.missing==TRUE){
        stop("`W.Q` and `W.g` are disabled when missing outcome is detected. Please provide covariates in `W`")
      } else{
        tryCatch({
          private$Q.set=cbind(A, W)
          }, error = function(e) stop('Covariates dimension error: nrow(W) != length(A)'))
        private$g.set=W
      }

      #subset observations with complete outcome
      private$Q.set = as.data.frame(private$Q.set)
      private$g.set = as.data.frame(private$g.set)
      if (ncol(private$g.set)==1) {
        private$g.set = as.data.frame(private$g.set)
        colnames(private$g.set) <- "Z"
      } else {
        private$g.set = private$g.set
      }

      #save input into private fields
      private$k_split=k_split
      #whether to save sl.fit  (Q.fit and g.fit)
      private$save.sl.fit = save.sl.fit
      #check data length
      if (length(private$Y)!=dim(private$Q.set)[1] | length(private$A)!=dim(private$g.set)[1]){
        stop("Please check the dimension of the covariates")
      }
      #-----determine SuperLearner or sl3 and change accordingly-----#
      if (is.character(Q.SL.library) & is.character(g.SL.library)) {
        if (any(grepl("SL.",Q.SL.library)) & any(grepl("SL.",g.SL.library))){
          #change future package loading
          private$sl.pkg <- "SuperLearner"
          #create a new local env for superlearner
          private$sl.env = new.env()
          #find the learners in global env and assign them into sl.env
          private$sl.learners = grep("SL.",lsf.str(globalenv()),value = T)
          lapply(private$sl.learners, function(x) assign(x=x,value=get(x,globalenv()),envir=private$sl.env))
          #change wrapper functions
          self$sl.fit = function(Y, X, SL.library, CV){
            suppressMessages({
              fit <- SuperLearner::SuperLearner(Y = Y, X = X, SL.library = SL.library, family= private$Y.type,
                                                env=private$sl.env, cvControl = CV)
            })
            return(fit)
          }
          self$sl.predict = function(fit, newdata){
            suppressMessages({
              pred <- as.numeric(predict(fit,newdata = newdata)$pred)
            })
            return(pred)
          }
        } else{
          stop("Input Q.SL.library and/or g.SL.library is not a valid SuperLearner library")
        }
      }  else {
        stop("Input Q.SL.library and/or g.SL.library is not a valid SuperLearner library")
      }

      #input sl libraries
      self$libs$Q.SL.library=Q.SL.library
      self$libs$g.SL.library=g.SL.library
      #------input checking-----#
      #check k_split value
      if (private$k_split>=self$n){
        stop("`k_split` >= number of observations is not allowed.")
      }else if (private$k_split < 1){
        stop("`k_split` < 1 is not allowed.")
      }
      #check verbose value
      if (!is.logical(private$verbose)){
        stop("`verbose` is not valid")
      }
      #check if SuperLearner and/or sl3 library is loaded
      if (!any(names(sessionInfo()$otherPkgs) %in% c("SuperLearner"))){
        warning("`SuperLearner` package is not loaded.")
      }
      #-------check if future.apply is loaded otherwise lapply would be used.------#
      if (any(names(sessionInfo()$otherPkgs) %in% c("future.apply"))){
        private$.f_lapply = function(iter,func) {
          future.apply::future_lapply(iter,func,future.seed = T,future.packages = private$sl.pkg,future.globals = TRUE)
        }
      }else{
        private$.f_lapply = function(iter,func) lapply(iter,func)
        }
    },


    #-------------------------fit method-----------------------------#
    fit = function(){
      self$stratified_fitted = FALSE
      #----------create index for cross-fitting---------#
      private$cv$k_index <- sample(rep(1:private$k_split,ceiling(self$n/private$k_split))[1:self$n],replace = F)
      private$cv$fold_index = split(1:self$n, private$cv$k_index)
      private$cv$fold_length = sapply(private$cv$fold_index,length)
      #create non-missing index for the outcome model
      if (private$Y.missing) {
        private$cv$fold_index.Q = lapply(private$cv$fold_index, function(x) x[x %in% which(private$observed==1)])
        private$cv$fold_length.Q = sapply(private$cv$fold_index.Q,length)
      } else{
        private$cv$fold_index.Q = private$cv$fold_index
        private$cv$fold_length.Q = private$cv$fold_length
      }


      iter <- 1:private$k_split

      #----------------progress bar setup----------#
      #check if progressr is loaded
      if (any(names(sessionInfo()$otherPkgs) %in% c("progressr"))){
        private$isLoaded_progressr = TRUE
        pb <- progressr::progressor(along = iter)
      }

      #---------parallelization with future.apply------#
      fitted <- private$.f_lapply(
        iter=iter,
        func=function(i,...){
          #when k_split in 1:2, no cvControl will be used (same cv for k_split)
          if (private$k_split==1){
            train_index <- validation_index <- as.numeric(unlist(private$cv$fold_index))
            cv_param <- list()
          } else if (private$k_split==2){
            train_index <- as.numeric(unlist(private$cv$fold_index[-i]))
            validation_index <- as.numeric(unlist(private$cv$fold_index[i]))
            cv_param <- list()
          } else{
            train_index <- as.numeric(unlist(private$cv$fold_index[-i]))
            validation_index <- as.numeric(unlist(private$cv$fold_index[i]))
            cv_param <- list(V=private$k_split-1,
                             validRows= private$.new_cv_index(val_fold=i , fold_length = private$cv$fold_length))
          }

          #when outcome is missing, subset the complete case for Q estimation
          if (private$Y.missing){
            if (private$k_split==1){
              train_index.Q <- validation_index.Q <-as.numeric(unlist(private$cv$fold_index.Q))
              cv_param.Q <- list()
            } else if (private$k_split==2) {
              train_index.Q <- as.numeric(unlist(private$cv$fold_index.Q[-i]))
              validation_index.Q <- as.numeric(unlist(private$cv$fold_index.Q[i]))
              cv_param.Q <- list()
            } else {#special care for cross-fitting indices when outcome is missing
              train_index.Q <- as.numeric(unlist(private$cv$fold_index.Q[-i]))
              validation_index.Q <- as.numeric(unlist(private$cv$fold_index.Q[i]))
              cv_param.Q <- list(V=private$k_split-1,
                                 validRows= private$.new_cv_index(val_fold=i, fold_length =private$cv$fold_length.Q))
            }
          } else{
            train_index.Q = train_index
            validation_index.Q = validation_index
            cv_param.Q <- cv_param
          }

          #split the sample based on the index
          #Q outcome set
          train_set.Q <- private$Q.set[train_index.Q,]
          validation_set.Q <- private$Q.set[validation_index.Q,]
          #g exposure set
          train_set.g <- data.frame(private$g.set[train_index,])
          validation_set.g <- data.frame(private$g.set[validation_index,])
          colnames(train_set.g)=colnames(validation_set.g)=colnames(private$g.set) #make to g df colnames consistent

          #Q model(outcome model: g-comp)
          #fit with train set
          Q.fit <- self$sl.fit(Y = private$Y[train_index.Q],
                               X = train_set.Q,
                               SL.library = self$libs$Q.SL.library,
                               CV= cv_param.Q)
          # predict on validation set
          mu0 <- self$sl.predict(Q.fit,newdata=transform(validation_set.Q, A = 0)) #Q0_pred
          mu1 <- self$sl.predict(Q.fit,newdata=transform(validation_set.Q, A = 1)) #Q1_pred

          #g model(exposure model: propensity score)
          # fit with train set
          g.fit <- self$sl.fit(Y=private$AxObserved[train_index],
                               X=train_set.g,
                               SL.library = self$libs$g.SL.library,
                               CV= cv_param)
          # predict on validation set
          raw_p_score  <- self$sl.predict(g.fit,newdata = validation_set.g)  #g_pred

          #add metadata
          names(validation_index) <- rep(i,length(validation_index))

          if (private$isLoaded_progressr){
            pb(sprintf("No.%g iteration", i,private$k_split))
          }

          if (private$save.sl.fit){
            output <- list(validation_index, validation_index.Q, Q.fit, mu0, mu1, g.fit, raw_p_score)
            names(output) <- c("validation_index","validation_index.Q","Q.fit","mu0","mu1","g.fit","raw_p_score")
          } else {
            output <- list(validation_index, validation_index.Q, mu0, mu1, raw_p_score)
            names(output) <- c("validation_index","validation_index.Q","mu0","mu1","raw_p_score")
          }

          return(output)
        })

      #store fitted values from future to member variables
      for (i in fitted){
        #add estimates based on the val index
        self$obs_est$mu0[i$validation_index.Q] <- i$mu0
        self$obs_est$mu1[i$validation_index.Q] <- i$mu1
        self$obs_est$raw_p_score[i$validation_index] <- i$raw_p_score
        #append fitted objects
        if (private$save.sl.fit) {
          self$libs$Q.fit = append(self$libs$Q.fit, list(i$Q.fit))
          self$libs$g.fit = append(self$libs$g.fit, list(i$g.fit))
        }
        self$libs$validation_index = append(self$libs$validation_index, i$validation_index)
        self$libs$validation_index.Q = append(self$libs$validation_index.Q, i$validation_index.Q)
      }
      self$obs_est$mu[private$observed==1]  <- self$obs_est$mu0[private$observed==1]*(1-private$A[private$observed==1]) +
        self$obs_est$mu1[private$observed==1]*(private$A[private$observed==1])#Q_pred

      if (private$verbose){
        message("Done!\n")
      }

      invisible(self)
    },



    #-------------------------stratified_fit method-----------------------------#
    stratified_fit = function(){
      self$stratified_fitted = TRUE
      #----------create index for cross-fitting---------#
      private$cv$k_index <- sample(rep(1:private$k_split,ceiling(self$n/private$k_split))[1:self$n],replace = F)
      private$cv$fold_index = split(1:self$n, private$cv$k_index)
      private$cv$fold_length = sapply(private$cv$fold_index,length)
      #create non-missing index for the outcome model
      if (private$Y.missing) {
        private$cv$fold_index.Q = lapply(private$cv$fold_index, function(x) x[x %in% which(private$observed==1)])
        private$cv$fold_length.Q = sapply(private$cv$fold_index.Q,length)
      } else{
        private$cv$fold_index.Q = private$cv$fold_index
        private$cv$fold_length.Q = private$cv$fold_length
      }


      iter <- 1:private$k_split

      #----------------progress bar setup----------#
      #check if progressr is loaded
      if (any(names(sessionInfo()$otherPkgs) %in% c("progressr"))){
        private$isLoaded_progressr = TRUE
        pb <- progressr::progressor(along = iter)
      }

      #---------parallelization with future.apply------#
      fitted <- private$.f_lapply(
        iter=iter,
        func=function(i,...){
          #when k_split in 1:2, no cvControl will be used (same cv for k_split)
          if (private$k_split==1){
            train_index <- validation_index <- as.numeric(unlist(private$cv$fold_index))
          } else if (private$k_split>=2){
            train_index <- as.numeric(unlist(private$cv$fold_index[-i]))
            validation_index <- as.numeric(unlist(private$cv$fold_index[i]))
          }

          cv_param <- list()
          #when outcome is missing, subset the complete case for Q estimation
          if (private$Y.missing){
            if (private$k_split==1){
              train_index.Q <- validation_index.Q <-as.numeric(unlist(private$cv$fold_index.Q))
            } else if (private$k_split>=2) {
              train_index.Q <- as.numeric(unlist(private$cv$fold_index.Q[-i]))
              validation_index.Q <- as.numeric(unlist(private$cv$fold_index.Q[i]))
            }
            cv_param.Q <- list()
          } else{
            train_index.Q = train_index
            validation_index.Q = validation_index
            cv_param.Q <- cv_param
          }

          #Q model(outcome model: g-comp)
          #fit with train set
          #A==0
          train_index.Q0 <- intersect(train_index.Q, which(private$A==0))
          Q0.fit <- self$sl.fit(Y = private$Y[train_index.Q0],
                                X = private$Q.set[train_index.Q0,],
                                SL.library = self$libs$Q.SL.library,
                                CV= cv_param.Q)
          #A==1
          train_index.Q1 <- intersect(train_index.Q, which(private$A==1))
          Q1.fit <- self$sl.fit(Y = private$Y[train_index.Q1],
                                X = private$Q.set[train_index.Q1,],
                                SL.library = self$libs$Q.SL.library,
                                CV= cv_param.Q)
          # predict on validation set
          mu0 <- self$sl.predict(Q0.fit,newdata=private$Q.set[validation_index.Q,]) #Q0_pred
          mu1 <- self$sl.predict(Q1.fit,newdata=private$Q.set[validation_index.Q,]) #Q1_pred


          #g model(exposure model: propensity score)
          #g exposure set
          train_set.g <- data.frame(private$g.set[train_index,])
          validation_set.g <- data.frame(private$g.set[validation_index,])
          colnames(train_set.g)=colnames(validation_set.g)=colnames(private$g.set) #make to g df colnames consistent
          # fit with train set
          g.fit <- self$sl.fit(Y=private$AxObserved[train_index],
                               X=train_set.g,
                               SL.library = self$libs$g.SL.library,
                               CV= cv_param)
          # predict on validation set
          raw_p_score  <- self$sl.predict(g.fit,newdata = validation_set.g)  #g_pred

          #add metadata
          names(validation_index) <- rep(i,length(validation_index))

          if (private$isLoaded_progressr){
            pb(sprintf("No.%g iteration", i,private$k_split))
          }

          if (private$save.sl.fit){
            Q.fit <- list(Q0=Q0.fit, Q1= Q1.fit)
            output <- list(validation_index, validation_index.Q, Q.fit, mu0, mu1, g.fit, raw_p_score)
            names(output) <- c("validation_index","validation_index.Q","Q.fit","mu0","mu1","g.fit","raw_p_score")
          } else {
            output <- list(validation_index, validation_index.Q, mu0, mu1, raw_p_score)
            names(output) <- c("validation_index","validation_index.Q","mu0","mu1","raw_p_score")
          }

          return(output)
        })

      #store fitted values from future to member variables
      for (i in fitted){
        #add estimates based on the val index
        self$obs_est$mu0[i$validation_index.Q] <- i$mu0
        self$obs_est$mu1[i$validation_index.Q] <- i$mu1
        self$obs_est$raw_p_score[i$validation_index] <- i$raw_p_score
        #append fitted objects
        if (private$save.sl.fit) {
          self$libs$Q.fit = append(self$libs$Q.fit, list(i$Q.fit))
          self$libs$g.fit = append(self$libs$g.fit, list(i$g.fit))
        }
        self$libs$validation_index = append(self$libs$validation_index, i$validation_index)
        self$libs$validation_index.Q = append(self$libs$validation_index.Q, i$validation_index.Q)
      }
      self$obs_est$mu[private$observed==1]  <- self$obs_est$mu0[private$observed==1]*(1-private$A[private$observed==1]) +
        self$obs_est$mu1[private$observed==1]*(private$A[private$observed==1])#Q_pred

      if (private$verbose){
        message("Done!\n")
      }

      invisible(self)
    }
  ),

  #-------------------------private fields and methods----------------------------#
  private = list(
    #input
    Q.set=NULL,
    g.set=NULL,
    k_split=NULL,
    save.sl.fit=FALSE,
    cv = list(
      #a vector stores the groups for splitting
      k_index= NULL,
      #a list of indices for each fold
      fold_index= NULL,
      fold_index.Q = NULL,
      #a vector of length(fold_index[[i]])
      fold_length = NULL,
      fold_length.Q = NULL
    ),
    fitted=NULL,
    sl.pkg =NULL,
    sl.env=NULL,
    sl.learners = NULL,
    isLoaded_progressr = FALSE,
    #private methods
    #lapply or future_lapply
    .f_lapply =NULL,
    #create new index for training set
    .new_cv_index = function(val_fold,fold_length=private$cv$fold_length, k_split=private$k_split){
      train_fold_length = c(0,fold_length[-val_fold])
      train_fold_cumsum = cumsum(train_fold_length)
      new_train_index= lapply(1:(k_split-1),
                              function(x) {
                                (1:train_fold_length[[x+1]])+ train_fold_cumsum[[x]]
                              }
      )
      names(new_train_index) = names(train_fold_length[-1])
      return(new_train_index)
    }
  )
)



#' @name fit
#' @aliases fit.AIPW
#' @title Fit the data to the [AIPW] object
#'
#' @description
#' Fitting the data into the [AIPW] object with/without cross-fitting to estimate the efficient influence functions
#'
#' @section R6 Usage:
#' \code{$fit()}
#'
#' @return A fitted [AIPW] object with `obs_est` and `libs` (public variables)
#'
#' @seealso [AIPW]
NULL

#' @name stratified_fit
#' @aliases stratified_fit.AIPW
#' @title Fit the data to the [AIPW] object stratified by `A` for the outcome model
#'
#' @description
#' Fitting the data into the [AIPW] object with/without cross-fitting to estimate the efficient influence functions.
#' Outcome model is fitted, stratified by exposure status `A`
#'
#' @section R6 Usage:
#' \code{$stratified_fit.AIPW()}
#'
#' @return A fitted [AIPW] object with `obs_est` and `libs` (public variables)
#'
#' @seealso [AIPW]
NULL

Try the AIPW package in your browser

Any scripts or data that you put into this service are public.

AIPW documentation built on June 11, 2021, 5:08 p.m.