Nothing
#' @importFrom geex m_estimate
#' @import caret
G_computation <- R6::R6Class(
"G_computation",
inherit = TEstimator,
#-------------------------public fields-----------------------------#
public = list(
po.est = list(
y1.hat = NULL,
y0.hat = NULL
),
po.est.var = list(
y1.hat.rep = NULL,
y0.hat.rep = NULL
),
resi = NULL,
ps.est = NULL,
id = "G_computation",
initialize = function(df, vars_name, name,
gc.method, gc.formula,
var_approach = "Bias_adjusted", isTrial, ...) {
#browser()
super$initialize(df, vars_name, name)
private$gc.method <- gc.method
private$gc.formula <- gc.formula
private$var_approach <- var_approach
self$model <- private$fit(...)
po.est <- private$est_potentialOutcomes()
self$data$y1.hat <- po.est$y1.hat
self$data$y0.hat <- po.est$y0.hat
self$resi <- private$est_residual()
private$set_ATE()
private$set_CATE(private$outcome_predictors,TRUE)
private$isTrial <- isTrial
self$id <- paste(self$id, private$gc.method, sep = "/")
},
diagnosis_t_ignorability = function(stratification, stratification_joint=TRUE){
#browser()
if(missing(stratification)){
stratification <- private$outcome_predictors
}
residuals.overall <- mean(self$resi)
residuals.subgroups <- private$aggregate_residual(stratification)
if(test_binary(self$data[,private$outcome_name])){
colnames.subgroups <- colnames(residuals.subgroups)
var_names <- colnames.subgroups[!colnames.subgroups %in% c("sample.size","res.mean", "res.se")]
var_names_data <- residuals.subgroups[,var_names]
subgroup_name_level <- apply(var_names_data, 1, function(x) paste(var_names, x, sep = "=", collapse = ","))
subgroup_name_level <- factor(subgroup_name_level, levels = subgroup_name_level, ordered = T)
df <- residuals.subgroups %>%
select(res.mean, res.se, sample.size) %>%
mutate(group=subgroup_name_level,
ci_l=res.mean-1.98*res.se,
ci_u=res.mean+1.98*res.se)
plot.res <- ggplot2::ggplot(data = df, aes(x = res.mean, y = group)) +
geom_point(position = position_dodge(0.5), aes(size=sample.size)) +
geom_errorbar(aes(xmin = ci_l, xmax = ci_u),
width = .3, position = position_dodge(0.5)) +
geom_vline(xintercept = 0, color = "black", linetype = "dashed", alpha = .5) +
ggtitle("model fit: mean of squared residual (1.98se)")+
theme(plot.title = element_text(),
legend.position = "none")
out <- list(residuals.overall = residuals.overall,
residuals.subgroups = residuals.subgroups,
plot.res = plot.res)
} else {
colnames.subgroups <- colnames(residuals.subgroups)
var_names <- colnames.subgroups[!colnames.subgroups %in% c("sample.size","res.mean", "res.se","msr","sesr")]
var_names_data <- residuals.subgroups[,var_names]
subgroup_name_level <- apply(var_names_data, 1, function(x) paste(var_names, x, sep = "=", collapse = ","))
subgroup_name_level <- factor(subgroup_name_level, levels = subgroup_name_level, ordered = T)
df <- residuals.subgroups %>%
select(res.mean, res.se, sample.size) %>%
mutate(group=subgroup_name_level,
ci_l=res.mean-1.98*res.se,
ci_u=res.mean+1.98*res.se)
plot.res.subgroup <- ggplot2::ggplot(data = df, aes(x = res.mean, y = group)) +
geom_point(position = position_dodge(0.5), aes(size=sample.size)) +
geom_errorbar(aes(xmin = ci_l, xmax = ci_u),
width = .3, position = position_dodge(0.5)) +
geom_vline(xintercept = 0, color = "black", linetype = "dashed", alpha = .5) +
ggtitle("mean(1.98se) residual")+
theme(plot.title = element_text(),
legend.position = "none")
df <- residuals.subgroups %>%
select(msr, sesr, sample.size) %>%
mutate(group=subgroup_name_level,
ci_l=msr-1.98*sesr,
ci_u=msr+1.98*sesr)
plot.msr.subgroup <- ggplot2::ggplot(data = df, aes(x = msr, y = group)) +
geom_point(position = position_dodge(0.5), aes(size=sample.size)) +
geom_errorbar(aes(xmin = ci_l, xmax = ci_u),
width = .3, position = position_dodge(0.5)) +
geom_vline(xintercept = 0, color = "black", linetype = "dashed", alpha = .5) +
ggtitle("mean squared error")+
theme(plot.title = element_text(),
legend.position = "none")
df <- data.frame(residual=self$resi)
plot.res.overall <- ggplot(data = df, aes(x=residual)) + geom_density()
plot.res.mse <- ggpubr::ggarrange(plot.res.subgroup, plot.res.overall, plot.msr.subgroup, ncol = 3, nrow = 1)
out <- list(residuals.overall = residuals.overall,
residuals.subgroups = residuals.subgroups,
plot.res.mse = plot.res.mse)
}
out
}
),
#-------------------------private fields and methods----------------------------#
private = list(
gc.method = NULL,
gc.formula = NULL,
#treatment_method = "glm",
var_approach = NULL,
iterations = 1,
fit = function(...) {
if (private$gc.method == "BART"){
x.train <- as.matrix(self$data[, c(private$outcome_predictors, private$treatment_name)])
y.train <- as.matrix(self$data[, private$outcome_name])
model <- BART::wbart(
x.train = x.train,
y.train = y.train,
...
)
} else {
if (is.null(private$gc.formula)) {
model <- caret::train(
x = self$data[, c(private$outcome_predictors, private$treatment_name)],
y = self$data[, private$outcome_name],
method = private$gc.method,
...
)
} else {
print(private$gc.formula)
model <- caret::train(
form = private$gc.formula,
data = self$data,
method = private$gc.method,
...
)
}
}
return(model)
},
est_ATE_SE = function(index) {
#browser()
y1.hat <- self$data$y1.hat[index]
y0.hat <- self$data$y0.hat[index]
data <- as.data.frame(cbind(y1.hat, y0.hat))
colnames(data) <- c("y1.hat","y0.hat")
results <- m_estimate(estFUN = private$gc_estfun,
data = data,
root_control = setup_root_control(start = c(0,0,0)))
y1.hat.mu <- results@estimates[1]
y0.hat.mu <- results@estimates[2]
est <- results@estimates[3]
se <- sqrt(results@vcov[3,3])
return(list(y1.hat = y1.hat.mu, y0.hat = y0.hat.mu, est = est, se = se))
},
est_weighted_ATE_SE = function(index, weight) {
#browser()
y1.hat <- self$data$y1.hat[index]*weight
y0.hat <- self$data$y0.hat[index]*weight
data <- as.data.frame(cbind(y1.hat, y0.hat))
colnames(data) <- c("y1.hat","y0.hat")
results <- m_estimate(estFUN = private$gc_estfun,
data = data,
root_control = setup_root_control(start = c(0,0,0)))
y1.hat.mu <- results@estimates[1]
y0.hat.mu <- results@estimates[2]
est <- results@estimates[3]
se <- sqrt(results@vcov[3,3])
return(list(y1.hat = y1.hat.mu, y0.hat = y0.hat.mu, est = est, se = se))
},
gc_estfun = function(data){
#browser()
Y1 <- data$y1.hat
Y0 <- data$y0.hat
function(theta) {
c(Y1 - theta[1],
Y0 - theta[2],
theta[1]-theta[2] - theta[3]
)
}
},
# compute deviance for continuous outcome/binary outcome
est_residual = function() {
#browser()
if (inherits(self$data[, private$outcome_name],"numeric")) {
resi <- residuals(self$model)
} else {
y <- self$data[, private$outcome_name]
y <- as.numeric(levels(y))[y]
y.hat <- predict(self$model, newdata = self$data, type = "prob")
y.hat.0 <- y.hat[,"0"]
y.hat.1 <- y.hat[,"1"]
resi <- -2 * (y * log(y.hat.1) + (1-y)*log(y.hat.0))
}
return(resi)
},
est_potentialOutcomes = function() {
#browser()
data0 <- data1 <- self$data[, c(private$outcome_predictors, private$treatment_name)]
t.level <- unique(self$data[, private$treatment_name])
level.order <- order(t.level)
data0[, private$treatment_name] <- t.level[match(1, level.order)]
data1[, private$treatment_name] <- t.level[match(2, level.order)]
if (!is.factor(self$data[, private$outcome_name])) {
y1.hat <- predict(self$model, newdata = data1)
y0.hat <- predict(self$model, newdata = data0)
} else {
y1.hat <- predict(self$model, newdata = data1, type = "prob")[, 2]
y0.hat <- predict(self$model, newdata = data0, type = "prob")[, 2]
}
return(list(y1.hat = y1.hat, y0.hat = y0.hat))
},
est_ps = function() {
ps <- predict(self$model$treatment, newdata = self$data, type = "prob")[, 2]
return(ps)
},
aggregate_residual = function(stratification) {
group_data <- self$data %>%
group_by(across(all_of(stratification)))
group_strata <- group_data %>% group_keys()
group_id <- group_data %>% group_indices()
n_groups <- dim(group_strata)[1]
group_sample_size <- group_size(group_data)
if(test_binary(self$data[,private$outcome_name])){
res.mean <- res.se <- sample.size <- NULL
for (i in seq(n_groups)) {
# for binary outcome, how to evaluate model fit?
subgroup.id.in.data <- self$data[group_id == i, "id"]
res.mean[i] <- sum(self$resi[subgroup.id.in.data])/group_sample_size[i]
res.se[i] <- sqrt(var(self$resi[subgroup.id.in.data])/group_sample_size[i])
sample.size[i] <- group_sample_size[i]
}
res <- cbind(group_strata, sample.size, res.mean, res.se)
} else {
res.mean <- res.se <- sample.size <- msr <- sesr <- NULL
for (i in seq(n_groups)) {
subgroup.id.in.data <- self$data[group_id == i, "id"]
res.mean[i] <- sum(self$resi[subgroup.id.in.data])/group_sample_size[i]
res.se[i] <- sqrt(var(self$resi[subgroup.id.in.data])/group_sample_size[i])
# mean of squared residual
msr[i] <- sum((self$resi[subgroup.id.in.data])^2)/group_sample_size[i]
sesr[i] <- sqrt(var((self$resi[subgroup.id.in.data])^2)/group_sample_size[i])
sample.size[i] <- group_sample_size[i]
}
res <- cbind(group_strata, sample.size, res.mean, res.se, msr, sesr)
}
res <- as.data.frame(res)
return(res)
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.