Nothing
#' Wrapper for bartMachine learner
#'
#' Support bayesian additive regression trees via the bartMachine package.
#'
#' @param Y Outcome variable
#' @param X Covariate dataframe
#' @param newX Optional dataframe to predict the outcome
#' @param obsWeights Optional observation-level weights (supported but not tested)
#' @param id Optional id to group observations from the same unit (not used
#' currently).
#' @param family "gaussian" for regression, "binomial" for binary
#' classification
#' @param num_trees The number of trees to be grown in the sum-of-trees model.
#' @param num_burn_in Number of MCMC samples to be discarded as "burn-in".
#' @param num_iterations_after_burn_in Number of MCMC samples to draw from the
#' posterior distribution of f(x).
#' @param alpha Base hyperparameter in tree prior for whether a node is
#' nonterminal or not.
#' @param beta Power hyperparameter in tree prior for whether a node is
#' nonterminal or not.
#' @param k For regression, k determines the prior probability that E(Y|X) is
#' contained in the interval (y_min, y_max), based on a normal
#' distribution. For example, when k=2, the prior probability is 95\%. For
#' classification, k determines the prior probability that E(Y|X) is between
#' (-3,3). Note that a larger value of k results in more shrinkage and a more
#' conservative fit.
#' @param q Quantile of the prior on the error variance at which the data-based
#' estimate is placed. Note that the larger the value of q, the more
#' aggressive the fit as you are placing more prior weight on values lower
#' than the data-based estimate. Not used for classification.
#' @param nu Degrees of freedom for the inverse chi^2 prior. Not used for
#' classification.
#' @param verbose Prints information about progress of the algorithm to the
#' screen.
#' @param ... Additional arguments (not used)
#'
#' @encoding utf-8
#' @export
SL.bartMachine <- function(Y, X, newX, family, obsWeights, id,
num_trees = 50, num_burn_in = 250, verbose = F,
alpha = 0.95, beta = 2, k = 2, q = 0.9, nu = 3,
num_iterations_after_burn_in = 1000,
...) {
.SL.require("bartMachine")
model = bartMachine::bartMachine(X, Y, num_trees = num_trees,
num_burn_in = num_burn_in, verbose = verbose,
alpha = alpha, beta = beta, k = k, q = q, nu = nu,
num_iterations_after_burn_in = num_iterations_after_burn_in)
# pred returns predicted responses (on the scale of the outcome)
pred <- predict(model, newX)
# fit returns all objects needed for predict.SL.template
fit <- list(object = model)
#fit <- vector("list", length=0)
class(fit) <- c("SL.bartMachine")
out <- list(pred = pred, fit = fit)
return(out)
}
#' bartMachine prediction
#' @param object SuperLearner object
#' @param newdata Dataframe to predict the outcome
#' @param family "gaussian" for regression, "binomial" for binary
#' classification. (Not used)
#' @param Y Outcome variable (not used)
#' @param X Covariate dataframe (not used)
#' @param ... Additional arguments (not used)
#'
#' @export
predict.SL.bartMachine <- function(object, newdata, family, X = NULL, Y = NULL,...) {
.SL.require("bartMachine")
pred <- predict(object$object, newdata)
return(pred)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.