integrate_function: The main integration function

View source: R/integrate_function.R

integrate_functionR Documentation

The main integration function

Description

This function does the integration from the paper Here, we pass the mean of the posterior estimates of observed joint probabilities rather than the full posterior estimates.

Usage

integrate_function(
  intframe,
  constraint = T,
  f = f,
  n_cores = n_cores,
  lambda = 0
)

Arguments

intframe

the dataframe of interest

constraint

boolean, whether or not you want to implement b1>=b0 in integration

f

the function for the distribution of f(u)

n_cores

the number of cores (note detect_cores only gives total number of cores, not necessarily available)

lambda

regularization term in constraint, multiplied by exp(b1)+b0-b0)^2

Examples

#Set a seed and generate data from the bivariate probit
set.seed(0)

N <- 500 # Number of random samples
a=1

x1=runif(N, -a,a)
x2=runif(N, -a,a)
x3=runif(N,-a,a)
x4=runif(N,-a,a)
x5=runif(N, -a,a)
beta1= -0.2
alpha1= 0.7
beta0= -0
alpha0= -0.5
mu1 <- beta0+beta1*(x1+x2+x3+x4+x5)
mu2 <- alpha0+alpha1*(x1+x2+x3+x4+x5)
mu<-matrix(c(mu1, mu2), nrow=N, ncol=2)
rho=.5
gamma=1
B1.true=pnorm(mu2+gamma)
B0.true=pnorm(mu2)
sigma <- matrix(c(1, rho,rho,1),
               2) # Covariance matrix
sim_data=t(sapply(1:N, function(i)MASS::mvrnorm(1, mu = mu[i,], Sigma = sigma )))
#generate the binary treatments
G=sapply(1:N, function(i)ifelse(sim_data[i,1]>=0, 1,0))
#generate the binary outcomes
B=sapply(1:N, function(i)ifelse(sim_data[i,2]>=-1*gamma*G[i], 1,0))
print(table(G,B))
covariates=data.frame(x1,x2,x3,x4,x5,B, G)
vars=c('x1','x2','x3','x4','x5')
intframe=BARTpred(covariates, treat='G', Outcome='B',vars, mono=T)
#pick a function to integrate over
#here the standard deviation is chosen to match the generated data
f=function(u){
 dnorm(u, mean=0, sd=sqrt(rho/(1-rho)))
}
treat_frame=integrate_function(intframe, constraint=T, f=f, n_cores=1, lambda=0)
library(rpart)
#merge the ites and covariates
dataset=data.frame(covariates, tau=treat_frame[, 'tau'])
tree_fit<-rpart(tau~x1+x2+x3+x4+x5,data=dataset)
rpart.plot::rpart.plot(tree_fit)
####Optional, if you want to make a quick histogram of ITE's
library(tidyverse)
library(dplyr)
dataset%>%
ggplot(aes(x=tau))+
geom_histogram(aes(y=..count../sum(..count..)),color='white',
fill='#1d2951')+
ggtitle('All observations')+
ylab('Density')+xlab('Individual Treatment Effects')+
xlim(0,.25)+theme_minimal(base_size = 16)+
theme(plot.title = element_text(hjust = 0.5,size=16))

demetrios1/Causallysensitive documentation built on Nov. 18, 2022, 5:27 p.m.