R/TEstimator.R

#' @import ggplot2
#' @import ggpubr
#' @import dplyr
#' @import tidyr
TEstimator <- R6::R6Class(
  "TEstimator",
  #-------------------------public fields-----------------------------#
  public = list(
    id = NA,
    name = character(),
    statistics = list(n=numeric(),
                      density_confounders=data.frame()),

    data = NULL,

    estimates = list(ATE=data.frame(y1.hat=NA,
                                    y0.hat=NA,
                                    est=NA,
                                    se=NA),
                     CATE = data.frame()),
    model = list(),

    #-------------------------constructor-----------------------------#
    #' @description Create a new \code{Estimator} object
    #' @param df A data frame containing variables in \code{vars_name}
    #' @param vars_name vars_name A list containing four vectors \code{confounders_internal}, \code{treatment_name}, and \code{outcome_name}. \code{confounders_internal} is a character vector containing the adjustment variables, which, along with \code{Estimator} and the corresponding \code{outcome_method} or \code{treatment_method} to correct for confounding; \code{outcome_name} is a character vector of length one containing the variable name of outcome; \code{treatment_name} is a character vector of length one containing the variable name of treatment.
    initialize = function(df, vars_name, name) {
      self$name <- name
      self$data <- df
      self$data$id <- seq(dim(df)[1])
      private$outcome_predictors <- vars_name$outcome_predictors
      private$treatment_name <- vars_name$treatment_name
      private$outcome_name <- vars_name$outcome_name
      self$statistics <- list(n=dim(df)[1],
                              density_confounders=private$est_joint_denstiy())
    },

    #' @description Replicating the average treatment effect of \code{target.obj}. If \code{stratification} is specified, then replicating the conditional average treatment effect stratified by \code{stratification} and \code{stratification_joint} by weighting based on the residual variables, namely, variables that are specified in \code{outcome_predictors} while not in \code{stratification}.
    #' @param target.obj An object of class \code{Estimator} or list.
    #' @param weighting_estimator A string specifying a weighting estimator for generalizing/transporting the estimates to \code{target.obj}. The allowed estimators are: \code{"balancing"}, and \code{"modeling"}.
    #' @param weighting_method A string specifying which model for selection to use. Possible values are found using \code{names(getModelInfo())}. See \url{http://topepo.github.io/caret/train-models-by-tag.html}.
    #' @param stratification An optional string vector containing variables to define subgroup. If \code{!is.NULL(stratification)}, \code{source.obj} will compute both weighted and unweighted conditional average treatment effect based on these variables, \code{target.obj} will calculate the conditional average treatment effect based on these variables.
    #' @param stratification_joint An optional logical defining the subgroup based on joint distribution of variables or univariate distribution in \code{stratification} when \code{stratification} is specified.
    # EstimateRep = function(target.obj,
    #                        weighting_estimator="Balancing", weighting_method="glm",
    #                        stratification=NULL, stratification_joint=FALSE) {
    #   private$weighting_estimator <- weighting_estimator
    #   private$weighting_method <- weighting_method
    #
    #   #browser()
    #   private$set_weighted_ATE_SE(target.obj$data)
    #   if (!is.null(stratification)) {
    #     if(isTRUE(all.equal(stratification, private$confounders_external_name))&stratification_joint){
    #       self$estimates$CATE <- self$get_CATE(stratification = stratification,
    #                                            stratification_joint = stratification_joint)
    #       self$estimates$CATE_weighted <- self$estimates$CATE
    #     } else {
    #       self$estimates$CATE <- self$get_CATE(stratification = stratification,
    #                                            stratification_joint = stratification_joint)
    #       private$set_weighted_CATE_SE(
    #         target.data = target.obj$data,
    #         stratification = stratification,
    #         stratification_joint = stratification_joint
    #       )
    #     }
    #   }
    # },

    #' @description Get conditional average treatment effect of subgroups defined by \code{stratification} and \code{stratification_joint}. If \code{stratification_joint=FALSE}, then the method return conditional average treatment effect of subgroups stratified by each of variables in \code{stratification}.
    #' @param stratification An string vector containing variables to define subgroup.
    #' @param stratification_joint An logical defining the subgroup based on joint distribution of variables or univariate distribution in \code{stratification}.
    #' @return A data frame. If \code{stratification_joint=TRUE}, then the method returns a data frame with N rows and J columns, where N represents the number of subgroups, and J is equal to the sum of number of variables in \code{stratification} and 3 (three additional columns with name \code{cate}, \code{se}, and \code{size}, representing the estimated conditional average treatment effect of this subgroup, standard error of the estimate, and the sample size of the subgroup). If \code{stratification_joint=FALSE}, then the method returns a data frame with N rows and 5 columns, where N represents the number of subgroups stratified by each variable in \code{stratification} and 5 columns with name \code{name}, \code{value}, \code{cate}, \code{se}, and \code{size}, representing the name of a variable used to stratify the population, a level of the variable, the estimated conditional average treatment effect of this subgroup, standard error of the estimate, and the sample size of the subgroup).
    get_CATE = function(stratification, stratification_joint=TRUE) {
      # browser()
      if("TEstimator_Synthetic" %in% class(self)) {
        return(self$estimates$CATE)
      }

      if (stratification_joint) {
        CATE_mean_se <- private$est_CATEestimation4JointStratification(stratification)
      } else {
        CATE_mean_se <- private$est_CATEestimation4SeperateStratification(stratification)
      }
      return(CATE_mean_se)
    },


    #' @description Plot the forest plot of conditional average treatment effect of subgroups defined by \code{stratification} and \code{stratification_joint}. The method first call public method \code{get_CATE(stratification,stratification_joint)}, then plot the results.
    #' @param stratification An string vector containing variables to define subgroup.
    #' @param stratification_joint An logical defining the subgroup based on joint distribution of variables or univariate distribution in \code{stratification}.
    plot_CATE = function(stratification = private$outcome_predictors,
                         stratification_joint = TRUE) {
      #browser()
      data.cate <- self$get_CATE(stratification, stratification_joint)
      colnames.subgroups <- colnames(data.cate)
      var_names <- colnames.subgroups[!colnames.subgroups %in% c("y1.hat","y0.hat","cate","se","size","pt","py")]
      var_names_data <- data.cate[,var_names]
      subgroup_name_level <- apply(var_names_data, 1, function(x) paste(var_names, x, sep = "=", collapse = ","))
      subgroup_name_level <- factor(subgroup_name_level, levels = subgroup_name_level, ordered = T)

      #browser()
      df <- data.cate %>%
        select(cate,se,size) %>%
        mutate(group=subgroup_name_level,
               ci_l=cate-1.98*se,
               ci_u=cate+1.98*se)

      cate_plot <- ggplot2::ggplot(data = df, aes(x = cate, y = group)) +
        geom_point(position = position_dodge(0.5), aes(size=size)) +
        geom_errorbar(aes(xmin = ci_l, xmax = ci_u),
                      width = .3, position = position_dodge(0.5)) +
        geom_vline(xintercept = 0, color = "black", linetype = "dashed", alpha = .5) +
        ggtitle("Estimate of conditional average treatment effect.") +
        theme(plot.title = element_text())

      return(cate_plot)
    },

    plot_y1_y0 = function(stratification, stratification_joint = TRUE, seperate = FALSE){
      if(stratification_joint==TRUE){
        df <- self$get_CATE(stratification, stratification_joint) %>%
          select(stratification, y1.hat, y0.hat) %>%
          gather(key="potential_outcome", value="est", y1.hat, y0.hat) %>%
          mutate(group_name = apply(.[,stratification], 1, function(x)
            paste(stratification,x,sep = "=",collapse = ",")))
      } else{
        df <- self$get_CATE(stratification, stratification_joint) %>%
          select(name, value, y1.hat, y0.hat) %>%
          gather(y1.hat, y0.hat, -name, -value) %>%
          rename(potential_outcome = y1.hat,
                 est = y0.hat) %>%
          mutate(group_name = paste(name,"=",value,sep = ""))
      }
      if(seperate == TRUE){
        ggplot2::ggplot(data=df, aes(x=group_name, y=est)) +
          geom_bar(stat = "identity", position = position_dodge()) +
          facet_wrap(~potential_outcome) +
          coord_flip()
      } else {
        ggplot2::ggplot(data=df, aes(x=group_name, y=est,
                            group = potential_outcome, fill=potential_outcome)) +
          geom_bar(stat = "identity", position = position_dodge()) +
          coord_flip()
        }
    },

    diagnosis_t_overlap = function(stratification, stratification_joint=TRUE){
      #browser()
      if(missing(stratification)){
        vars_name <- private$outcome_predictors
      } else{
        vars_name <- stratification
      }

      if(isTRUE(stratification_joint)){
        # position=fill: a percent stacked barplot
        p.prop <- self$data %>%
          select(vars_name, private$treatment_name) %>%
          mutate(group_name = apply(.[,vars_name], 1, function(x)
            paste(vars_name,x,sep = "=",collapse = ","))) %>%
          select(group_name, private$treatment_name) %>%
          ggplot(aes(x=group_name, fill=factor(eval(parse(text = private$treatment_name))))) +
          geom_bar(position = "fill") +
          ylab("proportion") +
          labs(fill = "treatment") +
          coord_flip() +
          theme(legend.position="none")


        p.count <- self$data %>%
          select(vars_name, private$treatment_name) %>%
          mutate(group_name = apply(.[,vars_name], 1, function(x)
            paste(vars_name,x,sep = "=",collapse = ","))) %>%
          select(group_name, private$treatment_name) %>%
          ggplot(aes(x=group_name, fill=factor(eval(parse(text = private$treatment_name))))) +
          geom_bar(stat = "count") +
          labs(fill = "treatment") +
          coord_flip() +
          theme(axis.text.y=element_blank(),
                axis.ticks.y=element_blank(),
                axis.title.y=element_blank())
      } else {
        # apply: X must be matrix or array
        # data then add mutate then group and summarise, can't fist group, summarise and mutate...
        df <- lapply(stratification, function(x) self$data %>%
                         mutate(group_name = apply(as.matrix(.[,x]), 1, function(y) paste(x,y,sep="="))) %>%
                         group_by(group_name, eval(parse(text=private$treatment_name))) %>%
                         summarise(count=n())
                       )
        data <- data.frame()
        for (df.strata in df) {
          data <- bind_rows(data, df.strata)
        }
        colnames(data) <- c("group_name",private$treatment_name,"count")

        # the default behavior of geom_bar() is to count the rows for each x value. It doesn't expect y value, since it's going to
        # count that up itself - in fact, it will flag a warning if you give it one, since it thinks you are confused. How aggregation
        # is to be performed is specified as an argument to geom_bar(), which is stat="count" for default value.
        # stat="identity": you're telling ggplot2 to skip the aggregation and that you'll provide the y values.
        # If you use stat="identity", you need to provide y value.
        p.prop <- ggplot(data = data, aes(x=group_name, y=count, fill=factor(eval(parse(text = private$treatment_name))))) +
          geom_bar(stat = "identity", position = "fill") +
          ylab("proportion") +
          labs(fill = "treatment") +
          coord_flip() +
          theme(legend.position="none")

        p.count <- ggplot(data = data, aes(x=group_name, y=count, fill=factor(eval(parse(text = private$treatment_name))))) +
          geom_bar(stat = "identity") +
          ylab("count") +
          labs(fill = "treatment") +
          coord_flip() +
          theme(axis.text.y=element_blank(),
                axis.ticks.y=element_blank(),
                axis.title.y=element_blank())

      }

      #print(data)

      tgrob <- ggpubr::text_grob(c("Treatment overlap within subpopulations"))
      ggpubr::ggarrange(tgrob, NULL, p.prop, p.count, ncol=2, nrow=2, heights = c(1,5))

    },

    diagnosis_y_overlap = function(stratification, stratification_joint=TRUE){
      #browser()
      if(missing(stratification)){
        stratification <- private$outcome_predictors
      }

      if(test_binary(self$data[,private$outcome_name])){
        if(stratification_joint==TRUE){
          df <- self$data %>%
            select(stratification, private$outcome_name, private$treatment_name) %>%
            #filter(eval(parse(text=private$treatment_name)) == "1") %>%
            mutate(group_name = apply(.[,stratification], 1, function(x)
              paste(stratification,x,sep = "=",collapse = ",")))
          p <- ggplot(data=df, aes(x=group_name, fill=factor(eval(parse(text = private$outcome_name))))) +
            geom_bar(stat = "count") +
            labs(fill = "outcome") +
            ggtitle("Outcome overlap within subpopulations")+
            coord_flip() +
            facet_wrap(~eval(parse(text=private$treatment_name))) +
            theme(axis.title.y=element_blank())

          print(df %>% group_by(across(c("group_name",private$treatment_name, private$outcome_name))) %>% summarise(count=n()))
          p

        } else{
          df <- self$data %>%
            select(stratification, private$treatment_name, private$outcome_name) %>%
            tidyr::gather(key = "variable", value = "measurement", -c(private$treatment_name,private$outcome_name)) %>%
            mutate(group_name = paste(variable,measurement,sep="="))

          p <- ggplot(data=df, aes(x=group_name, fill=factor(eval(parse(text = private$outcome_name))))) +
                geom_bar(stat = "count") +
                labs(fill = "outcome") +
                ggtitle("Outcome overlap within subpopulations")+
                coord_flip() +
                facet_wrap(~eval(parse(text=private$treatment_name))) +
                theme(axis.title.y=element_blank())

          print(df %>% group_by(across(c("group_name", private$treatment_name, private$outcome_name))) %>% summarise(count=n()))
          p
          # df <- lapply(stratification, function(x) self$data %>%
          #              mutate(group_name = apply(as.matrix(.[,x]), 1, function(y) paste(x,y,sep="="))) %>%
          #              group_by(group_name, eval(parse(text=private$treatment_name))) %>%
          #              summarise(count=n())
          # )
          # data <- data.frame()
          # for (df.strata in df) {
          #   data <- bind_rows(data, df.strata)
          # }
          # colnames(data) <- c("group_name",private$treatment_name,"count")
        }



        # p.count <- self$data %>%
        #   select(vars_name, private$outcome_name, private$treatment_name) %>%
        #   #filter(eval(parse(text=private$treatment_name)) == "1") %>%
        #   mutate(group_name = apply(.[,vars_name], 1, function(x)
        #     paste(vars_name,x,sep = "=",collapse = ","))) %>%
        #   #select(group_name, private$outcome_name) %>%
        #   ggplot(aes(x=group_name, fill=factor(eval(parse(text = private$outcome_name))))) +
        #   geom_bar(stat = "count") +
        #   labs(fill = "outcome") +
        #   ggtitle("Outcome overlap within subpopulations")+
        #   coord_flip() +
        #   facet_wrap(~eval(parse(text=private$treatment_name))) +
        #   theme(axis.text.y=element_blank(),
        #         axis.ticks.y=element_blank(),
        #         axis.title.y=element_blank())



        # p.count.t1 <- self$data %>%
        #   select(vars_name, private$outcome_name, private$treatment_name) %>%
        #   filter(eval(parse(text=private$treatment_name)) == "1") %>%
        #   mutate(group_name = apply(.[,vars_name], 1, function(x)
        #     paste(vars_name,x,sep = "=",collapse = ","))) %>%
        #   select(group_name, private$outcome_name) %>%
        #   ggplot(aes(x=group_name, fill=factor(eval(parse(text = private$outcome_name))))) +
        #   geom_bar(stat = "count") +
        #   labs(fill = "outcome") +
        #   ggtitle("survival count in treatment group")+
        #   coord_flip() +
        #   theme(legend.position="none")
        #
        # p.count.t0 <- self$data %>%
        #   select(vars_name, private$outcome_name, private$treatment_name) %>%
        #   filter(eval(parse(text=private$treatment_name)) == "0") %>%
        #   mutate(group_name = apply(.[,vars_name], 1, function(x)
        #     paste(vars_name,x,sep = "=",collapse = ","))) %>%
        #   select(group_name, private$outcome_name) %>%
        #   ggplot(aes(x=group_name, fill=factor(eval(parse(text = private$outcome_name))))) +
        #   geom_bar(stat = "count") +
        #   labs(fill = "outcome") +
        #   ggtitle("survival count in control group")+
        #   coord_flip() +
        #   theme(axis.text.y=element_blank(),
        #         axis.ticks.y=element_blank(),
        #         axis.title.y=element_blank())

        #tgrob <- ggpubr::text_grob(c("Outcome overlap within subpopulations"))
        #ggpubr::ggarrange(tgrob, NULL, p.count.t1, p.count.t0,ncol=2, nrow=2, heights = c(1,5))

      } else {
        p.dis <- self$data %>%
            select(stratification, private$outcome_name, private$treatment_name) %>%
            mutate(group_name = apply(.[,stratification], 1, function(x)
              paste(stratification,x,sep = "=",collapse = ","))) %>%
            ggplot(aes(x=group_name, y=eval(parse(text = private$outcome_name)), fill=factor(eval(parse(text = private$treatment_name))))) +
            geom_boxplot() +
            ylab("outcome") +
            labs(fill = "treatment") +
            ggtitle("outcome distribution in groups")+
            coord_flip()

        p.dis
      }

    },

    diagnosis_t_ignorability = function(){}

  ),
  #-------------------------private fields and methods----------------------------#
  private = list(

    outcome_predictors = NA,
    treatment_name = NA,
    outcome_name = NA,
    var_method = "sandwitch",
    isTrial = FALSE,


    set_ATE = function(){
      ATE_SE <- private$est_ATE_SE(self$data$id)
      self$estimates$ATE$y1.hat <- ATE_SE$y1.hat
      self$estimates$ATE$y0.hat <- ATE_SE$y0.hat
      self$estimates$ATE$est <- ATE_SE$est
      self$estimates$ATE$se <- ATE_SE$se
    },

    set_CATE = function(stratification, stratification_joint){
      self$estimates$CATE <- self$get_CATE(stratification,stratification_joint)
    },

    est_joint_denstiy = function(){
      #browser()
      joint_var_internal <-
        self$data %>%
        group_by(across(all_of(c(private$outcome_predictors,private$treatment_name, private$outcome_name)))) %>%
        summarise(count=n())
      joint_var_internal <- as.data.frame(joint_var_internal)
      return(joint_var_internal)
    },

    est_CATEestimation4JointStratification = function(stratification) {
      #browser()
      group_data <- self$data %>%
        group_by(across(all_of(stratification)))
      group_strata <- group_data %>% group_keys()
      group_id <- group_data %>% group_indices()
      n_groups <- dim(group_strata)[1]
      group_sample_size <- group_size(group_data)
      cate <- se <- size <- y1.hat <- y0.hat <- pt <- py <- NULL
      for (i in seq(n_groups)) {
        subgroup.id.in.data <- self$data[group_id == i, "id"]
        cate_y1_y0_se <- private$est_ATE_SE(subgroup.id.in.data)
        y1.hat[i] <- cate_y1_y0_se$y1.hat
        y0.hat[i] <- cate_y1_y0_se$y0.hat
        cate[i] <- cate_y1_y0_se$est
        se[i] <- cate_y1_y0_se$se
        size[i] <- group_sample_size[i]
        pt[i] <- mean(as.numeric(as.character(self$data[group_id == i, private$treatment_name])))
        py[i] <- mean(as.numeric(as.character(self$data[group_id == i, private$outcome_name])))
        # print(i)
      }
      CATE_mean_se <- cbind(group_strata, y1.hat, y0.hat, cate, se, size, pt, py)
      CATE_mean_se <- as.data.frame(CATE_mean_se)
      # browser()
      # colnames(CATE_mean_se) <- c(colnames(patterns),"cate","se")
      return(CATE_mean_se)
    },

    est_CATEestimation4SeperateStratification = function(stratification) {
      # browser()
      group_var <- group_level <- cate <- se <- size <- y1.hat <- y0.hat <- density <- pt <- py <- NULL
      i <- 1
      for (var_name in stratification) {
        group_data <- self$data %>% group_by(across(var_name))
        group_strata <- group_data %>% group_keys()
        group_id_4each_obs <- group_data %>% group_indices()
        n_groups <- dim(group_strata)[1]
        group_sample_size <- group_size(group_data)
        for (group_id in seq(n_groups)) {
          subgroup.id.in.data <- self$data[group_id_4each_obs == group_id, "id"]
          group_var[i] <- var_name
          group_level[i] <- group_strata[group_id, 1]
          cate_y1_y0_se <- private$est_ATE_SE(subgroup.id.in.data)
          y1.hat[i] <- cate_y1_y0_se$y1.hat
          y0.hat[i] <- cate_y1_y0_se$y0.hat
          cate[i] <- cate_y1_y0_se$est
          se[i] <- cate_y1_y0_se$se
          size[i] <- group_sample_size[group_id]
          pt[i] <- mean(as.numeric(as.character(self$data[group_id_4each_obs == group_id, private$treatment_name])))
          py[i] <- mean(as.numeric(as.character(self$data[group_id_4each_obs == group_id, private$outcome_name])))

          i <- i + 1
        }
      }
      # the output element from group_keys() is not a vector/numeric, hence needs to convert to data.frame reshape(4*1)
      group_level <- t(as.data.frame(group_level))
      CATE_mean_se <- data.frame(
        name = group_var,
        value = group_level,
        y1.hat = y1.hat,
        y0.hat = y0.hat,
        cate = cate,
        se = se,
        size = size,
        pt = pt,
        py = py,
        stringsAsFactors = FALSE
      )
      return(CATE_mean_se)
    },

    # plot_cate = function(){
    #   data.cate <- self$estimates$CATE
    #   colnames.subgroups <- colnames(data.cate)
    #   var_names <- colnames.subgroups[!colnames.subgroups %in% c("y1.hat","y0.hat","cate","se","size")]
    #   var_names_data <- data.cate[,var_names]
    #   subgroup_name_level <- apply(var_names_data, 1, function(x) paste(var_names, x, sep = "=", collapse = ","))
    #   subgroup_name_level <- factor(subgroup_name_level, levels = subgroup_name_level, ordered = T)
    #
    #   #browser()
    #   df <- data.cate %>%
    #     select(cate,se,size) %>%
    #     mutate(group=subgroup_name_level,
    #            ci_l=cate-1.98*se,
    #            ci_u=cate+1.98*se)
    #
    #   cate.plot <- ggplot2::ggplot(data = df, aes(x = cate, y = group)) +
    #     geom_point(position = position_dodge(0.5), aes(size=size)) +
    #     geom_errorbar(aes(xmin = ci_l, xmax = ci_u),
    #                   width = .3, position = position_dodge(0.5)) +
    #     geom_vline(xintercept = 0, color = "black", linetype = "dashed", alpha = .5) +
    #     ggtitle("estimate of CATE.") +
    #     theme(plot.title = element_text(),
    #           legend.position = "none")
    #
    #   return(cate.plot)
    # },

    fit = function(){},

    est_ATE_SE = function(){},

    est_weighted_ATE_SE = function(){}

  )
)

Try the RCTrep package in your browser

Any scripts or data that you put into this service are public.

RCTrep documentation built on Nov. 2, 2023, 6:10 p.m.