R/deprecated.R

Defines functions multi_causal_forest

Documented in multi_causal_forest

#' (deprecated) One vs. all causal forest for multiple treatment effect estimation
#'
#' Since policytree version 1.1 this function is deprecated in favor of the new estimator
#' `multi_arm_causal_forest` available in GRF (version 2+). This function will continue to work
#' for now but passes its arguments onto the "conformable" `multi_arm_causal_forest` in GRF, with a warning.
#' (Note: for policy learning this forest works as before,
#' but for individual point predictions, they differ as `multi_arm_causal_forest` predicts contrasts.
#' See the GRF documentation example for details.)
#'
#' @param X The covariates used in the causal regression.
#' @param Y The outcome (must be a numeric vector with no NAs).
#' @param W The treatment assignment (must be a categorical vector with no NAs).
#' @param Y.hat Estimates of the expected responses E\[Y | Xi\], marginalizing
#'              over treatment. If Y.hat = NULL, these are estimated using
#'              a separate regression forest. See section 6.1.1 of the GRF paper for
#'              further discussion of this quantity. Default is NULL.
#' @param W.hat Matrix with estimates of the treatment propensities E\[Wk | Xi\]. If W.hat = NULL,
#'              these are estimated using a k separate regression forests. Default is NULL.
#' @param num.trees Number of trees grown in the forest. Note: Getting accurate
#'                  confidence intervals generally requires more trees than
#'                  getting accurate predictions. Default is 2000.
#' @param sample.weights (experimental) Weights given to each sample in estimation.
#'                       If NULL, each observation receives the same weight.
#'                       Note: To avoid introducing confounding, weights should be
#'                       independent of the potential outcomes given X. Default is NULL.
#' @param clusters Vector of integers or factors specifying which cluster each observation corresponds to.
#'  Default is NULL (ignored).
#' @param equalize.cluster.weights If FALSE, each unit is given the same weight (so that bigger
#'  clusters get more weight). If TRUE, each cluster is given equal weight in the forest. In this case,
#'  during training, each tree uses the same number of observations from each drawn cluster: If the
#'  smallest cluster has K units, then when we sample a cluster during training, we only give a random
#'  K elements of the cluster to the tree-growing procedure. When estimating average treatment effects,
#'  each observation is given weight 1/cluster size, so that the total weight of each cluster is the
#'  same. Note that, if this argument is FALSE, sample weights may also be directly adjusted via the
#'  sample.weights argument. If this argument is TRUE, sample.weights must be set to NULL. Default is
#'  FALSE.
#' @param sample.fraction Fraction of the data used to build each tree.
#'                        Note: If honesty = TRUE, these subsamples will
#'                        further be cut by a factor of honesty.fraction. Default is 0.5.
#' @param mtry Number of variables tried for each split. Default is
#'             \eqn{\sqrt p + 20} where p is the number of variables.
#' @param min.node.size A target for the minimum number of observations in each tree leaf. Note that nodes
#'                      with size smaller than min.node.size can occur, as in the original randomForest package.
#'                      Default is 5.
#' @param honesty Whether to use honest splitting (i.e., sub-sample splitting). Default is TRUE.
#'  For a detailed description of honesty, honesty.fraction, honesty.prune.leaves, and recommendations for
#'  parameter tuning, see the grf
#'  \href{https://grf-labs.github.io/grf/REFERENCE.html#honesty-honesty-fraction-prune-empty-leaves}{algorithm reference}.
#' @param honesty.fraction The fraction of data that will be used for determining splits if honesty = TRUE. Corresponds
#'                         to set J1 in the notation of the paper. Default is 0.5 (i.e. half of the data is used for
#'                         determining splits).
#' @param honesty.prune.leaves If true, prunes the estimation sample tree such that no leaves
#'  are empty. If false, keep the same tree as determined in the splits sample (if an empty leave is encountered, that
#'  tree is skipped and does not contribute to the estimate). Setting this to false may improve performance on
#'  small/marginally powered data, but requires more trees (note: tuning does not adjust the number of trees).
#'  Only applies if honesty is enabled. Default is TRUE.
#' @param alpha A tuning parameter that controls the maximum imbalance of a split. Default is 0.05.
#' @param imbalance.penalty A tuning parameter that controls how harshly imbalanced splits are penalized. Default is 0.
#' @param stabilize.splits Whether or not the treatment should be taken into account when
#'                         determining the imbalance of a split. Default is TRUE.
#' @param ci.group.size The forest will grow ci.group.size trees on each subsample.
#'                      In order to provide confidence intervals, ci.group.size must
#'                      be at least 2. Default is 2.
#' @param tune.parameters A vector of parameter names to tune.
#'  If "all": all tunable parameters are tuned by cross-validation. The following parameters are
#'  tunable: ("sample.fraction", "mtry", "min.node.size", "honesty.fraction",
#'   "honesty.prune.leaves", "alpha", "imbalance.penalty"). If honesty is false these parameters are not tuned.
#'  Default is "none" (no parameters are tuned).
#' @param tune.num.trees The number of trees in each 'mini forest' used to fit the tuning model. Default is 200.
#' @param tune.num.reps The number of forests used to fit the tuning model. Default is 50.
#' @param tune.num.draws The number of random parameter values considered when using the model
#'                          to select the optimal parameters. Default is 1000.
#' @param compute.oob.predictions Whether OOB predictions on training set should be precomputed. Default is TRUE.
#' @param orthog.boosting Deprecated and unused after version 1.0.4.
#' @param num.threads Number of threads used in training. By default, the number of threads is set
#'                    to the maximum hardware concurrency.
#' @param seed The seed of the C++ random number generator.
#'
#' @return A warning will be issued and this function passes its arguments onto the new
#'  estimator `multi_arm_causal_forest` and returns that object.
#'
#' @export
multi_causal_forest <- function(X, Y, W,
                                Y.hat = NULL,
                                W.hat = NULL,
                                num.trees = 2000,
                                sample.weights = NULL,
                                clusters = NULL,
                                equalize.cluster.weights = FALSE,
                                sample.fraction = 0.5,
                                mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X)),
                                min.node.size = 5,
                                honesty = TRUE,
                                honesty.fraction = 0.5,
                                honesty.prune.leaves = TRUE,
                                alpha = 0.05,
                                imbalance.penalty = 0,
                                stabilize.splits = TRUE,
                                ci.group.size = 2,
                                tune.parameters = "none",
                                tune.num.trees = 200,
                                tune.num.reps = 50,
                                tune.num.draws = 1000,
                                compute.oob.predictions = TRUE,
                                orthog.boosting = FALSE,
                                num.threads = NULL,
                                seed = runif(1, 0, .Machine$integer.max)) {
  warning(paste0("\nDeprecation warning:\n",
                 "This forest is deprecated. ",
                 "Returning an object of type `multi_arm_causal_forest` available in GRF 2.0+"))
  grf::multi_arm_causal_forest(X, Y, W,
                               Y.hat = Y.hat,
                               W.hat = W.hat,
                               num.trees = num.trees,
                               sample.weights = sample.weights,
                               clusters = clusters,
                               equalize.cluster.weights = equalize.cluster.weights,
                               sample.fraction = sample.fraction,
                               mtry = mtry,
                               min.node.size = min.node.size,
                               honesty = honesty,
                               honesty.fraction = honesty.fraction,
                               honesty.prune.leaves = honesty.prune.leaves,
                               alpha = alpha,
                               imbalance.penalty = imbalance.penalty,
                               stabilize.splits = stabilize.splits,
                               ci.group.size = ci.group.size,
                               compute.oob.predictions = compute.oob.predictions,
                               num.threads = num.threads,
                               seed = seed)
}

Try the policytree package in your browser

Any scripts or data that you put into this service are public.

policytree documentation built on July 9, 2023, 6:30 p.m.