causalOT
was developed to reproduce the methods in Optimal transport methods for causal inference. The functions in the package are built to construct weights to make distributions more same and estimate causal effects. We recommend using the Causal Optimal Transport methods since they are semi- to non-parametric. This document will describe some simple usages of the functions in the package and should be enough to get users started.
The main functions of the package, calc_weight
and estimate_effect
, take arguments x
, a numeric matrix
of covariates; z
, a treatment indicator in c(0,1)
; and y
, a numeric vector with the outcome data.
If easy to do, we can supply the necessary data directly.
# packages library(causalOT) library(torch) # reproducible seeds set.seed(1111) torch_manual_seed(3249) # generated some data hainmueller <- Hainmueller$new(n = 512) hainmueller$gen_data() x <- hainmueller$get_x() z <- hainmueller$get_z() y <- hainmueller$get_y() # NOT RUN # weights <- calc_weight(x = x, z = z, y = y)
Note that the calc_weight function will not use the outcome data y
when calculating the weights and will not pass it to the internal data constructor. It must be passed to the estimate_effect
function later.
However, sometimes we get data as a data.frame
and users may not know how to turn these into the required objects. In this case, we've supplied the df2dataHolder
function to create the required data object like so:
df <- data.frame(y = y, z = z, x) df2dH <- df2dataHolder(treatment.formula = "z ~ .", outcome.formula = "y ~ .", data = df) # NOT RUN # weights <- calc_weight(x = df2dH, z = NULL, y = NULL)
In this case, we can pass the dataHolder
object directly to the calc_weight
function inside the x
argument and ignore the others; the dataHolder
object already contains the x
, y
, and z
data internally. Note that this function will need both outcome and treatment formulae since it needs to know which columns are actually confounders for the purposes of calculating the weights!
Finally, if you so desire, you can create a dataHolder object directly.
dH <- dataHolder(x = x, z = z, y = y) # NOT RUN # weights <- calc_weight(x = dH, z = NULL, y = NULL)
This may be useful if you plan on reusing the data object.
The weights can be estimated by using the calc_weight
function in the package. We select optimal hyperparameters through our bootstrap-based algorithm and target the average treatment effect.
weights <- calc_weight(x = x, z = z, method = "COT", estimand = "ATE", options = list(lambda.bootstrap = Inf, nboot = 1000L) )
These weights will balance distributions, making estimates of treatment effects unbiased.
We can then estimate effects with
tau_hat <- estimate_effect(causalWeights = weights, y = y)
The estimator generated here is a simple weighted difference in observed outcomes between treatment groups. Moreover, note we must supply the outcome information in argumnet y
since the calc_weight
function does not store it when we pass data matrices.
The output of the estimate_effect
function creates an object of class causalEffect
which can be fed into the native R
function confint
to calculate asymptotic confidence intervals,
ci_tau <- confint(object = tau_hat, level = 0.95)
or into vcov
to calculate the variance of your estimate using the semiparametrically efficient variance formula:
var_tau <- vcov(object = tau_hat)
This then gives the following treatment effect estimate, variance, and C.I.
print(coef(tau_hat)) #> estimate #> 0.007949831 print(var_tau) #> estimate #> estimate 0.1887772 print(ci_tau) #> 2.5 % 97.5 % #> estimate -0.8436251 0.8595248
The function estimate_effect
can also use models to estimate the treatment effects. There are also several additional arguments that will be demonstrated below. These are:
model.function
: either a character or function with the model you want to runestimate.separately
: TRUE or FALSE. Should the model be estimated separately on each treatment group (TRUE) or jointly on the full data (FALSE)augment.estimate
: Should an augmented estimator be used to calculate the final treatment effect? (TRUE or FALSE)normalize.weights
: Should the weights be normalized to sum to one in each treatment group before being used. For methods except "Logistic", "Probit", or "CBPS", the weights are by definition normalized to sum to one so this option will not have an effect for most methods in the package.The model functions we can use need to have a few components
formula
argumentdata
argument that accepts a data.frame
weights
argumentnewdata
argument.One such function we could use is lm
.
lm
tau_hat_lm <- estimate_effect(causalWeights = weights, y = y, model.function = lm, estimate.separately = TRUE, augment.estimate = FALSE, normalize.weights = TRUE)
In this case, separate models will be fit to treated and controls and the predictions from the model will be used to estimate treatment effects. We can also calculate the augmented (aka doubly robust) estimate with argument augment.estimate
.
tau_hat_dr <- estimate_effect(causalWeights = weights, y = y, model.function = lm, estimate.separately = TRUE, augment.estimate = TRUE, normalize.weights = TRUE)
We can also fit a weighted OLS by specifying estimate.separately = FALSE
:
tau_hat_wols <- estimate_effect(causalWeights = weights, y = y, model.function = lm, estimate.separately = FALSE, augment.estimate = FALSE, normalize.weights = TRUE)
This fits a single weighted OLS model on the entire data.
An outcome model that is particular to this package is the function barycentric_projection
that estimates, as the name implies, barycentric projections of the outcome data. To use this function, there are a couple of steps. Unlike the case for linear models with lm
, we need to think carefully about what sample the data arise from.
To use this function outside of the main causalOT functions, we would do
df <- data.frame(z = z, y = y, x) bp <- barycentric_projection(formula = "y ~ x + z", data = df, weights = weights, separate.samples.on = "z", penalty = 0.01, cost_function = NULL, p = 2, debias = FALSE, cost.online = "auto", diameter = NULL, niter = 1000, tol = 1e-7)
This will run the optimal transport problem between the samples denoted by "z" and get the dual potentials for the Sinkhorn Divergence problem.
Then, we can run a predict
function to see what the outcomes would be if the samples had arisen from a different distribution. Let's say that everyone had actually been treated
newdf <- df newdf$z <- 1L preds <- predict(object = bp, newdata = newdf, source.sample = df$z) head(preds) #> [1] -2.80749441 0.52369862 5.79605734 -1.55036673 0.01254699 #> [6] 1.55753697
The argument source.sample
should be a vector that denotes the original treatment group of the samples. This allows the function to use the appropriate dual potentials to calculate the expected outcome.
In the context of the estimate_effect function, we need to supply some extra arguments in the ...
argument.
tau_hat_bp <- estimate_effect(causalWeights = weights, y = y, model.function = barycentric_projection, estimate.separately = FALSE, augment.estimate = TRUE, normalize.weights = TRUE, # special args for barycentric_projection separate.samples.on = "z", penalty = 0.01, cost_function = NULL, p = 3, debias = FALSE, cost.online = "tensorized", diameter = NULL, niter = 1000L, tol = 1e-7, line_search_fn = "strong_wolfe" ) print(tau_hat_bp@estimate) #> [1] -0.02549428
This method currently doesn't have a variance estimator accounting for the weight uncertainty but we can use the asymptotic variance estimator of Hahn (1998):
vcov(tau_hat_bp) #> estimate #> estimate 0.008051722 vcov(tau_hat_wols) #> estimate #> estimate 0.01705504 vcov(tau_hat) #> estimate #> estimate 0.1887772
In neither of these cases did we feed data or a formula to the model function. By default, the estimate_effect
function will regress the outcome in argument y
on all of the covariates from the calc_weight
function and adjust for the treatment indicator as appropriate given the selected options. If you want to change the covariates for the outcome model from the weighting estimating function, you can provide new covariates in an argument x
:
estimate_effect(causalWeights = weights, x = x_new, y = y)
Note this data must have the same observation order as the previous data and must also be an object of class matrix
.
Diagnostics are also an important part of deciding whether the weights perform well. There are several areas that we will explore:
Typically, estimated samples sizes with weights are calculated as $\sum_i 1/w_i^2$ and gives us a measure of how much information is in the sample. The lower the effective sample size (ESS), the higher the variance, and the lower the sample size, the more weight a few individuals have. Of course, we can calculate this in causalOT
!
ESS(weights) #> Control Treated #> 116.66625 84.32589
Of course, this measure has problems because it can fail to diagnose problems with variable weights. In response, Vehtari et al. use Pareto smoothed importance sampling. We offer some shell code to adapt the class causalWeights
to the loo
package:
raw_psis <- PSIS(weights)
This will also return the Pareto smoothed weights and log weights.
If we want to easily examine the PSIS diagnostics, we can pull those out too
PSIS_diag(raw_psis) #> $w0 #> $w0$pareto_k #> [1] -0.2320877 #> #> $w0$n_eff #> [1] 115.7568 #> #> #> $w1 #> $w1$pareto_k #> [1] 0.2478191 #> #> $w1$n_eff #> [1] 82.99825
We can see that all of the $k$ values are below the recommended 0.5, indicating finite variance and that the central limit theorem holds. Note the estimated sample sizes are a bit lower than the ESS
method above.
Many authors consider the standardized absolute mean balance as a marker for important balance: see Stuart (2010). That is
$$ \frac{|\overline{X}c - \overline{X}_t| }{\sigma{\text{pool}}},$$ where
$\overline{X}c$ is the mean in the controls, $\overline{X}_t$ is the mean in the treated, and $\sigma{\text{pool}}$ is the pooled standard deviation.
We offer such checks in causalOT
as well.
First, we consider pre-weighting mean balance between treatment groups
mean_balance(x = hainmueller) #> X1 X2 X3 X4 X5 X6 #> 1.0889178 0.9320099 0.9322327 0.3741322 0.2798986 0.2644929
and after weighting mean balance between treatment groups
mean_balance(x = hainmueller, weights = weights) #> X1 X2 X3 X4 X5 X6 #> 0.036900048 0.006577199 0.003137536 0.003228830 0.002445118 0.060551316
Pretty good! However, mean balance doesn't ensure distributional balance.
Ultimately, distributional balance is what we care about in causal inference. Fortunately, we can also measure that too. We consider the 2-Sinkhorn divergence of Genevay et al. since it metrizes the convergence in distribution.
Before weighting, distributional balance looks poor:
# controls ot_distance(x1 = hainmueller$get_x0(), x2 = hainmueller$get_x(), a = NULL, b = rep(1/512,512), p = 2, penalty = 1e3, debias = TRUE) #> [1] 0.5311378 #treated ot_distance(x1 = hainmueller$get_x1(), x2 = hainmueller$get_x(), a = NULL, b = rep(1/512,512), p = 2, penalty = 1e3, debias = TRUE) #> [1] 0.4612781
But after weighting, it looks much better!
# controls ot_distance(x1 = hainmueller$get_x0(), x2 = hainmueller$get_x(), a = weights@w0, b = rep(1/512,512), p = 2, penalty = 1e3, debias = TRUE) #> [1] 0.003670398 # treated ot_distance(x1 = hainmueller$get_x1(), x2 = hainmueller$get_x(), a = weights@w1, b = rep(1/512,512), p = 2, penalty = 1e3, debias = TRUE) #> [1] 0.002123788
After Causal Optimal Transport, the distributions are much more similar.
We can also simply feed the output of calc_weight directly into the ot_distance
function:
ot_distance(x1 = weights, p = 2, penalty = 1e3, debias = TRUE) #> $pre #> control treated #> 0.5311378 0.4612781 #> #> $post #> control treated #> 0.003670398 0.002123788
and the S4 deployment takes care of the rest.
Finally, we can construct a summary of the optimal transport distances, Pareto k statistics, effective sample size, and mean balance using the summary method:
summarized_cw <- summary(weights, penalty = 1000)
We can then print the object to the screen:
summarized_cw #> Diagnostics for causalWeights for estimand ATE #> Control group #> pre post #> OT distance 0.5311378 0.003670398 #> Pareto k NA -0.232087710 #> N eff 247.0000000 115.756836449 #> Avg. std. mean balance 0.3067516 0.017646355 #> #> Treated group #> pre post #> OT distance 0.4612781 2.123788e-03 #> Pareto k NA 2.478191e-01 #> N eff 265.0000000 8.299825e+01 #> Avg. std. mean balance 0.2859156 9.871006e-04
or we can make some diagnostic plots too!
plot(summarized_cw) #> Plotting diagnostics for causalWeights for estimand ATE
The calc weight function can also handle other methods. We have implemented methods for logistic or probit regression, the covariate balancing propensity score (CBPS), stable balancing weights (SBW), entropy balancing weights (EntropyBW), and the synthetic control method (SCM).
calc_weight(x = hainmueller, method = "Logistic", estimand = "ATE") calc_weight(x = hainmueller, method = "Probit", estimand = "ATE") calc_weight(x = hainmueller, method = "CBPS", estimand = "ATE") calc_weight(x = hainmueller, method = "SBW") calc_weight(x = hainmueller, method = "EntropyBW") calc_weight(x = hainmueller, method = "SCM")
The function also accepts methods "EnergyBW", for Energy Balancing Weights of Hainmueller and Mak (2020), and "NNM", for nearest neighbor matching with replacement, but these are special cases of COT with the penalty parameter $\lambda$ forced to be $\infty$ and $0$, respectively.
The argument options
is a little vague. So we also have a function cotOptions
which is avaible to help. The documentation provides more details. The other optimization methods "SBW", "EntropyBW", and "SCM" provide their own options function. The options for "Logistic" and "Probit" pass arguments to glm
and "CBPS" will pass arguments to the CBPS
function in the package of the same name.
The package also provides more flexible optimal transport weights and modeling via some object-oriented programming via the R6
package. These functions don't have as many safeguards and everything is done by reference so that you have to be more careful about what you do. However, it also gives you more flexibility on the types of problems you can solve. To learn more, see the vignette on object-oriented solvers.
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.