#' @title Estimation of ATT
#'
#' @description Differentiated Confounder Balancing (DCB) Estimation of ATE
#'
#' @param outcome Y, binary treatment T, and covariates X
#'
#' @return ATT
#'
#' @examples ATT.DCB(Y,T,X)
#'
#' @export ATT.DCB
ATT.DCB <- function(Y,T,M,gp=TRUE,lambda=10,delta=0.001,mu=0.001,upsilon=0.001,thold=1e-4,max_iter=100000){
n_c <- sum(T==0)
p <- ncol(M)
J <- c()
values <- c()
####################################################################
##########################Given Variables###########################
####################################################################
Y_c <- Y[T==0] - mean(Y[T==0])
M_c <- M[T==0,]
M_t <- M[T==1,]
M_norm <- M
for(j in 1:ncol(M)){
ms <- mean(M_c[,j])
M_c[,j] <- M_c[,j] - ms
M_t[,j] <- M_t[,j] - ms
M_norm[,j] <- M[,j] - ms
}
if(gp){
M_t_bar <- apply(M_t, 2, mean)
}
else{
M_t_bar <- apply(M_norm, 2, mean)
}
obj_func <- function(W,beta){
diff <- M_t_bar - t(M_c) %*% W
value1 <- (sum(beta*diff))^2 + sum((1+W)*(Y_c - M_c %*% beta)^2) + delta*sum(W^2)
value2 <- mu*sum(beta^2) + upsilon*sum(abs(beta))
return(value1 + value2)
}
####################################################################
#############################Initialize#############################
####################################################################
W <- rep(1/n_c,n_c)
beta <- rep(1/p,p)
J[1] <- obj_func(W,beta)
omega <- sqrt(W)
fr <- function(omega){
W <- omega^2
diff <- M_t_bar - t(M_c) %*% W
value <- (sum(beta*diff))^2 + sum((1+W)*(Y_c - M_c %*% beta)^2) + delta*sum(W^2)
return(value)
}
grr <- function(omega){
W <- omega^2
diff <- M_t_bar - t(M_c) %*% W
comp1 <- -4 * sum(beta*diff) * matrixcalc::hadamard.prod(M_c %*% beta,omega)
comp2 <- 4 * delta * matrixcalc::hadamard.prod(W,omega)
comp3 <- 2 * lambda * matrixcalc::hadamard.prod(omega,(Y_c-M_c %*% beta)^2)
J_omega <- as.vector(comp1 + comp2 + comp3)
return(J_omega)
}
####line search
fit <- rje::armijo(fun=fr, x=omega, dx = -grr(omega))
eta <- fit$adj
####cross validation
diff <- as.vector(M_t_bar - t(M_c) %*% W)
Y_prime <- sqrt(lambda*(1+W))*Y_c
M_prime <- as.vector(sqrt(lambda*(1+W)))*M_c
mydata <- as.data.frame(cbind(c(Y_prime,0),rbind(M_prime,diff)))
names(mydata)[1] <- "Y_prime_0"
model <- caret::train(
Y_prime_0 ~ ., data = mydata, method = "glmnet",
trControl = trainControl("cv", number = 10),
tuneLength = 10
)
alpha1 <- model$bestTune[[1]]
lambda1 <- model$bestTune[[2]]
for(ind in 1:max_iter){
####################################################################
#########################Fix W Update beta##########################
####################################################################
diff <- as.vector(M_t_bar - t(M_c) %*% W)
Y_prime <- sqrt(lambda*(1+W))*Y_c
M_prime <- as.vector(sqrt(lambda*(1+W)))*M_c
####Elastic net####
#fit <- glmnet::glmnet(x=rbind(M_prime,diff), y=c(Y_prime,0),
# family="gaussian", alpha=1/(1+2*mu/upsilon), lambda = 2*mu+upsilon)
fit <- glmnet::glmnet(x=rbind(M_prime,diff), y=c(Y_prime,0),
family="gaussian", alpha=alpha1, lambda = lambda1)
beta <- as.vector(fit$beta)
values <- c(values,obj_func(W,beta))
####################################################################
#########################Fix beta Update W##########################
####################################################################
omega <- sqrt(W)
fr <- function(omega){
W <- omega^2
diff <- M_t_bar - t(M_c) %*% W
value <- (sum(beta*diff))^2 + sum((1+W)*(Y_c - M_c %*% beta)^2) + delta*sum(W^2)
return(value)
}
grr <- function(omega){
W <- omega^2
diff <- M_t_bar - t(M_c) %*% W
comp1 <- -4 * sum(beta*diff) * matrixcalc::hadamard.prod(M_c %*% beta,omega)
comp2 <- 4 * delta * matrixcalc::hadamard.prod(W,omega)
comp3 <- 2 * lambda * matrixcalc::hadamard.prod(omega,(Y_c-M_c %*% beta)^2)
J_omega <- as.vector(comp1 + comp2 + comp3)
return(J_omega)
}
####line search
#fit <- armijo(fun=fr, x=omega, dx = -grr(omega))
#eta <- fit$adj
#fit <- linesch_ww(fr,grr,omega,-grr(omega))
#eta <- fit$alpha
omega <- omega - eta * grr(omega)
omega <- omega/sqrt(sum(matrixcalc::hadamard.prod(omega,omega)))
W <- matrixcalc::hadamard.prod(omega,omega)
#print(c(obj_func(W,beta),mean(Y[T==1])-sum(W*Y_c)))
values <- c(values,obj_func(W,beta))
####################################################################
#####################Objective function value#######################
####################################################################
J[ind+1] <- obj_func(W,beta)
if(abs(J[ind+1]-J[ind])<thold){break}
#print(c(ind,(J[ind+1]-J[ind]),mean(Y[T==1]) - sum(W*Y[T==0])))
}
#plot(J[2:length(J)])
out <- list("alpha"=alpha1,"lambda"=lambda1,"beta"=beta, "weight"=W, "ATT"=mean(Y[T==1]) - sum(W*Y[T==0]), "update"=values)
return(out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.