View source: R/meta_learners.R
| metalearner_ensemble | R Documentation |
metalearner_ensemble implements the S-learner, T-learner, and X-learner for
weighted ensemble learning estimation of CATEs using super learner. The super learner in
this case includes the following machine learning algorithms:
extreme gradient boosting, glmnet (elastic net regression), random forest and
neural nets.
metalearner_ensemble(
data = NULL,
train.data = NULL,
test.data = NULL,
cov.formula,
treat.var,
meta.learner.type,
SL.learners = c("SL.glmnet", "SL.xgboost", "SL.nnet"),
nfolds = 5,
family = gaussian(),
binary.preds = FALSE,
conformal = FALSE,
alpha = 0.1,
calib_frac = 0.5,
seed = 1234
)
data |
|
train.data |
|
test.data |
|
cov.formula |
formula description of the model y ~ x(list of covariates) permits users to incorporate outcome variable and confounders in model. |
treat.var |
string for the name of treatment variable. |
meta.learner.type |
string specifying is the S-learner and
|
SL.learners |
vector for super learner ensemble that includes extreme gradient boosting, glmnet, random forest, and neural nets. |
nfolds |
number of folds for cross-validation. Currently supports up to 5 folds. |
family |
gaussian() or binomial() family for outcome variable. 5 folds. |
binary.preds |
logical for whether outcome predictions should be binary |
conformal |
logical for whether to compute conformal prediction intervals |
alpha |
proportion for conformal prediction intervals |
calib_frac |
fraction of training data to use for calibration in conformal inference |
seed |
random seed |
metalearner_ensemble of predicted outcome values and CATEs
estimated by the meta learners for each observation.
# load dataset
data(exp_data)
#load SuperLearner package
library(SuperLearner)
# estimate CATEs with S Learner
set.seed(123456)
slearner <- metalearner_ensemble(cov.formula = support_war ~ age +
income + employed + job_loss,
data = exp_data,
treat.var = "strong_leader",
meta.learner.type = "S.Learner",
SL.learners = c("SL.glm"),
nfolds = 5,
binary.preds = FALSE,
)
print(slearner)
# estimate CATEs with T Learner
set.seed(123456)
tlearner <- metalearner_ensemble(cov.formula = support_war ~ age + income +
employed + job_loss,
data = exp_data,
treat.var = "strong_leader",
meta.learner.type = "T.Learner",
SL.learners = c("SL.xgboost",
"SL.nnet"),
nfolds = 5,
binary.preds = FALSE,
)
print(tlearner)
# estimate CATEs with X Learner
set.seed(123456)
xlearner <- metalearner_ensemble(cov.formula = support_war ~ age + income +
employed + job_loss,
test.data = exp_data,
train.data = exp_data,
treat.var = "strong_leader",
meta.learner.type = "X.Learner",
SL.learners = c("SL.glmnet","SL.xgboost",
"SL.nnet"),
binary.preds = TRUE)
print(xlearner)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.