Description Usage Arguments Examples
Estimation of the Optimal Treatment rule using Super Learner and mean performance using CV-TMLE To avoid nesting cross-validation, it uses split-specfic estimates of Q and g to estimate the rule, and 'split-specific' estimates of the rule in CV-TMLE to estimate mean performance
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | opt_tmle(data, Wnodes = grep("^W", names(data), value = TRUE), Anode = "A",
Ynode = "Y", Vnodes = Wnodes, stratifyAY = TRUE,
SL.library = opt_tmle.SL.library, verbose = 3, parallel = FALSE,
perf_tmle = TRUE, perf_dripcw = FALSE, perf_cv = TRUE,
perf_full = FALSE, maximize = TRUE, ...)
## S3 method for class 'opt_tmle'
print(obj)
## S3 method for class 'opt_tmle'
plot(obj)
## S3 method for class 'opt_tmle'
predict(obj, ...)
|
data |
data.frame containing the relevant variable |
Wnodes, |
vector of column names indicating covariates |
Anode, |
column name of treatment |
Ynode, |
column name of outcome |
Vnodes, |
vector of column names to base the treatment on |
stratifyAY, |
logical: should we stratify the cross-validation based on (A,Y) pairs |
verbose, |
integer that controls the verbosity of the output (higher is more verbose) |
parallel, |
logical: should foreach parallelization be used? |
SL_library, |
list of SuperLearner libraries for the various models. See |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | Qbar0 <- function(A, W) {
W1 <- W[, 1]
W2 <- W[, 2]
W3 <- W[, 3]
W4 <- W[, 4]
Qbar <- (1/2) * (plogis(-5 * (A == 2) * (W1 + 0.5) + 5 * (A == 3) * (W1 - 0.5)) + plogis(W2 * W3))
return(Qbar)
}
g0 <- function(W) {
W1 <- W[, 1]
W2 <- W[, 2]
W3 <- W[, 3]
W4 <- W[, 4]
# rep(0.5, nrow(W))
scale_factor <- 0.8
A1 <- plogis(scale_factor * W1)
A2 <- plogis(scale_factor * W2)
A3 <- plogis(scale_factor * W3)
A <- cbind(A1, A2, A3)
# make sure A sums to 1
A <- normalize_rows(A)
}
gen_data <- function(n = 1000, p = 4) {
W <- matrix(rnorm(n * p), nrow = n)
colnames(W) <- paste("W", seq_len(p), sep = "")
g0W <- g0(W)
A <- factor(apply(g0W, 1, function(pAi) which(rmultinom(1, 1, pAi) == 1)))
A_vals <- vals_from_factor(A)
u <- runif(n)
Y <- as.numeric(u < Qbar0(A, W))
Q0aW <- sapply(A_vals, Qbar0, W)
d0 <- max.col(Q0aW)
Yd0 <- as.numeric(u < Qbar0(d0, W))
df <- data.frame(W, A, Y, d0, Yd0)
df$g0W <- g0(W)
df$Q0aW <- Q0aW
return(df)
}
SL.library <- list(Q = c("SL.glm", "SL.glmem", "SL.glmnet", "SL.glmnetem", "SL.polymars", "SL.step.forward",
"SL.gam", "SL.mean"), g = c("mnSL.glmnet", "mnSL.multinom", "mnSL.mean", "mnSL.polymars"), QaV = c("SL.polymars",
"SL.glm", "SL.glmnet", "SL.step.forward", "SL.gam", "SL.mean"))
SL.library$QaV <- sl_to_mv_library(SL.library$QaV)
# SL.library$Q <- sl_to_strat_library(SL.library$Q, 'A')
testdata <- gen_data(1e+05, 5)
data <- gen_data(1000, 5)
Wnodes <- grep("^W", names(data), value = TRUE)
Anode <- "A"
Ynode <- "Y"
Vnodes <- Wnodes
nodes <- list(Wnodes = Wnodes, Anode = Anode, Ynode = "Ystar", Vnodes = Vnodes)
system.time({
result <- opt_tmle(data, SL.library = SL.library)
})
vimresult <- backward_vim(result, testdata, Qbar0)
ggplot(vimresult$vimdf, aes(y = Vnode, x = est, xmin = lower, xmax = upper)) + geom_point() + geom_point(aes(x = test),
color = "red") + geom_errorbarh() + facet_wrap(~metric, scales = "free") + theme_bw()
print(result)
plot(result)
Wnodes <- result$nodes$Wnodes
QaV_dV <- predict(result, newdata = testdata[, Wnodes], pred_fit = "QaV")$dV
QaV_perf <- mean(Qbar0(QaV_dV, testdata[, Wnodes]))
class_dV <- predict(result, newdata = testdata[, Wnodes], pred_fit = "class")$dV
class_perf <- mean(Qbar0(class_dV, testdata[, Wnodes]))
joint_dV <- predict(result, newdata = testdata[, Wnodes], pred_fit = "joint")$dV
joint_perf <- mean(Qbar0(joint_dV, testdata[, Wnodes]))
EYd0_perf <- mean(Qbar0(testdata$d0, testdata[, Wnodes]))
c(QaV_perf, class_perf, joint_perf, EYd0_perf)
# perf of true blip approx=0.748
plot(result)
vim <- tx_vim(result)
ggplot(vim, aes(y = node, x = risk_full_fraction, color = model)) + geom_point() + theme_bw() + xlab("VIM")
library(reshape2)
long <- melt(vim, id = c("node", "model"))
ggplot(long, aes(y = node, x = value, color = model)) + geom_point() + facet_wrap(~variable, scales = "free") +
theme_bw()
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.