Description Usage Arguments Value See Also Examples
View source: R/extract_segments.R
extract_segments
takes as
input a fitted rpart object using the lift_method
and returns the resulting segments in table form. See example
below for more details on it's usage.
1 | extract_segments(rpart_fit)
|
rpart_fit |
An object of class |
A data.frame containing the resulting segments.
It contains confidednce intervals with the alpha parameter specified in the parms
argument to the rpart
function.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | set.seed(1)
library(rpart)
library(RCTree)
# Generate a dataset
p_x <- function(Tr, X1, X2, X3){
lp <- Tr*X1 - 0.2*X2 + as.numeric(X3)
exp(lp)/(1+exp(lp))
}
n <- 3000
Tr <- rbinom(n, 1, 0.3)
X1 <- rnorm(n, 0.5)
X2 <- runif(n, -1, 1)
X3 <- factor(sample(LETTERS[1:3], size = n, replace = T))
p <- p_x(Tr, X1, X2, X3)
y <- sapply(p, function(x) rbinom(1, 1, x))
y_mat <- cbind(y, Tr)
dat <- data.frame(y = I(y_mat), X1, X2)
# Fit a causal tree
lift_method <- import_lift_method()
baseline_lift <- mean(y_mat[y_mat[, 2] == 1, 1]) - mean(y_mat[y_mat[, 2] == 0, 1])
causal_tree <- rpart(y ~ ., data = dat,
method = lift_method, control = rpart.control(cp = -Inf, minbucket = 700),
parms = list(baseline_lift = baseline_lift, alpha = 0.05))
# Predict treatment effect and compare with actual treatment effect
tau <- predict(causal_tree, dat)
p_treat <- p_x(rep(1, n), X1, X2, X3)[order(X1)]
p_cont <- p_x(rep(0, n), X1, X2, X3)[order(X1)]
plot(sort(dat$X1), p_treat - p_cont, col = "red")
points(dat$X1, tau)
abline(h = baseline_lift)
# explore the resulting segments
segments <- extract_segments(causal_tree)
print(segments)
# Compare with the ATE:
print(baseline_lift)
# Compare to a regular classfication model
dat2 <- data.frame(y, X1, X2, X3, Tr)
fit2 <- rpart(y ~ ., data = dat2)
dat2_treat <- dat2; dat2_cont <- dat2
dat2_treat$Tr <- 1L; dat2_cont$Tr <- 0L
tau2 <- predict(fit2, dat2_treat) - predict(fit2, dat2_cont)
plot(sort(dat$X1), p_treat - p_cont, col = "red")
points(dat$X1, tau2)
abline(h = baseline_lift)
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.