Nothing
#' @title High-Dimensional Conditional Average Treatment Effects (HDCATE) Estimator
#'
#' @keywords internal
#'
#' @description
#' Use a two-step procedure to estimate the conditional average treatment effects (CATE) for all possible values of the covariate(s).
#'
#'
#' @format [R6::R6Class] object.
#'
#' @family HDCATE
HDCATE_R6Class <- R6Class(
# different instances for different models
"HDCATE",
public=list(
data=NA,
y_name=NA,
d_name=NA,
x_formula=NA,
cond_var_name=NA,
cond_var_lower=NA,
cond_var_upper=NA,
cond_var_eval_step=0.01,
use_cross_fitting=FALSE,
user_defined_first_stage=FALSE,
user_defined_fit_untreated=NA,
user_defined_fit_treated=NA,
user_defined_fit_ps=NA,
user_defined_predict_untreated=NA,
user_defined_predict_treated=NA,
user_defined_predict_ps=NA,
propensity_est_method='lasso',
expectation_est_method='lasso',
propensity_score_model_list=list(),
treated_cond_exp_model_list=list(),
untreated_cond_exp_model_list=list(),
local_reg_model_list=list(),
bw='rule-of-thumb',
k_fold=NA,
folds=NULL,
local_weight=NA,
hdcate=NA,
sigma_hat=NA,
eta_hat_func_list=NA,
eta_hat_list=NA,
d=NA,
h=NA,
n_obs=NA,
alpha=NA,
n_alpha=NA,
i_two_side=NA,
i_left_side=NA,
i_right_side=NA,
cond_interval=NA,
length_grid=NA,
ate=NA,
vate=NA,
initialize = function(data, y_name, d_name, x_formula){
# necessary param
self$data <- data
self$y_name <- y_name
self$d_name <- d_name
self$x_formula <- x_formula
# init model
# equal weight
self$local_weight <- rep(1, nrow(data))
self$n_obs = nrow(data)
},
propensity_hd_estimate = function(data=NA, verbose=F) {
data <- private$check_data(data)
method <- self$propensity_est_method
d_name <- self$d_name
x_formula <- self$x_formula
if (verbose) {message(paste0('Start estimating model for propensity score, method=', method))}
if (!self$user_defined_first_stage){
if (method == "lasso") {
model <- hdm::rlassologit(as.formula(paste0(d_name, "~", x_formula)), data=data)
} else {
stop('When estimating propensity score, a valid method should be provided, such as "lasso".')
}
} else {
model <- self$user_defined_fit_ps(model.frame(as.formula(paste0(d_name, "~", x_formula)), data))
}
if (verbose) {message(paste0('Finish estimating propensity score.'))}
self$propensity_score_model_list <- append(self$propensity_score_model_list, list(model))
return(model)
},
conditional_expectations_hd_estimate = function(data=NA, verbose=F) {
data <- private$check_data(data)
method <- self$expectation_est_method
y_name <- self$y_name
x_formula <- self$x_formula
if (verbose) {message(paste0('Start estimating conditional expectations, method=', method))}
filter_treated<-data[,self$d_name]==1
filter_untreated<-data[,self$d_name]==0
if (!self$user_defined_first_stage){
if (method == 'lasso') {
model_untreated<-hdm::rlasso(as.formula(paste0(y_name, "~", x_formula)),
data=data[filter_untreated,])
model_treated<-hdm::rlasso(as.formula(paste0(y_name, "~", x_formula)),
data=data[filter_treated,])
} else {
stop('When estimating conditional expectations, a valid method should be provided, such as "lasso".')
}
} else {
model_untreated<-self$user_defined_fit_untreated(model.frame(as.formula(paste0(y_name, "~", x_formula)), data[filter_untreated,]))
model_treated<-self$user_defined_fit_treated(model.frame(as.formula(paste0(y_name, "~", x_formula)), data[filter_treated,]))
}
if (verbose) {message(paste0('Finish estimating conditional expectations.'))}
self$untreated_cond_exp_model_list <- append(self$untreated_cond_exp_model_list, list(model_untreated))
self$treated_cond_exp_model_list <- append(self$treated_cond_exp_model_list, list(model_treated))
return(list(predictor_untreated=model_untreated, predictor_treated=model_treated))
},
first_stage = function(data=NA, verbose=F) {
data <- private$check_data(data)
predictor_propensity <- list(predictor_propensity=self$propensity_hd_estimate(data, verbose=verbose))
predictor_cond_expectations <- self$conditional_expectations_hd_estimate(data, verbose=verbose)
predictor_eta_hat <- append(predictor_cond_expectations, predictor_propensity)
return(predictor_eta_hat)
},
second_stage = function(predictor_eta_hat=NA, eta_hat=NA, subsample_idx=NULL, local_weight=NULL, estimate_std=TRUE, verbose=FALSE, save_model=TRUE) {
y_name <- self$y_name
d_name <- self$d_name
x_formula <- self$x_formula
if (verbose) {
if (self$use_cross_fitting) {
message(paste0('Start estimating HDCATE in a fold.'))
} else {
message(paste0('Start estimating HDCATE.'))
}
}
if (is.null(subsample_idx) || anyNA(subsample_idx)) {
# full-sample
subsample_idx <- c(1:nrow(self$data))
}
if (is.null(local_weight) || anyNA(local_weight)) {
# equal weight (out of bootstrap)
local_weight <- self$local_weight[subsample_idx]
}
# Compute phi
if (is.null(eta_hat) || anyNA(eta_hat)) {
# for fitting case
if (!self$user_defined_first_stage) {
mu0_hat <- predict(predictor_eta_hat$predictor_untreated, self$data[subsample_idx,])
mu1_hat <- predict(predictor_eta_hat$predictor_treated, self$data[subsample_idx,])
pi_hat <- predict(predictor_eta_hat$predictor_propensity, self$data[subsample_idx,], type="response")
} else {
mu0_hat <- self$user_defined_predict_untreated(predictor_eta_hat$predictor_untreated, self$data[subsample_idx,])
mu1_hat <- self$user_defined_predict_treated(predictor_eta_hat$predictor_treated, self$data[subsample_idx,])
pi_hat <- self$user_defined_predict_ps(predictor_eta_hat$predictor_propensity, self$data[subsample_idx,])
}
eta_hat <- list(untreated=mu0_hat, treated=mu1_hat, propensity=pi_hat)
} else {
# for inference case
mu0_hat<-eta_hat$untreated
mu1_hat<-eta_hat$treated
pi_hat<- eta_hat$propensity
}
D <- self$data[subsample_idx, self$d_name]
Y <- self$data[subsample_idx, self$y_name]
phi <- D * (Y - mu1_hat) / pi_hat + mu1_hat - (1 - D) * (Y - mu0_hat) / (1 - pi_hat) - mu0_hat
# Local linear regression
# Gaussian kernel
ker <- locpol::gaussK
cond_var <- self$data[subsample_idx, self$cond_var_name]
cond_interval <- self$cond_interval
# ATE
ate <- mean(phi)
vate <- ate*rep(1,length(cond_interval))
# bandwidth
h <- self$get_bw(
phi=phi,
use_sample_idx=use_sample_idx
)
model <- locpol::locPolSmootherC(x=cond_var, y=phi, xeval=cond_interval, bw=h,
weig=local_weight, deg=1, kernel=ker)
if (save_model) {
self$local_reg_model_list <- append(self$local_reg_model_list, list(model))
}
# get HDCATE estimator, i,e. tau_hat in the paper
hdcate <- model$beta0
if (verbose) {
message(paste0('Finish estimating HDCATE.'))
}
if (estimate_std) {
if (verbose) {
message(paste0('Start estimating standard errors.'))
}
length_grid <- self$length_grid
fX_hat <- numeric(length_grid)
sigmasq_hat <- numeric(length_grid)
d <- self$d
for(i in 1:length_grid)
{
fX_hat[i]<-mean(ker((cond_var-cond_interval[i])/h))/h
sigmasq_hat[i]<-mean((phi-hdcate[i])^2*(ker((cond_var-cond_interval[i])/h))^2)/(fX_hat[i]^2)/(h^d)
}
sigma_hat <- sqrt(sigmasq_hat)
if (verbose) {
message(paste0('Finish estimating standard errors.'))
}
} else {
sigma_hat <- NULL
}
self$h <- h
return(list(hdcate=hdcate, sigma_hat=sigma_hat, vate=vate, ate=ate, eta_hat=eta_hat))
},
get_bw = function(phi, use_sample_idx) {
if (self$bw == 'plug-in'){
# use subsample to choose bandwidth
use_full_sample <- FALSE
} else {
# use full sample to choose bandwidth
use_full_sample <- TRUE
# update index
use_sample_idx <- c(1:nrow(data))
}
# bandwidth
cond_var <- self$data[use_sample_idx, self$cond_var_name]
sample_size <- nrow(self$data[use_sample_idx, ])
if (self$bw == 'rule-of-thumb'){
# Silverman (1986), rule-of-thumb
h_hat <- stats::bw.nrd0(cond_var)
h <- h_hat * sample_size^(1/5) * sample_size^(-2/7)
} else if (self$bw == 'plug-in'){
# Ruppert,Sheather and Wand (1995), the plug-in method
h_hat <- KernSmooth::dpill(x=cond_var, y=phi)
h <- h_hat * sample_size^(1/5) * sample_size^(-2/7)
} else if (is.numeric(self$bw)){
# bandwidth can be specified by user
h <- self$bw
}
# avoid h=NaN
if (anyNA(h) || is.null(h)) {
h <- 0.01
}
return(h)
},
#' @description Fit the HDCATE function
#'
#' @return estimated HDCATE
fit = function(verbose=FALSE){
if (!self$use_cross_fitting){
# Full-sample estimator
message(paste0('Use full-sample estimator.'))
# First stage
predictor_eta_hat <- self$first_stage(verbose=verbose)
# return(predictor_eta_hat)
# Second stage
res_second_stage <- self$second_stage(predictor_eta_hat=predictor_eta_hat, verbose=verbose)
self$hdcate <- res_second_stage$hdcate
self$ate <- res_second_stage$ate
self$vate <- res_second_stage$vate
self$sigma_hat <- res_second_stage$sigma_hat
self$eta_hat_list <- list(full_sample=res_second_stage$eta_hat)
} else {
if (is.numeric(self$k_fold) && (self$k_fold %% 1 == 0) && (self$k_fold >= 2)) {
# K-fold cross-fitting estimator
message(paste0('Use ', self$k_fold, '-fold cross-fitting estimator.'))
if (is.null(self$folds) || anyNA(self$folds)) {
self$folds <- caret::createFolds(self$data[, self$y_name], k=self$k_fold)
}
cv_res <- lapply(
self$folds,
function(idx) {
data_kc <- self$data[-idx, ]
# First stage
predictor_eta_hat_kc <- self$first_stage(data=data_kc, verbose=verbose)
# Second stage
res <- self$second_stage(
predictor_eta_hat=predictor_eta_hat_kc,
subsample_idx=idx,
verbose=verbose
)
return(res)
}
)
self$hdcate <- as.matrix(rowMeans(sapply(cv_res, function(obj) obj$hdcate)))
self$ate <- mean(sapply(cv_res, function(obj) obj$ate))
self$vate <- as.matrix(rowMeans(sapply(cv_res, function(obj) obj$vate)))
self$sigma_hat <- as.matrix(rowMeans(sapply(cv_res, function(obj) obj$sigma_hat)))
self$eta_hat_list <- lapply(cv_res, function(obj) obj$eta_hat)
} else {
stop(paste0('`k_fold`=', self$k_fold, ' is not supported for cross-fitting method'))
}
# return(list(cv_res=cv_res))
}
},
inference = function(sig_level=0.01, boot_method='normal', n_rep_boot=1000, verbose=FALSE){
# uniform confidence bands
message(paste0('Constructing uniform confidence bands...'))
# validation
v1 <- !(is.null(self$hdcate) || anyNA(self$hdcate))
v2 <- !(is.null(self$sigma_hat) || anyNA(self$sigma_hat))
v3 <- !(is.null(self$eta_hat_list) || anyNA(self$eta_hat_list))
if (!(v1 && v2 && v3)) {
stop('Missing fitted models. You should call `HDCATE.fit()` before calling `HDCATE.inference()`.')
}
# bootstrap
# n_rep_boot<-2 # for test, just delete this line
weights = self$draw_weights(boot_method, n_rep_boot, self$n_obs) # one row for one boot
hdcate_stat_boot_one_sided <- numeric(n_rep_boot)
hdcate_stat_boot_two_sided <- numeric(n_rep_boot)
for (b in 1:n_rep_boot){
if (is.null(self$folds) || anyNA(self$folds)) {
# full_sample
boot_loc_res <- self$second_stage(
eta_hat=self$eta_hat_list[[1]],
subsample_idx=NULL,
estimate_std=FALSE,
local_weight=weights[,b],
verbose=verbose,
save_model=FALSE
)
hdcate_boot <- boot_loc_res$hdcate
} else {
cv_res <- mapply(
function(idx, list_idx) {
boot_loc_res <- self$second_stage(
eta_hat=self$eta_hat_list[[list_idx]],
subsample_idx=idx,
estimate_std=FALSE,
local_weight=weights[idx,b],
verbose=verbose,
save_model=FALSE
)
return(boot_loc_res)
},
self$folds,
seq_along(self$folds),
SIMPLIFY=FALSE
)
hdcate_boot <- as.matrix(rowMeans(sapply(cv_res, function(obj) obj$hdcate)))
}
hdcate_stat_boot_one_sided[b] <- max(sqrt(self$n_obs * (self$h ^ self$d)) * (hdcate_boot - self$hdcate) / self$sigma_hat)
hdcate_stat_boot_two_sided[b] <- max(sqrt(self$n_obs * (self$h ^ self$d)) * abs(hdcate_boot - self$hdcate) / self$sigma_hat)
}
# return(list(hdcate_stat_boot_one_sided, hdcate_stat_boot_two_sided, hdcate_boot, weights))
# CI
# alpha can be a vector like c(0.10, 0.05, 0.01)
alpha <- sig_level
n_alpha <- length(alpha)
c_alpha_one_sided <- quantile(hdcate_stat_boot_one_sided, 1 - alpha, na.rm = TRUE)
c_alpha_two_sided <- quantile(hdcate_stat_boot_two_sided, 1 - alpha, na.rm = TRUE)
# Uniform confidence bands
i_left_side <- matrix(0, self$length_grid, n_alpha,
dimnames = list(paste0('CATE on ', self$cond_var_name, '=', self$cond_interval), paste0('Sigificant level ', alpha)))
i_right_side <- matrix(0, self$length_grid, n_alpha,
dimnames = list(paste0('CATE on ', self$cond_var_name, '=', self$cond_interval), paste0('Sigificant level ', alpha)))
i_two_side <- matrix(0, self$length_grid, n_alpha * 2,
dimnames = list(
paste0('CATE on ', self$cond_var_name, '=', self$cond_interval),
rep(paste0('Sigificant level ', alpha), 2))
)
for (idx in 1:n_alpha){
i_left_side[,idx] <- self$hdcate - c_alpha_one_sided[idx] * self$sigma_hat / sqrt(self$n_obs * (self$h ^ self$d))
i_right_side[,idx] <- self$hdcate + c_alpha_one_sided[idx] * self$sigma_hat / sqrt(self$n_obs * (self$h ^ self$d))
i_two_side[,idx] <- i_left_side[,idx]
i_two_side[,idx+n_alpha] <- i_right_side[,idx]
}
self$alpha <- alpha
self$n_alpha <- n_alpha
self$i_two_side <- i_two_side
self$i_left_side <- i_left_side
self$i_right_side <- i_right_side
message(paste0('Uniform confidence bands are constructed.'))
# res <- list(i_two_side, i_left_side, i_right_side, hdcate, vate)
# return(list(cv_res=cv_res))
},
#' @description Plot the results.
#'
#' @param output_pdf if `TRUE`, save image to a pdf file named as `pdf_name`
#' @param pdf_name the name of the output PDF file
#' @param include_band if `TRUE`, plot uniform confidence bands as well.
#' @param test_side `'both'` for a 2-sided test, `'left'` for a left-sided test or `'right'` for a right-sided test
#' @param y_axis_min the lowest value plotted in Y axis, the default is `'auto'`
#' @param y_axis_max the largest value plotted in Y axis, the default is `'auto'`
#' @param display.hdcate the name of HDCATE function in the legend, the default is 'HDCATEF'
#' @param display.ate the name of average treatment effect in the legend, the default is 'ATE'
#' @param display.siglevel the name of the significant level for confidence bands in the legend, the default is 'sig_level'
plot=function(output_pdf=FALSE, pdf_name='hdcate_plot.pdf', include_band=TRUE, test_side='both', y_axis_min='auto', y_axis_max='auto', display.hdcate='HDCATEF', display.ate='ATE', display.siglevel='sig_level'){
min_value <- min(self$hdcate, self$vate)
max_value <- max(self$hdcate, self$vate)
if (is.null(self$i_two_side) || anyNA(self$i_two_side)) {
include_band <- FALSE
}
if (include_band) {
cb_matrix <- self$get_confidence_bands(test_side)
min_value <- min(min_value, cb_matrix)
max_value <- max(max_value, cb_matrix)
}
# set y axis
distance <- max_value - min_value
if (y_axis_min=='auto'){
ylower <- min_value - distance * 0.25
}
if (y_axis_max=='auto'){
yupper <- max_value + distance * 0.25
}
# reset user's settings when the function is exited
oldpar <- par(no.readonly = TRUE)
on.exit(par(oldpar))
# plot and save
if (output_pdf) {
pdf(pdf_name)
}
plot(self$cond_interval, self$hdcate, ylim=c(ylower,yupper), xlab = self$cond_var_name, ylab="", type="l", lty=1, lwd = 2)
par(new = T)
plot(self$cond_interval, self$vate, ylim=c(ylower,yupper), xlab ="",ylab="", type="l", lty=4, lwd = 1.25)
lty_set <- c(1, 4)
lwd_set <- c(2, 1.25)
if (include_band) {
for (idx in 1:self$n_alpha){
# 2, 3, 5, 6
idx_lty <- c(3, 2, 5, 6)[idx %% 4 + 1]
# if (idx_lty == 0) {
# idx_lty = 6
# }
if (test_side=='left' || test_side=='right') {
par(new = T)
plot(self$cond_interval, cb_matrix[,idx], ylim=c(ylower,yupper), xlab ="", ylab="", type="l", lty=idx_lty)
lty_set <- append(lty_set, idx_lty)
lwd_set <- append(lwd_set, 1)
} else {
# two-side test
par(new = T)
plot(self$cond_interval, cb_matrix[,idx], ylim=c(ylower,yupper), xlab ="", ylab="", type="l", lty=idx_lty)
par(new = T)
plot(self$cond_interval, cb_matrix[,idx+self$n_alpha], ylim=c(ylower,yupper), xlab ="", ylab="", type="l", lty=idx_lty)
lty_set <- append(lty_set, rep(idx_lty, 2))
lwd_set <- append(lwd_set, rep(1,2))
}
}
legend("topright", c(display.hdcate, display.ate, paste0(display.siglevel, '=', self$alpha)), lty=lty_set, lwd=lwd_set, cex=1.2)
} else {
legend("topright", c(display.hdcate, display.ate), lty=lty_set, lwd=lwd_set, cex=1.2)
}
if (output_pdf) {
dev.off()
}
},
get_confidence_bands=function(test_side='both') {
if ((is.null(self$i_two_side) || anyNA(self$i_two_side)) && check_exist) {
stop('Missing results. Make sure you have called `HDCATE.fit()` and `HDCATE.inference()` before calling this method.')
}
if (test_side=='left'){
cb_matrix <- self$i_left_side
} else if (test_side=='right') {
cb_matrix <- self$i_right_side
} else {
# 2-side test
cb_matrix <- self$i_two_side
}
return(cb_matrix)
},
draw_weights=function(method, n_rep_boot, n_obs){
if (method == "bayes") {
weights = stats::rexp(n_rep_boot * n_obs, rate = 1)
} else if (method == "normal") {
weights = stats::rnorm(n_rep_boot * n_obs, mean = 1, sd = 1)
} else if (method == "wild") {
weights = stats::rnorm(n_rep_boot * n_obs) / sqrt(2) +
(stats::rnorm(n_rep_boot * n_obs)^2 - 1) / 2 + 1
} else {
stop("invalid boot method")
}
weights = matrix(weights, nrow = n_rep_boot, ncol = n_obs, byrow = TRUE)
weights = t(weights)
return(weights)
},
set_condition_var=function(name=NA, min=NA, max=NA, step=NA) {
if (!(is.null(name) || anyNA(name))) {
self$cond_var_name <- name
self$d <- length(name)
}
if (!(is.null(min) || anyNA(min))) {
self$cond_var_lower <- min
}
if (!(is.null(max) || anyNA(max))) {
self$cond_var_upper <- max
}
if (!(is.null(step) || anyNA(step))) {
self$cond_var_eval_step <- step
}
# validation
v1 <- (!(is.null(self$cond_var_name) || anyNA(self$cond_var_name)))
v2 <- (!(is.null(self$cond_var_lower) || anyNA(self$cond_var_lower)))
v3 <- (!(is.null(self$cond_var_upper) || anyNA(self$cond_var_upper)))
# step=0.01 by default (defined in hdcate R6 class)
v4 <- (!(is.null(self$cond_var_eval_step) || anyNA(self$cond_var_eval_step)))
if (!(v1 && v2 && v3)) {
warning(paste0('Failed to update conditional variable to: name=', name, ', min=', min, ', max=', max, ', step=', step, '.'))
stop('Missing some of these params for the conditional variable: `name`, `min` and `max`.')
} else {
self$cond_interval <- seq(self$cond_var_lower, self$cond_var_upper, self$cond_var_eval_step)
self$length_grid <- length(self$cond_interval)
message(paste0('Updated conditional variable to: name=', name, ', min=', min, ', max=', max, ', step=', step, '.'))
}
}
),
private = list(
check_data = function(data) {
if (!(is.data.frame(data))) {
# for full-sample estimator
return(self$data)
} else {
# for split-sample estimator
return(data)
}
}
),
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.