extract_segments: Extract lift segments from an rpart object in a table form

Description Usage Arguments Value See Also Examples

View source: R/extract_segments.R

Description

extract_segments takes as input a fitted rpart object using the lift_method and returns the resulting segments in table form. See example below for more details on it's usage.

Usage

1
extract_segments(rpart_fit)

Arguments

rpart_fit

An object of class rpart fitted with the lift_method method import by import_lift_method.

Value

A data.frame containing the resulting segments. It contains confidednce intervals with the alpha parameter specified in the parms argument to the rpart function.

See Also

import_lift_method

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
set.seed(1)
library(rpart)
library(RCTree)

# Generate a dataset
p_x <- function(Tr, X1, X2, X3){
  lp <- Tr*X1 - 0.2*X2 + as.numeric(X3)
  exp(lp)/(1+exp(lp))
}

n <- 3000
Tr <- rbinom(n, 1, 0.3)
X1 <- rnorm(n, 0.5)
X2 <- runif(n, -1, 1)
X3 <- factor(sample(LETTERS[1:3], size = n, replace = T))
p <- p_x(Tr, X1, X2, X3)
y <- sapply(p, function(x) rbinom(1, 1, x))
y_mat <- cbind(y, Tr)
dat <- data.frame(y = I(y_mat), X1, X2)

# Fit a causal tree
lift_method <- import_lift_method()
baseline_lift <- mean(y_mat[y_mat[, 2] == 1, 1]) - mean(y_mat[y_mat[, 2] == 0, 1])
causal_tree <- rpart(y ~ ., data = dat,
              method = lift_method, control = rpart.control(cp = -Inf, minbucket = 700),
              parms = list(baseline_lift = baseline_lift, alpha = 0.05))

# Predict treatment effect and compare with actual treatment effect
tau <- predict(causal_tree, dat)
p_treat <- p_x(rep(1, n), X1, X2, X3)[order(X1)]
p_cont <- p_x(rep(0, n), X1, X2, X3)[order(X1)]

plot(sort(dat$X1), p_treat - p_cont, col = "red")
points(dat$X1, tau)
abline(h = baseline_lift)

# explore the resulting segments
segments <- extract_segments(causal_tree)
print(segments)

# Compare with the ATE:
print(baseline_lift)

# Compare to a regular classfication model
dat2 <- data.frame(y, X1, X2, X3, Tr)
fit2 <- rpart(y ~ ., data = dat2)
dat2_treat <- dat2; dat2_cont <- dat2
dat2_treat$Tr <- 1L; dat2_cont$Tr <- 0L
tau2 <- predict(fit2, dat2_treat) - predict(fit2, dat2_cont)

plot(sort(dat$X1), p_treat - p_cont, col = "red")
points(dat$X1, tau2)
abline(h = baseline_lift)

IyarLin/RCTree documentation built on April 13, 2020, 12:37 a.m.