Lrnr_grfcate: Generalized Random Forests for Conditional Average Treatment...

Lrnr_grfcateR Documentation

Generalized Random Forests for Conditional Average Treatment Effects

Description

This learner implements the so-called "Causal Forests" estimator of the conditional average treatment effect (CATE) using the grf package function causal_forest. This learner is intended for use in the tmle3mopttx package, where it is necessary to fit the CATE, and then predict CATE values from new covariate data. As such, this learner requires a treatment/exposure node to be specified (A).

Format

An R6Class object inheriting from Lrnr_base.

Value

A learner object inheriting from Lrnr_base with methods for training and prediction. For a full list of learner functionality, see the complete documentation of Lrnr_base.

Parameters

  • A: Column name in the sl3_Task's covariates that indicates the treatment/exposure of interest. The treatment assignment must be a binary or real numeric vector with no NAs.

  • ...: Other parameters passed to causal_forest. See its documentation for details.

See Also

Other Learners: Custom_chain, Lrnr_HarmonicReg, Lrnr_arima, Lrnr_bartMachine, Lrnr_base, Lrnr_bayesglm, Lrnr_caret, Lrnr_cv_selector, Lrnr_cv, Lrnr_dbarts, Lrnr_define_interactions, Lrnr_density_discretize, Lrnr_density_hse, Lrnr_density_semiparametric, Lrnr_earth, Lrnr_expSmooth, Lrnr_gam, Lrnr_ga, Lrnr_gbm, Lrnr_glm_fast, Lrnr_glm_semiparametric, Lrnr_glmnet, Lrnr_glmtree, Lrnr_glm, Lrnr_grf, Lrnr_gru_keras, Lrnr_gts, Lrnr_h2o_grid, Lrnr_hal9001, Lrnr_haldensify, Lrnr_hts, Lrnr_independent_binomial, Lrnr_lightgbm, Lrnr_lstm_keras, Lrnr_mean, Lrnr_multiple_ts, Lrnr_multivariate, Lrnr_nnet, Lrnr_nnls, Lrnr_optim, Lrnr_pca, Lrnr_pkg_SuperLearner, Lrnr_polspline, Lrnr_pooled_hazards, Lrnr_randomForest, Lrnr_ranger, Lrnr_revere_task, Lrnr_rpart, Lrnr_rugarch, Lrnr_screener_augment, Lrnr_screener_coefs, Lrnr_screener_correlation, Lrnr_screener_importance, Lrnr_sl, Lrnr_solnp_density, Lrnr_solnp, Lrnr_stratified, Lrnr_subset_covariates, Lrnr_svm, Lrnr_tsDyn, Lrnr_ts_weights, Lrnr_xgboost, Pipeline, Stack, define_h2o_X(), undocumented_learner

Examples

data(mtcars)
mtcars_task <- sl3_Task$new(
  data = mtcars,
  covariates = c("cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am"),
  outcome = "mpg"
)
# simple prediction with lasso penalty
grfcate_lrnr <- Lrnr_grfcate$new(A = "vs")
grfcate_fit <- grfcate_lrnr$train(mtcars_task)
grf_cate_predictions <- grfcate_fit$predict()

tlverse/sl3 documentation built on Nov. 18, 2024, 12:46 a.m.