import_lift_method: Import lift_method for lift modeling with rpart

Description Usage Details Value See Also Examples

View source: R/import_lift_method.R

Description

import_lift_method is a function that imports a list of functions to serve as a user defined method with the rpart function. See example below for more details on it's usage.

Usage

1

Details

The rpart function accepts in the method argument a user defined list. This function imports the list that implements a causal tree. In addition to the method a user also needs to input the parms argument with the baseline lift in the population and the significance level for lift confidence intervals returned in the yval2 object. See example below for more details on how to use.

Value

A list containing eval, split and init functions.

See Also

extract_segments

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
set.seed(1)
library(rpart)
library(RCTree)

# Generate a dataset
p_x <- function(Tr, X1, X2, X3){
  lp <- Tr*X1 - 0.2*X2 + as.numeric(X3)
  exp(lp)/(1+exp(lp))
}

n <- 3000
Tr <- rbinom(n, 1, 0.3)
X1 <- rnorm(n, 0.5)
X2 <- runif(n, -1, 1)
X3 <- factor(sample(LETTERS[1:3], size = n, replace = T))
p <- p_x(Tr, X1, X2, X3)
y <- sapply(p, function(x) rbinom(1, 1, x))
y_mat <- cbind(y, Tr)
dat <- data.frame(y = I(y_mat), X1, X2)

# Fit a causal tree
lift_method <- import_lift_method()
baseline_lift <- mean(y_mat[y_mat[, 2] == 1, 1]) - mean(y_mat[y_mat[, 2] == 0, 1])
causal_tree <- rpart(y ~ ., data = dat,
              method = lift_method, control = rpart.control(cp = -Inf, minbucket = 700),
              parms = list(baseline_lift = baseline_lift, alpha = 0.05))

# Predict treatment effect and compare with actual treatment effect
tau <- predict(causal_tree, dat)
p_treat <- p_x(rep(1, n), X1, X2, X3)[order(X1)]
p_cont <- p_x(rep(0, n), X1, X2, X3)[order(X1)]

plot(sort(dat$X1), p_treat - p_cont, col = "red")
points(dat$X1, tau)
abline(h = baseline_lift)

# explore the resulting segments
segments <- extract_segments(causal_tree)
print(segments)

# Compare with the ATE:
print(baseline_lift)

# Compare to a regular classfication model
dat2 <- data.frame(y, X1, X2, X3, Tr)
fit2 <- rpart(y ~ ., data = dat2)
dat2_treat <- dat2; dat2_cont <- dat2
dat2_treat$Tr <- 1L; dat2_cont$Tr <- 0L
tau2 <- predict(fit2, dat2_treat) - predict(fit2, dat2_cont)

plot(sort(dat$X1), p_treat - p_cont, col = "red")
points(dat$X1, tau2)
abline(h = baseline_lift)

IyarLin/RCTree documentation built on April 13, 2020, 12:37 a.m.