#' @include CATE_estimators.R
# S-RF class -------------------------------------------------------------------
setClass(
"S_RF",
contains = "MetaLearner",
slots = list(
feature_train = "data.frame",
tr_train = "numeric",
yobs_train = "numeric",
forest = "forestry",
hyperparameter_list = "list",
creator = "function"
),
validity = function(object)
{
if (!all(object@tr_train %in% c(0, 1))) {
return("TR is the treatment and must be either 0 or 1")
}
return(TRUE)
}
)
# S_RF generator ---------------------------------------------------------------
#' @title S-Learners
#' @description \code{S_RF} is an implementation of the S-Learner combined with
#' Random Forests (Breiman 2001).
#' @rdname Slearners
#' @name S-Learner
#' @details
#' In the S-Learner, the outcome is estimated using all of the features and the
#' treatment indicator without giving the treatment indicator a special role.
#' The predicted CATE for an individual unit is then the difference between the
#' predicted values when the treatment assignment indicator is changed from
#' control to treatment:
#' \enumerate{
#' \item
#' Estimate the joint response function \deqn{\mu(x, w) = E[Y | X = x, W =
#' w]} using the
#' base learner.
#' We denote the estimate as \eqn{\hat \mu}.
#' \item
#' Define the CATE estimate as
#' \deqn{\tau(x) = \hat \mu_1(x, 1) - \hat \mu_0(x, 0).}
#' }
#' @param mu.forestry A list containing the hyperparameters for the
#' \code{forestry} package that are used in \eqn{\hat \mu_0}.
#' These hyperparameters are passed to the \code{forestry} package.
#' @return Object of class \code{S_RF}. It should be used with one of the
#' following functions \code{EstimateCATE}, \code{CateCI}, and
#' \code{CateBIAS}. The object has the following slots:
#' \item{\code{feature_train}}{A copy of feat.}
#' \item{\code{tr_train}}{A copy of tr.}
#' \item{\code{yobs_train}}{A copy of yobs.}
#' \item{\code{m_0}}{An object of class forestry that is fitted with the
#' observed outcomes of the control group as the dependent variable.}
#' \item{\code{m_1}}{An object of class forestry that is fitted with the
#' observed outcomes of the treated group as the dependent variable.}
#' \item{\code{hyperparameter_list}}{A list containting the hyperparameters of
#' the three random forest algorithms used.}
#' \item{\code{creator}}{Function call of \code{S_RF}. This is used for different
#' bootstrap procedures.}
#' @inherit X-Learner
#' @family metalearners
#' @export
S_RF <-
function(feat,
tr,
yobs,
nthread = 0,
verbose = TRUE,
mu.forestry =
list(
relevant.Variable = 1:ncol(feat),
ntree = 1000,
replace = TRUE,
sample.fraction = 0.9,
mtry = ncol(feat),
nodesizeSpl = 1,
nodesizeAvg = 3,
splitratio = .5,
middleSplit = FALSE
)) {
# Cast input data to a standard format -------------------------------------
feat <- as.data.frame(feat)
# Catch misspecification erros ---------------------------------------------
if (!(nthread - round(nthread) == 0) | nthread < 0) {
stop("nthread must be a positive integer!")
}
if (!is.logical(verbose)) {
stop("verbose must be either TRUE or FALSE.")
}
catch_input_errors(feat, yobs, tr)
# Set relevant relevant.Variable -------------------------------------------
# User often sets the relevant variables by column names and not numerical
# values. We translate it here to the index of the columns.
if (is.null(mu.forestry$relevant.Variable)) {
mu.forestry$relevant.Variable <- 1:ncol(feat)
} else{
if (is.character(mu.forestry$relevant.Variable))
mu.forestry$relevant.Variable <-
which(colnames(feat) %in% mu.forestry$relevant.Variable)
}
# Translate the settings to a feature list ---------------------------------
general_hyperpara <- list("nthread" = nthread)
hyperparameter_list <- list(
"general" = general_hyperpara,
"mu.forestry" = mu.forestry
)
return(
S_RF_fully_specified(
feat = feat,
tr = tr,
yobs = yobs,
hyperparameter_list = hyperparameter_list,
verbose = verbose
)
)
}
# S-RF basic constructor -------------------------------------------------------
S_RF_fully_specified <-
function(feat,
tr,
yobs,
hyperparameter_list,
verbose) {
m <- forestry::forestry(
x = cbind(feat[, hyperparameter_list[["mu.forestry"]]$relevant.Variable],
tr),
y = yobs,
ntree = hyperparameter_list[["mu.forestry"]]$ntree,
replace = hyperparameter_list[["mu.forestry"]]$replace,
sample.fraction = hyperparameter_list[["mu.forestry"]]$sample.fraction,
mtry = hyperparameter_list[["mu.forestry"]]$mtry,
nodesizeSpl = hyperparameter_list[["mu.forestry"]]$nodesizeSpl,
nodesizeAvg = hyperparameter_list[["mu.forestry"]]$nodesizeAvg,
nthread = hyperparameter_list[["general"]]$nthread,
splitrule = "variance",
splitratio = hyperparameter_list[["mu.forestry"]]$splitratio
)
new(
"S_RF",
feature_train = feat,
tr_train = tr,
yobs_train = yobs,
forest = m,
hyperparameter_list = hyperparameter_list,
creator = function(feat, tr, yobs) {
S_RF_fully_specified(feat,
tr,
yobs,
hyperparameter_list,
verbose)
}
)
}
############################
### Estimate CATE Method ###
############################
#' EstimateCate-S_hRF
#' @rdname EstimateCate
#' @inherit EstimateCate
#' @exportMethod EstimateCate
setMethod(
f = "EstimateCate",
signature = "S_RF",
definition = function(theObject, feature_new)
{
feature_new <- as.data.frame(feature_new)
catch_feat_input_errors(feature_new)
return(
predict(theObject@forest, cbind(feature_new, tr = 1)) -
predict(theObject@forest, cbind(feature_new, tr = 0))
)
}
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.