Nothing
#' @title
#' Causal rule ensemble
#'
#' @description
#' Performs the Causal Rule Ensemble on a data set with a response variable,
#' a treatment variable, and various features.
#'
#' @param y An observed response vector.
#' @param z A treatment vector.
#' @param X A covariate matrix (or a data frame). Should be provided as
#' numerical values.
#' @param method_params The list of parameters to define the models used,
#' including:
#' - *Parameters for Honest Splitting*
#' - *ratio_dis*: The ratio of data delegated to rules discovery
#' (default: 0.5).
#' - *Parameters for Discovery and Inference*
#' - *ite_method*: The method for ITE (pseudo-outcome) estimation
#' (default: \code{"aipw"}, options: \code{"aipw"} for Augmented Inverse
#' Probability Weighting, \code{"cf"} for Causal Forest, \code{"bart"} for
#' Causal Bayesian Additive Regression Trees, \code{"slearner"} for S-Learner,
#' \code{"tlearner"} for T-Learner, \code{"xlearner"} for X-Learner,
#' \code{"tpoisson"} for T-Learner with Poisson regression).
#' - *learner_ps*: The model for the propensity score estimation
#' (default: \code{"SL.xgboost"}, options: any SuperLearner prediction model
#' i.e., \code{"SL.lm"}, \code{"SL.svm"}, used only for \code{"aipw"},
#' \code{"bart"}, \code{"cf"} ITE estimators).
#' - *learner_y*: The model for the outcome estimation
#' (default: \code{"SL.xgboost"}, options: any SuperLearner prediction model
#' i.e., \code{"SL.lm"}, \code{"SL.svm"}, used only for \code{"aipw"},
#' \code{"slearner"}, \code{"tlearner"} and \code{"xlearner"} ITE
#' estimators).
#' @param hyper_params The list of hyper parameters to fine-tune the method,
#' including:
#' - *General hyper parameters*
#' - *intervention_vars*: Array with intervention-able covariates names used
#' for Rules Generation. Empty or null array means that all the covariates
#' are considered as intervention-able (default: `NULL`).
#' - *ntrees*: The number of decision trees for random forest (default: 20).
#' - *node_size*: Minimum size of the trees' terminal nodes (default: 20).
#' - *max_rules*: Maximum number of generated candidates rules (default: 50).
#' - *max_depth*: Maximum rules length (default: 3).
#' - *t_decay*: The decay threshold for rules pruning. Higher values will
#' carry out an aggressive pruning (default: 0.025).
#' - *t_ext*: The threshold to truncate too generic or too specific (extreme)
#' rules (default: 0.01, range: [0, 0.5)).
#' - *t_corr*: The threshold to define correlated rules (default: 1,
#' range: `[0,+inf)`).
#' - *stability_selection*: Method for stability selection for selecting the
#' rules. \code{"vanilla"} for stability selection, \code{"error_control"}
#' for stability selection with error control and \code{"no"} for no
#' stability selection (default: \code{"vanilla"}).
#' - *B*: Number of bootstrap samples for stability selection in rules
#' selection and uncertainty quantification in estimation (default: 20).
#' - *subsample*: Bootstrap ratio subsample for stability selection in rules
#' selection and uncertainty quantification in estimation (default: 0.5).
#' - *Method specific hyper parameters*
#' - *offset*: Name of the covariate to use as offset (i.e., \code{"x1"}) for
#' T-Poisson ITE estimation. Use `NULL` if offset is not used (default:
#' `NULL`).
#' - *cutoff*: Threshold (percentage) defining the minimum cutoff value for
#' the stability scores for Stability Selection (default: 0.9).
#' - *pfer*: Upper bound for the per-family error rate (tolerated amount of
#' falsely selected rules) for Error Control Stability Selection (default: 1).
#'
#' @param ite The estimated ITE vector. If given both the ITE estimation steps
#' in Discovery and Inference are skipped (default: `NULL`).
#'
#'
#' @return
#' An S3 object composed by:
#' \item{M}{the number of Decision Rules extracted at each step,}
#' \item{CATE}{the data.frame of Conditional Average Treatment Effect
#' decomposition estimates with corresponding uncertainty quantification,}
#' \item{method_params}{the list of method parameters,}
#' \item{hyper_params}{the list of hyper parameters,}
#' \item{rules}{the list of rules (implicit form) decomposing the CATE.}
#'
#' @note
#' - If `intervention_vars` are provided, it is important to note that the
#' individual treatment effect will still be computed using all covariates.
#' @export
#'
#' @examples
#'
#' \donttest{
#' set.seed(123)
#' dataset <- generate_cre_dataset(n = 400,
#' rho = 0,
#' n_rules = 2,
#' p = 10,
#' effect_size = 2,
#' binary_covariates = TRUE,
#' binary_outcome = FALSE,
#' confounding = "no")
#' y <- dataset[["y"]]
#' z <- dataset[["z"]]
#' X <- dataset[["X"]]
#'
#' method_params <- list(ratio_dis = 0.5,
#' ite_method ="aipw",
#' learner_ps = "SL.xgboost",
#' learner_y = "SL.xgboost")
#'
#' hyper_params <- list(intervention_vars = NULL,
#' offset = NULL,
#' ntrees = 20,
#' node_size = 20,
#' max_rules = 50,
#' max_depth = 3,
#' t_decay = 0.025,
#' t_ext = 0.025,
#' t_corr = 1,
#' stability_selection = "vanilla",
#' cutoff = 0.6,
#' pfer = 1,
#' B = 20,
#' subsample = 0.5)
#'
#' cre_results <- cre(y, z, X, method_params, hyper_params)
#'}
#'
cre <- function(y, z, X,
method_params = NULL, hyper_params = NULL, ite = NULL) {
"%>%" <- magrittr::"%>%"
# timing the function
st_time_cre <- proc.time()
# Input checks ---------------------------------------------------------------
check_input_data(y, z, X, ite)
method_params <- check_method_params(y = y,
ite = ite,
params = method_params)
hyper_params <- check_hyper_params(X_names = colnames(as.data.frame(X)),
params = hyper_params)
# Honest Splitting -----------------------------------------------------------
subgroups <- honest_splitting(y, z, X,
getElement(method_params, "ratio_dis"), ite)
discovery <- subgroups[["discovery"]]
inference <- subgroups[["inference"]]
y_dis <- discovery$y
z_dis <- discovery$z
X_dis <- discovery$X
ite_dis <- discovery$ite
y_inf <- inference$y
z_inf <- inference$z
X_inf <- inference$X
ite_inf <- inference$ite
intervention_vars <- getElement(hyper_params, "intervention_vars")
# Discovery ------------------------------------------------------------------
logger::log_info("Starting rules discovery...")
st_time_rd <- proc.time()
# Estimate ITE
if (is.null(ite)) {
ite_dis <- estimate_ite(y = y_dis,
z = z_dis,
X = X_dis,
ite_method = getElement(method_params, "ite_method"),
learner_ps = getElement(method_params, "learner_ps"),
learner_y = getElement(method_params, "learner_y"),
offset = getElement(method_params, "offset"))
} else {
logger::log_info("Using the provided ITE estimations...")
}
# Filter only Intervention-able variables
if (!is.null(intervention_vars)) {
X_dis <- X_dis[, intervention_vars, drop = FALSE]
}
# Discover Decision Rules
discovery <- discover_rules(X_dis,
ite_dis,
method_params,
hyper_params)
rules <- discovery[["rules"]]
M <- discovery[["M"]]
en_time_rd <- proc.time()
logger::log_info("Done with rules discovery. ",
"(WC: {g_wc_str(st_time_rd, en_time_rd)}", ".)")
# Inference ------------------------------------------------------------------
logger::log_info("Starting inference...")
st_time_inf <- proc.time()
# Estimate ITE
if (is.null(ite)) {
ite_inf <- estimate_ite(y = y_inf,
z = z_inf,
X = X_inf,
ite_method = getElement(method_params, "ite_method"),
learner_ps = getElement(method_params, "learner_ps"),
learner_y = getElement(method_params, "learner_y"),
offset = getElement(method_params, "offset"))
} else {
logger::log_info("Skipped generating ITE.",
"The provided ITE will be used.")
}
# Filter only Intervention-able variables
if (!is.null(intervention_vars)) {
X_inf <- X_inf[, intervention_vars, drop = FALSE]
}
# Generate rules matrix
if (length(rules) == 0) {
rules_matrix_inf <- NA
rules_explicit <- c()
} else {
rules_matrix_inf <- generate_rules_matrix(X_inf, rules)
if (!is.null(hyper_params$intervention_vars)) {
covariate_names <- hyper_params$intervention_vars
} else {
covariate_names <- colnames(as.data.frame(X))
}
rules_explicit <- interpret_rules(rules, covariate_names)
}
# Estimate CATE
cate_inf <- estimate_cate(rules_matrix_inf,
rules_explicit,
ite_inf,
getElement(hyper_params, "B"),
getElement(hyper_params, "subsample"))
M["select_significant"] <- as.integer(length(cate_inf$Rule)) - 1
# Estimate ITE
if (M["select_significant"] > 0) {
rules <- rules[rules_explicit %in% cate_inf$Rule[2:length(cate_inf$Rule)]]
rules_explicit <- cate_inf$Rule[2:length(cate_inf$Rule)]
} else {
rules <- NULL
rules_explicit <- NULL
}
en_time_inf <- proc.time()
logger::log_info("Done with inference. ",
"(WC: {g_wc_str(st_time_inf, en_time_inf)} ", ".)")
# Generate final results S3 object
results <- list("M" = M,
"CATE" = cate_inf,
"method_params" = method_params,
"hyper_params" = hyper_params,
"rules" = rules)
attr(results, "class") <- "cre"
# Return Results -------------------------------------------------------------
end_time_cre <- proc.time()
logger::log_info("Done with running CRE function!",
"(WC: {g_wc_str(st_time_cre, end_time_cre)}", ".)")
return(results)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.