opt_tmle: opt_tmle

Description Usage Arguments Examples

View source: R/opt_tmle.R

Description

Estimation of the Optimal Treatment rule using Super Learner and mean performance using CV-TMLE To avoid nesting cross-validation, it uses split-specfic estimates of Q and g to estimate the rule, and 'split-specific' estimates of the rule in CV-TMLE to estimate mean performance

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
opt_tmle(data, Wnodes = grep("^W", names(data), value = TRUE), Anode = "A",
  Ynode = "Y", Vnodes = Wnodes, stratifyAY = TRUE,
  SL.library = opt_tmle.SL.library, verbose = 3, parallel = FALSE,
  perf_tmle = TRUE, perf_dripcw = FALSE, perf_cv = TRUE,
  perf_full = FALSE, maximize = TRUE, ...)

## S3 method for class 'opt_tmle'
print(obj)

## S3 method for class 'opt_tmle'
plot(obj)

## S3 method for class 'opt_tmle'
predict(obj, ...)

Arguments

data

data.frame containing the relevant variable

Wnodes,

vector of column names indicating covariates

Anode,

column name of treatment

Ynode,

column name of outcome

Vnodes,

vector of column names to base the treatment on

stratifyAY,

logical: should we stratify the cross-validation based on (A,Y) pairs

verbose,

integer that controls the verbosity of the output (higher is more verbose)

parallel,

logical: should foreach parallelization be used?

SL_library,

list of SuperLearner libraries for the various models. See opt_tmle.SL_library for an example.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Qbar0 <- function(A, W) {
    
    W1 <- W[, 1]
    W2 <- W[, 2]
    W3 <- W[, 3]
    W4 <- W[, 4]
    Qbar <- (1/2) * (plogis(-5 * (A == 2) * (W1 + 0.5) + 5 * (A == 3) * (W1 - 0.5)) + plogis(W2 * W3))
    return(Qbar)
}

g0 <- function(W) {
    W1 <- W[, 1]
    W2 <- W[, 2]
    W3 <- W[, 3]
    W4 <- W[, 4]
    
    # rep(0.5, nrow(W))
    scale_factor <- 0.8
    A1 <- plogis(scale_factor * W1)
    A2 <- plogis(scale_factor * W2)
    A3 <- plogis(scale_factor * W3)
    A <- cbind(A1, A2, A3)
    
    # make sure A sums to 1
    A <- normalize_rows(A)
}

gen_data <- function(n = 1000, p = 4) {
    W <- matrix(rnorm(n * p), nrow = n)
    colnames(W) <- paste("W", seq_len(p), sep = "")
    g0W <- g0(W)
    A <- factor(apply(g0W, 1, function(pAi) which(rmultinom(1, 1, pAi) == 1)))
    A_vals <- vals_from_factor(A)
    
    u <- runif(n)
    Y <- as.numeric(u < Qbar0(A, W))
    Q0aW <- sapply(A_vals, Qbar0, W)
    d0 <- max.col(Q0aW)
    Yd0 <- as.numeric(u < Qbar0(d0, W))
    df <- data.frame(W, A, Y, d0, Yd0)
    
    df$g0W <- g0(W)
    df$Q0aW <- Q0aW
    
    return(df)
}






SL.library <- list(Q = c("SL.glm", "SL.glmem", "SL.glmnet", "SL.glmnetem", "SL.polymars", "SL.step.forward", 
    "SL.gam", "SL.mean"), g = c("mnSL.glmnet", "mnSL.multinom", "mnSL.mean", "mnSL.polymars"), QaV = c("SL.polymars", 
    "SL.glm", "SL.glmnet", "SL.step.forward", "SL.gam", "SL.mean"))

SL.library$QaV <- sl_to_mv_library(SL.library$QaV)
# SL.library$Q <- sl_to_strat_library(SL.library$Q, 'A')

testdata <- gen_data(1e+05, 5)
data <- gen_data(1000, 5)

Wnodes <- grep("^W", names(data), value = TRUE)
Anode <- "A"
Ynode <- "Y"
Vnodes <- Wnodes
nodes <- list(Wnodes = Wnodes, Anode = Anode, Ynode = "Ystar", Vnodes = Vnodes)

system.time({
    result <- opt_tmle(data, SL.library = SL.library)
})

vimresult <- backward_vim(result, testdata, Qbar0)

ggplot(vimresult$vimdf, aes(y = Vnode, x = est, xmin = lower, xmax = upper)) + geom_point() + geom_point(aes(x = test), 
    color = "red") + geom_errorbarh() + facet_wrap(~metric, scales = "free") + theme_bw()

print(result)
plot(result)
Wnodes <- result$nodes$Wnodes
QaV_dV <- predict(result, newdata = testdata[, Wnodes], pred_fit = "QaV")$dV
QaV_perf <- mean(Qbar0(QaV_dV, testdata[, Wnodes]))
class_dV <- predict(result, newdata = testdata[, Wnodes], pred_fit = "class")$dV
class_perf <- mean(Qbar0(class_dV, testdata[, Wnodes]))
joint_dV <- predict(result, newdata = testdata[, Wnodes], pred_fit = "joint")$dV
joint_perf <- mean(Qbar0(joint_dV, testdata[, Wnodes]))
EYd0_perf <- mean(Qbar0(testdata$d0, testdata[, Wnodes]))
c(QaV_perf, class_perf, joint_perf, EYd0_perf)
# perf of true blip approx=0.748

plot(result)

vim <- tx_vim(result)
ggplot(vim, aes(y = node, x = risk_full_fraction, color = model)) + geom_point() + theme_bw() + xlab("VIM")

library(reshape2)
long <- melt(vim, id = c("node", "model"))
ggplot(long, aes(y = node, x = value, color = model)) + geom_point() + facet_wrap(~variable, scales = "free") + 
    theme_bw() 

jeremyrcoyle/opttx documentation built on May 19, 2019, 5:08 a.m.