#' CATE
#'
#' This function estimates heterogeneous treatment effects (HTEs) defined as E(Y^1 - Y^0 | V = v0).
#' @param data A data frame containing the dataset.
#' @param learner A character string specifying which learner to use (e.g., "dr").
#' @param x_names A character vector specifying the names of the confouding variables.
#' @param y_name A character string specifying the outcome variable.
#' @param a_name A character string specifying the treatment variable.
#' @param v_names A character vector specifying the names of the effect modifiers.
#' @param v0 A matrix of evaluation points, i.e., values of V for which the CATE is estimated (E(Y^1 - Y^0 | V = v0)).
#' @param mu1.x A function taking arguments (y, a, x, new.x). It trains a model estimating
#' E(Y | A = 1, X) and returns a list of 3 elements: res, model and fit. \emph{res} is a vector of predictions of the model evaluated at
#' new.x, \emph{model} is the model object used to estimate E(Y | A = 1, X) and \emph{fit} is a function with argument new.x that
#' returns the predictions of the model. See examples.
#' @param mu0.x A function taking arguments (y, a, x, new.x). It trains a model estimating
#' E(Y | A = 0, X) and returns a list of 3 elements: res, model and fit. \emph{res} is a vector of predictions of the model evaluated at
#' new.x, \emph{model} is the model object used to estimate E(Y | A = 0, X) and \emph{fit} is a function with argument new.x that
#' returns the predictions of the model. See examples.
#' @param pi.x A function taking arguments (a, x, new.x). It trains a model estimating
#' P(A = 1 | X) and returns a list of 3 elements: res, model and fit. \emph{res} is a vector of predictions of the model evaluated at
#' new.x, \emph{model} is the model object used to estimate P(A = 1 | X) and \emph{fit} is a function with argument new.x that
#' returns the predictions of the model. See examples.
#' @param drl.v A function taking arguments (pseudo, v, new.v). It trains a model estimating E(Y^1 - Y^0 | V) by
#' regressing a pseudo-outcome \emph{pseudo} onto v and returns a list of 3 elements: res, model and fit.
#' \emph{res} is a vector of predictions of the model evaluated at new.v,
#' \emph{model} is the model object used to estimate E(Y^1 - Y^0 | V) (after possibly model selection)
#' and \emph{fit} is a function with argument new.v that returns the predictions of the model. See examples.
#' #' @param drl.x A function taking arguments (pseudo, x, new.x). It trains a model estimating E(Y^1 - Y^0 | X) by
#' regressing a pseudo-outcome \emph{pseudo} onto x and returns a list of 3 elements: res, model and fit.
#' \emph{res} is a vector of predictions of the model evaluated at new.x,
#' \emph{model} is the model object used to estimate E(Y^1 - Y^0 | X) (after possibly model selection)
#' and \emph{fit} is a function with argument new.v that returns the predictions of the model. See examples.
#' @param nsplits An integer indicating the number of splits used for cross-validation. Ignored if foldid is specified.
#' @param foldid An optional vector specifying fold assignments for cross-validation.
#' @param univariate_reg A logical indicating whether to perform univariate regression for estimating the CATE
#' as a function of each effect modifier separately (default: FALSE).
#' @param partial_dependence A logical indicating whether to compute partial dependence plots (default: FALSE).
#' @param partially_linear A logical indicating whether to compute partially linear approximations
#' via Robinson's transformation (default: FALSE).
#' @param additive_approx A logical indicating whether to compute the CATE assuming an additive structure (default: FALSE).
#' @param variable_importance A logical indicating whether to compute variable importance measures (default: FALSE).
#' @param vimp_num_splits An integer specifying the number of splits for variable importance computation (default: 1).
#' @param bw.stage2 A list of length equal to the number of effect modifiers considered, where each element if a vector of
#' candidate bandwidths for second-stage regression of the pseudo-outcome onto the effect modifier
#' that calculates either the univariate CATE or the Partial Dependence measure (default: NULL).
#' It needs to be provided if \emph{univariate_reg} or \emph{partial_dependence} is set to TRUE.
#' @param sample.split.cond.dens A logical indicating whether to do sample-splitting for conditional density estimation
#' (default: FALSE).
#' @param cond.dens A function
#' @param cate.w A function
#' @param cate.not.j A function
#' @param reg.basis.not.j A function
#' @param pl.dfs A list of length equal to the number of effect modifiers considered, where each element is a vector of
#' candidate number of basis elements for the partially linear approximation computed via Robinson trick.
#' @return A list containing the estimated CATE at v0 and per-fold estimates of the CATE at v0 for each learner.
#' @export
#' @references Kennedy, EH. (2020). Optimal Doubly Robust Estimation of
#' Heterogeneous Causal Effects. \emph{arXiv preprint arXiv:2004.14497}.
cate <- function(data, learner, x_names, y_name, a_name, v_names, v0,
mu1.x, mu0.x, pi.x, drl.v, drl.x,
nsplits=5,
foldid=NULL,
univariate_reg=FALSE,
partial_dependence=FALSE,
partially_linear=FALSE,
additive_approx=FALSE,
variable_importance=FALSE,
vimp_num_splits=1,
bw.stage2=NULL,
sample.split.cond.dens=FALSE,
cond.dens=NULL,
cate.w=NULL,
cate.not.j=NULL,
reg.basis.not.j=NULL,
pl.dfs=NULL) {
if(any(learner != "dr")) stop("Only learner = dr is currently implemented.")
dta <- get_input(data=data, x_names=x_names, y_name=y_name,
a_name=a_name, v_names=v_names, v0=v0)
a <- dta$a
v <- dta$v
v0.long <- dta$v0
v0.short <- dta$unique.v0
y <- dta$y
x <- dta$x
n <- length(y)
n.eval.pts <- nrow(v0.long)
n.eff.modif <- ncol(v)
if(is.null(foldid)) {
s <- sample(rep(1:nsplits, ceiling(n/nsplits))[1:n])
} else {
s <- foldid
nsplits <- length(unique(foldid))
}
est <- est.pi <- replicate(length(learner),
array(NA, dim=c(n.eval.pts, 3, nsplits)),
simplify=FALSE)
univariate_res <- pd_res <- additive_res <- robinson_res <-
replicate(length(learner), vector("list", n.eff.modif), simplify=FALSE)
cate.w.fit <- stage2.reg.data.pd <-
replicate(n.eff.modif, vector("list", nsplits), simplify=FALSE)
pseudo.y <- replicate(length(learner), rep(NA, n), simplify=FALSE)
pseudo.y.tr <- ites.x.tr <-
replicate(length(learner), vector("list", nsplits), simplify=FALSE)
pseudo.y.pd <- theta.bar <- cond.dens.vals <- cate.w.vals <-
replicate(length(learner), matrix(NA, ncol=n.eff.modif, nrow=n),
simplify=FALSE)
ites_v <- ites_x <- replicate(length(learner),
matrix(NA, ncol=3, nrow=n), simplify=FALSE)
names(est) <- names(est.pi) <- names(pseudo.y) <- names(ites_v) <-
names(ites_x) <- names(pseudo.y.pd) <- names(theta.bar) <-
names(cond.dens.vals) <- names(cate.w.vals) <- names(pseudo.y.tr) <- learner
stage2.reg.data.v <- stage2.reg.data.x <- reg.model <-
vector("list", nsplits)
tmp <- tryCatch(
{
for(k in 1:nsplits) {
print(paste0("Considering split # ", k, " out of ", nsplits))
test.idx <- k == s
train.idx <- k != s
if(all(!train.idx)) train.idx <- test.idx
n.te <- sum(test.idx)
n.tr <- sum(train.idx)
x.tr <- x[train.idx, , drop=FALSE]
v.tr <- v[train.idx, , drop=FALSE]
a.tr <- a[train.idx]
y.tr <- y[train.idx]
x.te <- x[test.idx, , drop=FALSE]
v.te <- v[test.idx, , drop=FALSE]
a.te <- a[test.idx]
y.te <- y[test.idx]
## Estimate nuisance functions using all folds but k and predict on fold k ##
pihat.vals <- pi.x(a=a.tr, x=x.tr, new.x=rbind(x.te, x.tr))$res
pihat.te <- pihat.vals[1:n.te]
pihat.tr <- pihat.vals[-c(1:n.te)]
mu0hat.vals <- mu0.x(y=y.tr, a=a.tr, x=x.tr, new.x=rbind(x.te, x.tr))$res
mu0hat.te <- mu0hat.vals[1:n.te]
mu0hat.tr <- mu0hat.vals[-c(1:n.te)]
mu1hat.vals <- mu1.x(y=y.tr, a=a.tr, x=x.tr, new.x=rbind(x.te, x.tr))$res
mu1hat.te <- mu1hat.vals[1:n.te]
mu1hat.tr <- mu1hat.vals[-c(1:n.te)]
cate.tr <- mu1hat.tr-mu0hat.tr
cate.te <- mu1hat.te-mu0hat.te
for(alg in learner) {
## compute IF values, i.e., the pseudo-outcomes for the DR-Learner ##
pseudo.te <- (a.te-pihat.te)/(pihat.te*(1-pihat.te)) *
(y.te-a.te*mu1hat.te - (1-a.te)*mu0hat.te) + cate.te
pseudo.tr <- (a.tr-pihat.tr)/(pihat.tr*(1-pihat.tr)) *
(y.tr-a.tr*mu1hat.tr - (1-a.tr)*mu0hat.tr) + cate.tr
drl.v.out <- drl.v(pseudo=pseudo.te, v=v.te, new.v=rbind(v0.long, v.te))
drl.v.out.pi <- drl.v(pseudo=cate.te, v=v.te, new.v=rbind(v0.long, v.te)) # plug-in
stage2.reg.data.v[[k]] <- cbind(data.frame(pseudo=pseudo.te,
mu1hat=mu1hat.te,
mu0hat=mu0hat.te,
pihat=pihat.te,
y=y.te,
a=a.te,
fold.id=k), v.te)
stage2.reg.data.x[[k]] <- cbind(data.frame(pseudo=pseudo.te,
mu1hat=mu1hat.te,
mu0hat=mu0hat.te,
pihat=pihat.te,
y=y.te,
a=a.te,
fold.id=k), x.te)
# drl.form[[k]] <- drl.res$drl.form
reg.model[[k]] <- drl.v.out$model
drl.v.res <- drl.v.out$res
drl.v.res.pi <- drl.v.out.pi$res
drl.x.res <- drl.x(pseudo=pseudo.tr, x=x.tr, new.x=rbind(x.te, x.tr))$res
est[[alg]][, , k] <- drl.v.res[1:n.eval.pts, ]
est.pi[[alg]][, , k] <- drl.v.res.pi[1:n.eval.pts, ]
pseudo.y[[alg]][test.idx] <- pseudo.te
pseudo.y.tr[[alg]][[k]] <- pseudo.tr
ites.x.tr[[alg]][[k]] <- drl.x.res[-c(1:n.te), 1]
ites_v[[alg]][test.idx, ] <- drl.v.res[-c(1:n.eval.pts), ]
ites_x[[alg]][test.idx, ] <- drl.x.res[1:n.te, ]
if(partial_dependence) {
for(j in 1:ncol(v)) {
v1.j.tr <- v.tr[, j]
v1.j.te <- v.te[, j]
not.v1.j.tr <- v.tr[, -j, drop=FALSE]
not.v1.j.te <- v.te[, -j, drop=FALSE]
w.tr <- cbind(v1j=v1.j.tr, not.v1.j.tr)
w.te <- cbind(v1j=v1.j.te, not.v1.j.te)
if(sample.split.cond.dens){
cond.dens.fit <- cond.dens[[j]](v1=v1.j.tr, v2=v2.not.v1.j.tr)
cond.dens.vals[[alg]][test.idx, j] <-
cond.dens.fit$predict.cond.dens(v1=v1.j.tr, v2=not.v1.j.tr,
new.v1=v1.j.te, new.v2=not.v1.j.te)
if(sum(cond.dens.vals.te < 0.001) > 0) {
warning(paste0("Effect modifier # ", j, ". There are ",
sum(cond.dens.vals.te < 0.001),
" conditional density values < 0.01. They will ",
"truncated at 0.01."))
cond.dens.vals.te[cond.dens.vals.te < 0.001] <- 0.01
}
}
cate.w.fit[[j]][[k]] <- cate.w[[j]](tau=cate.tr, w=w.tr, new.w=w.tr)
cate.w.te <- cate.w.fit[[j]][[k]]$fit(new.w=w.te)
cate.w.vals[[alg]][test.idx, j] <- cate.w.te
if(n.te > 1000) {
if(is.factor(v1.j.te)) {
v1.j.seq <- factor(levels(v1.j.te), levels=levels(v1.j.te))
}
else {
v1.j.seq <- seq(min(v1.j.te), max(v1.j.te), length.out=100)
}
tmp.cate.w.fit.fn <- Vectorize(function(u) {
mean(cate.w.fit[[j]][[k]]$fit(new.w=cbind(v1j=u, not.v1.j.te)))
}, vectorize.args = "u")
cate.w.avg.vals <- tmp.cate.w.fit.fn(v1.j.seq)
if(is.factor(v1.j.te)) {
theta.bar.vals <- rep(NA, length(v1.j.te))
for(u in levels(v1.j.te)) {
theta.bar.vals[v1.j.te==u] <- cate.w.avg.vals[v1.j.seq==u]
}
} else {
theta.bar.vals <- approx(x=v1.j.seq, y=cate.w.avg.vals,
xout=v1.j.te, rule=2)$y
}
}
else {
w.long.test <- cbind(v1j=rep(v1.j.te, each=n.te),
not.v1.j.te[rep(1:n.te, n.te), , drop=FALSE])
if(sample.split.cond.dens) {
cond.dens.preds <-
cond.dens.fit$predict.cond.dens(v1=v1.j.tr, v2=not.v1.j.tr,
new.v1=w.long.test[, 1],
new.v2=w.long.test[, -1, drop=FALSE])
marg.dens <- colMeans(matrix(cond.dens.preds, ncol=n.te, nrow=n.te))
}
cate.preds <- cate.w.fit[[j]][[k]]$fit(new.w=w.long.test)
theta.bar.vals <- colMeans(matrix(cate.preds, nrow=n.te, ncol=n.te))
}
theta.bar[[alg]][test.idx, j] <- theta.bar.vals
data.pd <- data.frame(pseudo.cate=pseudo.te,
mu1hat=mu1hat.te,
mu0hat=mu0hat.te,
pihat=pihat.te,
tauhat.w=cate.w.te,
theta.bar=theta.bar.vals,
y=y.te,
a=a.te,
fold.id=k)
stage2.reg.data.pd[[j]][[k]] <- cbind(data.pd, v.te)
}
}
}
}
print("Done with fitting nuisance functions.")
},
error = function(cond) {
message(conditionMessage(cond))
}
)
if(is.null(tmp)){
warning("Encountered error while fitting nuisance functions")
univ.res <- pd.res <- add.res <- rob.res <-
data.frame(theta=rep(NA, 10), theta.debias=rep(NA, 10))
return(list(univariate_res=list(dr=list(list(res=univ.res))),
pd_res=list(dr=list(list(res=pd.res))),
additive_res=list(dr=list(list(res=add.res))),
robinson_res=list(dr=list(list(res=rob.res)))))
}
for(alg in learner) {
if(alg != "dr") stop("Only learner = dr is currently implemented.")
if(additive_approx) {
additive_model <- drl.basis.additive(y=pseudo.y[[alg]], x=v, new.x=v)
tt <- delete.response(terms(additive_model$model))
}
if(univariate_reg | partial_dependence | additive_approx | partially_linear) {
for(j in 1:ncol(v)){
vj <- v[, j]
is.var.factor <- paste0(class(vj), collapse=" ") %in% c("factor", "ordered factor")
if(partially_linear) {
j.robinson <- robinson(pseudo=pseudo.y[[alg]],
w=v[, -j, drop=FALSE],
v=v[, j],
new.v=v0.short[[j]],
s=s,
cate.not.j=cate.not.j[[j]],
reg.basis.not.j=reg.basis.not.j[[j]],
dfs=pl.dfs[[j]])
rob.data <- data.frame(eval.pts=v0.short[[j]],
theta=j.robinson$res$preds,
ci.ll.pts=j.robinson$res$ci.ll,
ci.ul.pts=j.robinson$res$ci.uu,
ci.ll.unif=NA,
ci.ul.unif=NA)
robinson_res[[alg]][[j]] <- list(res=rob.data,
model=j.robinson$model,
risk=j.robinson$risk,
fits=j.robinson$fits)
}
if(partial_dependence) {
if(!sample.split.cond.dens) {
cond.dens.fit <- cond.dens[[j]](v1=v[, j], v2=v[, -j, drop=FALSE])
cond.dens.vals[[alg]][, j] <-
cond.dens.fit$predict.cond.dens(v1=v[, j], v2=v[, -j, drop=FALSE],
new.v1=v[, j],
new.v2=v[, -j, drop=FALSE])
if(sum(cond.dens.vals[[alg]][, j] < 0.001) > 0) {
warning(paste0("Effect modifier # ", j, ". There are ",
sum(cond.dens.vals[[alg]][, j] < 0.001),
" conditional density values < 0.001. They will ",
"truncated at 0.001."))
cond.dens.vals[[alg]][cond.dens.vals[[alg]][, j] < 0.001, j] <- 0.001
}
}
if(length(unique(vj)) < 15 | is.var.factor) {
marg.dens <- rep(NA, length(vj))
for(u in unique(vj)) marg.dens[vj==u] <- mean(vj==u)
} else {
marg.dens <- ks::kde(x=vj, eval.points=vj, density=TRUE)$estimate
}
ghat <- marg.dens/cond.dens.vals[[alg]][, j]
pseudo.y.pd[[alg]][, j] <-
(pseudo.y[[alg]]-cate.w.vals[[alg]][, j])*ghat + theta.bar[[alg]][, j]
}
if(length(unique(vj)) < 15 | is.var.factor) {
if(univariate_reg) {
res.empVar <- NULL
for(ll in 1:length(unique(vj))) {
pts.vj <- unique(vj)[ll]
if.vals <- pseudo.y[[alg]]*I(vj==pts.vj) / mean(I(vj==pts.vj))
tmp <- data.frame(eval.pts=pts.vj,
theta=mean(if.vals),
ci.ll.pts=mean(if.vals) - 1.96*sqrt(var(if.vals)/n),
ci.ul.pts=mean(if.vals) + 1.96*sqrt(var(if.vals)/n),
ci.ul.unif=NA,
ci.ll.unif=NA)
res.empVar <- rbind(res.empVar, tmp)
}
univariate_res[[alg]][[j]] <-
list(data=data.frame(pseudo=pseudo.y[[alg]], exposure=vj),
res=lm.discrete.v(y=pseudo.y[[alg]], x=vj, new.x=unique(vj)),
res.empVar=res.empVar)
}
if(partial_dependence) {
res.empVar <- NULL
for(ll in 1:length(unique(vj))) {
pts.vj <- unique(vj)[ll]
if.vals <- (pseudo.y[[alg]]-cate.w.vals[[alg]][, j]) *
I(vj==pts.vj)/cond.dens.vals[[alg]][, j] +
mean(theta.bar[[alg]][I(vj==pts.vj), j])
tmp <- data.frame(eval.pts=pts.vj,
theta=mean(if.vals),
ci.ll.pts=mean(if.vals) - 1.96*sqrt(var(if.vals)/n),
ci.ul.pts=mean(if.vals) + 1.96*sqrt(var(if.vals)/n),
ci.ul.unif=NA,
ci.ll.unif=NA)
res.empVar <- rbind(res.empVar, tmp)
}
pd_res[[alg]][[j]] <-
list(data=data.frame(pseudo=pseudo.y.pd[[alg]][, j], exposure=vj),
stage2.reg.data.pd=stage2.reg.data.pd,
res=lm.discrete.v(y=pseudo.y.pd[[alg]][, j], x=vj, new.x=unique(vj)),
res.empVar=res.empVar)
}
}
else {
if(univariate_reg) {
univ.inf <- debiased_inference(A=vj,
pseudo.out=pseudo.y[[alg]],
eval.pts=v0.short[[j]],
debias=FALSE,
bandwidth.method="LOOCV",
kernel.type="gau",
bw.seq=bw.stage2[[j]])
univ.debias.inf <- debiased_inference(A=vj,
pseudo.out=pseudo.y[[alg]],
eval.pts=v0.short[[j]],
debias=TRUE,
bandwidth.method="LOOCV",
kernel.type="gau",
bw.seq=bw.stage2[[j]])
univ.res <- data.frame(eval.pts=univ.inf$res$eval.pts,
theta=univ.inf$res$theta,
theta.debias=univ.debias.inf$res$theta,
ci.ul.pts=univ.inf$res$ci.ul.pts,
ci.ll.pts=univ.inf$res$ci.ll.pts,
ci.ul.pts.debias=univ.debias.inf$res$ci.ul.pts,
ci.ll.pts.debias=univ.debias.inf$res$ci.ll.pts,
ci.ul.unif=univ.inf$res$ci.ul.unif,
ci.ll.unif=univ.inf$res$ci.ll.unif,
ci.ul.unif.debias=univ.debias.inf$res$ci.ul.unif,
ci.ll.unif.debias=univ.debias.inf$res$ci.ll.unif,
bias=univ.inf$res$theta-univ.debias.inf$res$theta,
if.val.sd=univ.inf$res$if.val.sd,
if.val.sd.debias=univ.debias.inf$res$if.val.sd,
unif.quantile=univ.inf$res$unif.quantile,
unif.quantile.debias=univ.debias.inf$res$unif.quantile,
h=univ.inf$res$h,
b=univ.inf$res$b,
h.debias=univ.debias.inf$res$h,
b.debias=univ.debias.inf$res$b)
univariate_res[[alg]][[j]] <-
list(data=data.frame(pseudo=pseudo.y[[alg]], exposure=vj),
res=univ.res,
risk=list(risk=univ.inf$risk, risk.debias=univ.debias.inf$risk),
res.list=univ.inf$res.list,
res.list.debias=univ.debias.inf$res.list)
}
if(partial_dependence) {
muhat.vals <- .get.muhat(splits.id=s, cate.w.fit=cate.w.fit[[j]],
v1=vj, v2=v[, -j, drop=FALSE],
max.n.integral=1000)
pd.inf <- debiased_inference(A=vj, debias=FALSE,
pseudo.out=pseudo.y.pd[[alg]][, j],
eval.pts=v0.short[[j]],
mhat.obs=theta.bar[[alg]][, j],
muhat.vals=muhat.vals,
bandwidth.method="LOOCV",
kernel.type="gau",
bw.seq=bw.stage2[[j]])
pd.debias.inf <- debiased_inference(A=vj, debias=TRUE,
pseudo.out=pseudo.y.pd[[alg]][, j],
eval.pts=v0.short[[j]],
mhat.obs=theta.bar[[alg]][, j],
muhat.vals=muhat.vals,
bandwidth.method="LOOCV",
kernel.type="gau",
bw.seq=bw.stage2[[j]])
pd.inf.res <- data.frame(eval.pts=pd.inf$res$eval.pts,
theta=pd.inf$res$theta,
theta.debias=pd.debias.inf$res$theta,
ci.ul.pts=pd.inf$res$ci.ul.pts,
ci.ll.pts=pd.inf$res$ci.ll.pts,
ci.ul.pts.debias=pd.debias.inf$res$ci.ul.pts,
ci.ll.pts.debias=pd.debias.inf$res$ci.ll.pts,
ci.ul.unif=pd.inf$res$ci.ul.unif,
ci.ll.unif=pd.inf$res$ci.ll.unif,
ci.ul.unif.debias=pd.debias.inf$res$ci.ul.unif,
ci.ll.unif.debias=pd.debias.inf$res$ci.ll.unif,
bias=pd.inf$res$theta-pd.debias.inf$res$theta,
if.val.sd=pd.inf$res$if.val.sd,
if.val.sd.debias=pd.debias.inf$res$if.val.sd,
unif.quantile=pd.inf$res$unif.quantile,
unif.quantile.debias=pd.debias.inf$res$unif.quantile,
h=pd.inf$res$h,
b=pd.inf$res$b,
h.debias=pd.debias.inf$res$h,
b.debias=pd.debias.inf$res$b)
pd_res[[alg]][[j]] <-
list(data=data.frame(pseudo=pseudo.y.pd[[alg]][, j],
cond.dens.vals=cond.dens.vals[[alg]][, j],
exposure=vj),
res=pd.inf.res,
risk=list(risk=pd.inf$risk, risk.debias=pd.debias.inf$risk),
res.list=pd.inf$res.list,
res.list.debias=pd.debias.inf$res.list)
}
}
if(additive_approx){
new.dat.additive <- as.data.frame(matrix(0, nrow=length(v0.short[[j]]),
ncol=ncol(v),
dimnames=list(NULL, colnames(v))))
for(l in 1:ncol(v)) {
if(l == j) {
new.dat.additive[, l] <- v0.short[[j]]
} else {
if(is.factor(v[, l])) {
new.dat.additive[, l] <- factor(levels(v[, l])[1], levels=levels(v[, l]))
} else {
new.dat.additive[, l] <- min(v[, l])
}
}
}
preds.j.additive <- predict.lm(additive_model$model, newdata=new.dat.additive)
m <- model.frame(tt, new.dat.additive)
design.mat <- model.matrix(tt, m)
design.mat[, apply(design.mat, 2, function(u) length(unique(u))==1)] <- 0
preds.j.additive <- design.mat %*% coef(additive_model$model)
beta.vcov <- sandwich::vcovHC(additive_model$model)
sigma2hat <- diag(design.mat %*% beta.vcov %*% t(design.mat))
ci.l <- preds.j.additive-1.96*sqrt(sigma2hat)
ci.u <- preds.j.additive+1.96*sqrt(sigma2hat)
additive_res[[alg]][[j]] <- list(res=data.frame(eval.pts=v0.short[[j]],
theta=preds.j.additive,
ci.ll.pts=ci.l,
ci.ul.pts=ci.u,
ci.ll.unif=NA,
ci.ul.unif=NA),
drl.form=additive_model$drl.form,
model=additive_model$model,
risk=additive_model$risk,
res.list=additive_model$fits)
}
}
}
if(variable_importance){
print('start calculating vimp')
tau_hat <- ites_x[[alg]][,1]
pseudo_hat <- pseudo.y[[alg]]
vimp_df <- get_VIMP(tau_hat, pseudo_hat, x, y, a, v_names,
vimp_num_splits=vimp_num_splits, option=option)
draw_VIMP(vimp_df)
} else{
vimp_df <- data.frame(matrix(ncol = 4, nrow = length(v_names)))
}
}
out <- lapply(learner, function(w) apply(est[[w]], c(1, 2), mean))
cate.v.res <- list(est=out,
fold.est=est,
fold.est.pi=est.pi,
pseudo=pseudo.y,
pseudo.tr=pseudo.y.tr,
cate.v.sample=ites_v,
v0.long=v0.long,
v0.short=v0.short,
stage2.reg.data.v=stage2.reg.data.v,
stage2.reg.model=reg.model)
cate.x.res <- list(pseudo=pseudo.y,
pseudo.tr=pseudo.y.tr,
cate.x.sample=ites_x,
cate.x.sample.tr=ites.x.tr,
stage2.reg.data.x=stage2.reg.data.x)
ret <- list(cate.v.res=cate.v.res,
cate.x.res=cate.x.res,
univariate.res=univariate_res,
pd.res=pd_res,
additive.res=additive_res,
robinson.res=robinson_res,
vimp.df=vimp_df,
v0.long=v0.long, v0.short=v0.short,
foldid=s,
x=x, y=y, a=a, v=v,
drl.x=drl.x, drl.v=drl.v)
return(ret)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.