View source: R/causal_survival_forest.R
predict.causal_survival_forest | R Documentation |
Gets estimates of tau(X) using a trained causal survival forest.
## S3 method for class 'causal_survival_forest'
predict(
object,
newdata = NULL,
num.threads = NULL,
estimate.variance = FALSE,
...
)
object |
The trained forest. |
newdata |
Points at which predictions should be made. If NULL, makes out-of-bag predictions on the training set instead (i.e., provides predictions at Xi using only trees that did not use the i-th training example). Note that this matrix should have the number of columns as the training matrix, and that the columns must appear in the same order. |
num.threads |
Number of threads used in training. If set to NULL, the software automatically selects an appropriate amount. |
estimate.variance |
Whether variance estimates for |
... |
Additional arguments (currently ignored). |
Vector of predictions along with optional variance estimates.
# Train a causal survival forest targeting a Restricted Mean Survival Time (RMST)
# with maximum follow-up time set to `horizon`.
n <- 2000
p <- 5
X <- matrix(runif(n * p), n, p)
W <- rbinom(n, 1, 0.5)
horizon <- 1
failure.time <- pmin(rexp(n) * X[, 1] + W, horizon)
censor.time <- 2 * runif(n)
Y <- pmin(failure.time, censor.time)
D <- as.integer(failure.time <= censor.time)
# Save computation time by constraining the event grid by discretizing (rounding) continuous events.
cs.forest <- causal_survival_forest(X, round(Y, 2), W, D, horizon = horizon)
# Or do so more flexibly by defining your own time grid using the failure.times argument.
# grid <- seq(min(Y), max(Y), length.out = 150)
# cs.forest <- causal_survival_forest(X, Y, W, D, horizon = horizon, failure.times = grid)
# Predict using the forest.
X.test <- matrix(0.5, 10, p)
X.test[, 1] <- seq(0, 1, length.out = 10)
cs.pred <- predict(cs.forest, X.test)
# Predict on out-of-bag training samples.
cs.pred <- predict(cs.forest)
# Predict with confidence intervals; growing more trees is now recommended.
c.pred <- predict(cs.forest, X.test, estimate.variance = TRUE)
# Compute a doubly robust estimate of the average treatment effect.
average_treatment_effect(cs.forest)
# Compute the best linear projection on the first covariate.
best_linear_projection(cs.forest, X[, 1])
# See if a causal survival forest succeeded in capturing heterogeneity by plotting
# the TOC and calculating a 95% CI for the AUTOC.
train <- sample(1:n, n / 2)
eval <- -train
train.forest <- causal_survival_forest(X[train, ], Y[train], W[train], D[train], horizon = horizon)
eval.forest <- causal_survival_forest(X[eval, ], Y[eval], W[eval], D[eval], horizon = horizon)
rate <- rank_average_treatment_effect(eval.forest,
predict(train.forest, X[eval, ])$predictions)
plot(rate)
paste("AUTOC:", round(rate$estimate, 2), "+/", round(1.96 * rate$std.err, 2))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.