R/ce_estimate.R

Defines functions ce_estimate

Documented in ce_estimate

#' Causal inference with multiple treatments using observational data
#'
#' The function \code{ce_estimate} implements the 6
#' different methods for causal inference with
#' multiple treatments using observational data.
#'
#' @param y A numeric vector (0, 1) representing a binary outcome.
#' @param x A dataframe, including all the covariates but not treatments.
#' @param w A numeric vector representing the treatment groups.
#' @param method A character string. Users can selected from the
#' following methods including \code{"RA"}, \code{"VM"}, \code{"BART"},
#' \code{"TMLE"}, \code{"IPTW-Multinomial"}, \code{"IPTW-GBM"},
#' \code{"IPTW-SL"}, \code{"RAMS-Multinomial"}, \code{"RAMS-GBM"},
#' \code{"RAMS-SL"}.
#' @param formula A \code{\link[stats]{formula}} object representing
#' the variables used for the analysis.
#' The default is to use all terms specified in \code{x}.
#' @param discard A logical indicating whether to use the discarding rules
#' for the BART based methods. The default is \code{FALSE}.
#' @param estimand A character string representing the type of causal estimand.
#' Only \code{"ATT"} or \code{"ATE"} is allowed.
#' When the \code{estimand = "ATT"}, users also need to specify
#' the reference treatment group by setting the \code{reference_trt} argument.
#' @param trim_perc A 2-vector numeric value indicating the percentile
#' at which the inverse probability of treatment weights should be trimmed.
#' The default is \code{NULL}.
#' @param sl_library A character vector of prediction algorithms.
#' A list of functions included in the SuperLearner package
#' can be found with \code{\link[SuperLearner:listWrappers]{listWrappers}}.
#' @param reference_trt A numeric value indicating reference
#' treatment group for ATT effect.
#' @param boot A logical indicating whether or not to use nonparametric
#' bootstrap to calculate the 95\% confidence intervals of the causal
#' effect estimates. The default is \code{FALSE}.
#' @param verbose_boot A logical value indicating whether to
#' print the progress of nonparametric bootstrap.
#' The default is \code{TRUE}.
#' @param nboots A numeric value representing the number of bootstrap samples.
#' @param ndpost A numeric value indicating the number of posterior draws
#' for the Bayesian methods (\code{"BART"} and \code{"RA"}).
#' @param caliper A numeric value denoting the caliper which should be used
#' when matching (\code{method = "VM"}) on the logit of GPS within each cluster
#' formed by K-means clustering.
#' The caliper is in standardized units. For example, \code{caliper = 0.25}
#' means that all matches greater than 0.25 standard deviations of the
#' logit of GPS are dropped. The default value is 0.25.
#' @param n_cluster A numeric value denoting the number of clusters to form
#' using K means clustering on the logit of GPS when \code{method = "VM"}.
#' The default value is 5.
#' @param ... Other parameters that can be passed through to functions.
#'
#' @return A summary of the effect estimates can be obtained
#' with \code{summary} function. For VM, the output contains the number
#' of matched individuals. For BART and \code{discard = TRUE},
#' the output contains number of discarded individuals. For IPTW related
#' method and \code{boot = FALSE}, the weight distributions can be
#' visualized using \code{plot} function. For BART and RA, the output
#' contains a list of the posterior samples of causal estimands.
#' @import SuperLearner
#' @importFrom stringr str_sub
#' @references
#'
#' Hu, L., Gu, C., Lopez, M., Ji, J., & Wisnivesky, J. (2020).
#' Estimation of causal effects of multiple treatments in
#' observational studies with a binary outcome.
#' \emph{Statistical Methods in Medical Research}, \strong{29}(11), 3218–3234.
#'
#' Hu, L., Gu, C.
#' Estimation of causal effects of multiple treatments in healthcare
#' database studies with rare outcomes.
#' \emph{Health Service Outcomes Research Method} \strong{21}, 287–308 (2021).
#'
#' Sparapani R, Spanbauer C, McCulloch R
#' Nonparametric Machine Learning and  Efficient Computation
#' with Bayesian Additive Regression Trees: The BART R Package.
#' \emph{Journal of Statistical Software}, \strong{97}(1), 1-66.
#'
#' Hadley Wickham, Romain François, Lionel Henry and Kirill Müller (2021).
#' \emph{dplyr: A Grammar of Data Manipulation}.
#' R package version 1.0.7.
#'  URL: \url{https://CRAN.R-project.org/package=dplyr}
#'
#' Venables, W. N. & Ripley, B. D. (2002)
#' \emph{Modern Applied Statistics with S}.
#' Fourth Edition. Springer, New York. ISBN 0-387-95457-0
#'
#' Matthew Cefalu, Greg Ridgeway, Dan McCaffrey, Andrew Morral,
#' Beth Ann Griffin and Lane Burgette (2021).
#' \emph{twang: Toolkit for Weighting and Analysis of Nonequivalent Groups}.
#' R package version 2.5.
#' URL:\url{https://CRAN.R-project.org/package=twang}
#'
#' Noah Greifer (2021).
#' \emph{WeightIt: Weighting for Covariate Balance in Observational Studies}.
#' R package version 0.12.0.
#' URL:\url{https://CRAN.R-project.org/package=WeightIt}
#'
#' Hadley Wickham (2019).
#' \emph{stringr: Simple, Consistent Wrappers for Common String Operations}.
#'  R package version 1.4.0.
#'  URL:\url{https://CRAN.R-project.org/package=stringr}
#'
#' Andrew Gelman and Yu-Sung Su (2020).
#' \emph{arm: Data Analysis Using Regression and
#' Multilevel/Hierarchical Models}.
#' R package version 1.11-2.
#' URL:\url{https://CRAN.R-project.org/package=arm}
#'
#' Wood, S.N. (2011)
#' Fast stable restricted maximum likelihood and marginal likelihood
#' estimation of semiparametric generalized linear models.
#' \emph{Journal of the Royal Statistical Society (B)} \strong{73}(1):3-36
#'
#' Eric Polley, Erin LeDell, Chris Kennedy and Mark van der Laan (2021).
#' \emph{SuperLearner: Super Learner Prediction}.
#' R package version 2.0-28.
#' URL:\url{https://CRAN.R-project.org/package=SuperLearner}
#'
#' Susan Gruber, Mark J. van der Laan (2012).
#' tmle: An R Package for Targeted Maximum Likelihood Estimation.
#' \emph{Journal of Statistical Software},
#' \strong{51}(13), 1-35.
#'
#' Jasjeet S. Sekhon (2011).
#' Multivariate and Propensity Score Matching Software with
#' Automated Balance Optimization: The Matching Package for R.
#'  \emph{Journal of Statistical Software}, \strong{42}(7), 1-52
#'
#' H. Wickham.
#' \emph{ggplot2: Elegant Graphics for Data Analysis}.
#' Springer-Verlag New York, 2016.
#'
#' Claus O. Wilke (2020).
#' \emph{cowplot: Streamlined Plot Theme and Plot Annotations for 'ggplot2'}.
#' R package version 1.1.1.
#' URL:\url{https://CRAN.R-project.org/package=cowplot}
#'
#' Elio Campitelli (2021).
#' \emph{metR: Tools for Easier Analysis of Meteorological Fields}.
#' R package version 0.11.0.
#' URL:\url{https://github.com/eliocamp/metR}
#'
#' Hadley Wickham (2021).
#' \emph{tidyr: Tidy Messy Data}. R package version 1.1.4.
#' \emph{https://CRAN.R-project.org/package=tidyr}
#'
#' Microsoft Corporation and Steve Weston (2020).
#' \emph{doParallel: Foreach Parallel Adaptor for the 'parallel' Package}.
#' R package version 1.0.16.
#' URL:\url{https://CRAN.R-project.org/package=doParallel}
#'
#' Microsoft and Steve Weston (2020).
#' \emph{foreach: Provides Foreach Looping Construct. R package version 1.5.1}.
#' URL:\url{https://CRAN.R-project.org/package=foreach}
#' @export
#'
#' @examples
#' lp_w_all <-
#'   c(
#'     ".4*x1 + .1*x2  - .1*x4 + .1*x5", # w = 1
#'     ".2 * x1 + .2 * x2  - .2 * x4 - .3 * x5"
#'   ) # w = 2
#' nlp_w_all <-
#'   c(
#'     "-.5*x1*x4  - .1*x2*x5", # w = 1
#'     "-.3*x1*x4 + .2*x2*x5"
#'   ) # w = 2
#' lp_y_all <- rep(".2*x1 + .3*x2 - .1*x3 - .1*x4 - .2*x5", 3)
#' nlp_y_all <- rep(".7*x1*x1  - .1*x2*x3", 3)
#' X_all <- c(
#'   "rnorm(0, 0.5)", # x1
#'   "rbeta(2, .4)", # x2
#'   "runif(0, 0.5)", # x3
#'   "rweibull(1,2)", # x4
#'   "rbinom(1, .4)" # x5
#' )
#'
#' set.seed(111111)
#' data <- data_sim(
#'   sample_size = 300,
#'   n_trt = 3,
#'   x = X_all,
#'   lp_y = lp_y_all,
#'   nlp_y = nlp_y_all,
#'   align = FALSE,
#'   lp_w = lp_w_all,
#'   nlp_w = nlp_w_all,
#'   tau = c(-1.5, 0, 1.5),
#'   delta = c(0.5, 0.5),
#'   psi = 1
#' )
#' ce_estimate(
#'   y = data$y, x = data$covariates, w = data$w,
#'   ndpost = 100, method = "RA", estimand = "ATE"
#' )
ce_estimate <-
  function(y,
           x,
           w,
           method,
           formula = NULL,
           discard = FALSE,
           estimand,
           trim_perc = NULL,
           sl_library,
           reference_trt,
           boot = FALSE,
           nboots,
           verbose_boot = TRUE,
           ndpost = 1000,
           caliper = 0.25,
           n_cluster = 5,
           ...) {
    if (!(estimand %in% c("ATE", "ATT")))
      stop("Estimand only supported for \"ATT\" or \"ATE\"", call. = FALSE)
    if (!(
      method %in% c(
        "RA",
        "VM",
        "BART",
        "TMLE",
        "IPTW-Multinomial",
        "IPTW-GBM",
        "IPTW-SL",
        "RAMS-Multinomial",
        "RAMS-GBM",
        "RAMS-SL"
      )
    ))
    stop(
      "Currently method is only supported for \"RA\", \"VM\", \"BART\",
      \"TMLE\", \"IPTW-Multinomial\", \"IPTW-GBM\", \"IPTW-SL\",
      \"RAMS-Multinomial\", \"RAMS-GBM\", \"RAMS-SL\".
      Please double check the entered method argument.",
      call. = FALSE
    )
    if (estimand == "ATT" &&
        !(reference_trt %in% unique(w)))
      stop(paste0(
        "Please set the reference_trt from ",
        paste0(sort(unique(w)), collapse = ", "),
        "."
      ),
      call. = FALSE)
    if (sum(c(
      length(w) == length(y),
      length(w) == nrow(x),
      length(y) == nrow(x)
    )) != 3)
      stop(
        paste0(
          "The length of y, the length of w and the nrow for x should
          be equal. Please double check the input."
        ),
        call. = FALSE
      )
    if (!is.null(formula)) {
      x <-
        as.data.frame(stats::model.matrix(object = formula, cbind(y, x)))
      x <- x[, !(names(x) == "(Intercept)")]
    }
    if (method == "RA" && estimand == "ATE") {
      result <- ce_estimate_ra_ate(
        y = y,
        x = x,
        w = w,
        ndpost = ndpost
      )
    } else if (method == "RA" && estimand == "ATT") {
      result <- ce_estimate_ra_att(
        y = y,
        x = x,
        w = w,
        ndpost = ndpost,
        reference_trt = reference_trt
      )
    } else if (method == "VM" && estimand == "ATT" && boot == FALSE) {
      result <- ce_estimate_vm_att(
        y = y,
        x = x,
        w = w,
        reference_trt = reference_trt,
        caliper = caliper,
        n_cluster = n_cluster
      )
    } else if (method == "VM" && estimand == "ATT" && boot == TRUE) {
      stop("Bootstrap confidence intervals are not appropriate for
           VM based on Lopez and Gutman (2017). The current version of
           CIMTx does not support the standard error calculation for VM.")
    } else if (method == "BART" && estimand == "ATE") {
      result <- ce_estimate_bart_ate(
        y = y,
        x = x,
        w = w,
        ndpost = ndpost,
        discard = discard,
        ...
      )
    } else if (method == "BART" && estimand == "ATT") {
      result <- ce_estimate_bart_att(
        y = y,
        x = x,
        w = w,
        ndpost = ndpost,
        reference_trt = reference_trt,
        discard = discard,
        ...
      )
    } else if (method == "TMLE" &&
               estimand == "ATE" && boot == FALSE) {
      result <- ce_estimate_tmle_ate(
        y = y,
        x = x,
        w = w,
        sl_library = sl_library,
        ...
      )
    } else if (method == "TMLE" &&
               estimand == "ATE" && boot == TRUE) {
      result <- ce_estimate_tmle_ate_boot(
        y = y,
        x = x,
        w = w,
        sl_library = sl_library,
        nboots = nboots,
        verbose_boot = verbose_boot,
        ...
      )
    } else if (method %in% c("RAMS-Multinomial", "RAMS-GBM", "RAMS-SL") &&
               estimand == "ATE" && boot == FALSE) {
      result <- ce_estimate_rams_ate(
        y = y,
        x = x,
        w = w,
        method = method,
        verbose_boot = verbose_boot,
        ...
      )
    } else if (method %in% c("RAMS-Multinomial", "RAMS-GBM", "RAMS-SL") &&
               estimand == "ATE" && boot == TRUE) {
      result <- ce_estimate_rams_ate_boot(
        y = y,
        x = x,
        w = w,
        method = method,
        nboots = nboots,
        verbose_boot = verbose_boot,
        ...
      )
    } else if (method %in% c("RAMS-Multinomial", "RAMS-GBM", "RAMS-SL") &&
               estimand == "ATT" && boot == FALSE) {
      result <- ce_estimate_rams_att(
        y = y,
        x = x,
        w = w,
        method = method,
        reference_trt = reference_trt,
        ...
      )
    } else if (method %in% c("RAMS-Multinomial", "RAMS-GBM", "RAMS-SL") &&
               estimand == "ATT" && boot == TRUE) {
      result <- ce_estimate_rams_att_boot(
        y = y,
        x = x,
        w = w,
        method = method,
        nboots = nboots,
        reference_trt = reference_trt,
        verbose_boot = verbose_boot,
        ...
      )
    } else if (method %in% c("IPTW-Multinomial", "IPTW-GBM", "IPTW-SL") &&
               estimand == "ATE" && boot == FALSE) {
      result <- ce_estimate_iptw_ate(
        y = y,
        x = x,
        w = w,
        method = method,
        ...
      )
    } else if (method %in% c("IPTW-Multinomial", "IPTW-GBM", "IPTW-SL") &&
               estimand == "ATE" && boot == TRUE) {
      result <- ce_estimate_iptw_ate_boot(
        y = y,
        x = x,
        w = w,
        method = method,
        nboots = nboots,
        verbose_boot = verbose_boot,
        ...
      )
    } else if (method %in% c("IPTW-Multinomial", "IPTW-GBM", "IPTW-SL") &&
               estimand == "ATT" && boot == FALSE) {
      result <- ce_estimate_iptw_att(
        y = y,
        x = x,
        w = w,
        method = method,
        reference_trt = reference_trt,
        ...
      )
    } else if (method %in% c("IPTW-Multinomial", "IPTW-GBM", "IPTW-SL") &&
               estimand == "ATT" && boot == TRUE) {
      result <- ce_estimate_iptw_att_boot(
        y = y,
        x = x,
        w = w,
        method = method,
        reference_trt = reference_trt,
        nboots = nboots,
        verbose_boot = verbose_boot,
        ...
      )
    }
    return(result)
  }

Try the CIMTx package in your browser

Any scripts or data that you put into this service are public.

CIMTx documentation built on June 24, 2022, 9:07 a.m.