# Temporary wrapper that needs to be fixed in SuperLearner
#' Wrapper for bartMachine learner
#'
#' Support bayesian additive regression trees via the bartMachine package.
#'
#' @param Y Outcome variable
#' @param X Covariate dataframe
#' @param newX Optional dataframe to predict the outcome
#' @param obsWeights Optional observation-level weights (supported but not tested)
#' @param id Optional id to group observations from the same unit (not used
#' currently).
#' @param family "gaussian" for regression, "binomial" for binary
#' classification
#' @param num_trees The number of trees to be grown in the sum-of-trees model.
#' @param num_burn_in Number of MCMC samples to be discarded as "burn-in".
#' @param num_iterations_after_burn_in Number of MCMC samples to draw from the
#' posterior distribution of f(x).
#' @param alpha Base hyperparameter in tree prior for whether a node is
#' nonterminal or not.
#' @param beta Power hyperparameter in tree prior for whether a node is
#' nonterminal or not.
#' @param k For regression, k determines the prior probability that E(Y|X) is
#' contained in the interval (y_{min}, y_{max}), based on a normal
#' distribution. For example, when k=2, the prior probability is 95\%. For
#' classification, k determines the prior probability that E(Y|X) is between
#' (-3,3). Note that a larger value of k results in more shrinkage and a more
#' conservative fit.
#' @param q Quantile of the prior on the error variance at which the data-based
#' estimate is placed. Note that the larger the value of q, the more
#' aggressive the fit as you are placing more prior weight on values lower
#' than the data-based estimate. Not used for classification.
#' @param nu Degrees of freedom for the inverse chi^2 prior. Not used for
#' classification.
#' @param verbose Prints information about progress of the algorithm to the
#' screen.
#' @param serialize If TRUE, bartMachine results can be saved to a file, but
#' will require additional RAM.
#' @param ... Additional arguments (not used)
#'
#' @encoding utf-8
#' @export
SL.bartMachine2 <- function(Y, X, newX, family, obsWeights, id,
num_trees = 50, num_burn_in = 250, verbose = F,
alpha = 0.95, beta = 2, k = 2, q = 0.9, nu = 3,
num_iterations_after_burn_in = 1000,
serialize = FALSE,
...) {
#.SL.require("bartMachine")
################
### CK changes:
if (family$family == "binomial") {
# Need to convert Y to a factor, otherwise bartMachine does regression.
# And importantly, bartMachine expects the first level to be the positive
# class, so we have to specify levels.
Y = factor(Y, levels = c("1", "0"))
}
model = bartMachine::bartMachine(X, Y, num_trees = num_trees,
num_burn_in = num_burn_in, verbose = verbose,
alpha = alpha, beta = beta, k = k, q = q, nu = nu,
num_iterations_after_burn_in = num_iterations_after_burn_in,
serialize = serialize)
# pred returns predicted responses (on the scale of the outcome)
#pred <- bartMachine:::predict.bartMachine(model, newX)
pred <- predict(model, newX)
fit <- list(object = model)
class(fit) <- c("SL.bartMachine")
out <- list(pred = pred, fit = fit)
return(out)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.