xbart: Crossvalidation For Bayesian Additive Regression Trees

View source: R/xbart.R

xbartR Documentation

Crossvalidation For Bayesian Additive Regression Trees

Description

Fits the BART model against varying k, power, base, and ntree parameters using K-fold or repeated random subsampling crossvalidation, sharing burn-in between parameter settings. Results are given an array of evalulations of a loss functions on the held-out sets.

Usage

xbart(
    formula, data, subset, weights, offset, verbose = FALSE, n.samples = 200L,
    method = c("k-fold", "random subsample"), n.test = c(5, 0.2),
    n.reps = 40L, n.burn = c(200L, 150L, 50L),
    loss = c("rmse", "log", "mcr"), n.threads = dbarts::guessNumCores(), n.trees = 75L,
    k = NULL, power = 2, base = 0.95, drop = TRUE,
    resid.prior = chisq, control = dbarts::dbartsControl(), sigma = NA_real_,
    seed = NA_integer_)

Arguments

formula

An object of class formula following an analogous model description syntax as lm. For backwards compatibility, can also be the bart matrix x.train. See dbarts.

data

An optional data frame, list, or environment containing predictors to be used with the model. For backwards compatibility, can also be the bart vector y.train.

subset

An optional vector specifying a subset of observations to be used in the fitting process.

weights

An optional vector of weights to be used in the fitting process. When present, BART fits a model with observations y \mid x \sim N(f(x), \sigma^2 / w), where f(x) is the unknown function.

offset

An optional vector specifying an offset from 0 for the relationship between the underyling function, f(x), and the response y. Only is useful for binary responses, in which case the model fit is to assume P(Y = 1 \mid X = x) = \Phi(f(x) + \mathrm{offset}), where \Phi is the standard normal cumulative distribution function.

verbose

A logical determining if additional output is printed to the console.

n.samples

A positive integer, setting the number of posterior samples drawn for each fit of training data and used by the loss function.

method

Character string, either "k-fold" or "random subsample".

n.test

For each fit, the test sample size or proportion. For method "k-fold", is expected to be the number of folds, and in [2, n]. For method "random subsample", can be a real number in (0, 1) or a positive integer in (1, n). When a given as proportion, the number of test observations used is the proportion times the sample size rounded to the nearest integer.

n.reps

A positive integer setting the number of cross validation steps that will be taken. For "k-fold", each replication corresponds to fitting each of the K folds in turn, while for "random subsample" a replication is a single fit.

n.burn

Between one and three positive integers, specifying the 1) initial burn-in, 2) burn-in when moving from one parameter setting to another, and 3) the burn-in between each random subsample replication. The third parameter is also the burn in when moving between folds in "k-fold" crossvalidation.

loss

Either a one of the pre-set loss functions as character-strings (mcr - missclassification rate for binary responses, rmse - root-mean-squared-error for continuous response), log - negative log-loss for binary response (rmse serves this purpose for continuous responses), a function, or a function-evaluation environment list-pair. Functions should have prototypes of the form function(y.test, y.test.hat, weights), where y.test is the held out test subsample, y.test.hat is a matrix of dimension length(y.test) * n.samples, and weights are an optional vector of user-supplied weights. See examples.

n.threads

Across different sets of parameters (k \times power \times base \times n.trees) and n.reps, results are independent. For n.threads > 1, evaluations of the above are divided into approximately equal size evaluations chunks and executed in parallel. The default uses link{guessNumCores}, which should work across the most common operating system/hardware pairs. A value of NA is interpretted as 1.

n.trees

A vector of positive integers setting the BART hyperparameter for the number of trees in the sum-of-trees formulation. See bart.

k

A vector of positive real numbers, setting the BART hyperparameter for the node-mean prior standard deviation. If NULL, the default of bart2 will be used - 2 for continuous response and a Chi hyperprior for binary. Hyperprior crossvalidation not possible at this time.

power

A vector of real numbers greater than one, setting the BART hyperparameter for the tree prior's growth probability, given by {base} / (1 + depth)^{power}.

base

A vector of real numbers in (0, 1), setting the BART hyperparameter for the tree prior's growth probability.

drop

Logical, determining if dimensions with a single value are dropped from the result.

resid.prior

An expression of the form chisq or chisq(df, quant) that sets the prior used on the residual/error variance.

control

An object inheriting from dbartsControl, created by the dbartsControl function.

sigma

A positive numeric estimate of the residual standard deviation. If NA, a linear model is used with all of the predictors to obtain one.

seed

Optional integer specifying the desired pRNG seed. It should not be needed when running single-threaded - set.seed will suffice, and can be used to obtain reproducible results when multi-threaded. See Reproducibility section of bart.

Details

Crossvalidates n.reps replications against the crossproduct of given hyperparameter vectors n.trees \times k \times power \times base. For each fit, either one fold is withheld as test data and n.test - 1 folds are used as training data or n * n.test observations are withheld as test data and n * (1 - n.test) used as training. A replication corresponds to fitting all K folds in "k-fold" crossvalidation or a single fit with "random subsample". The training data is used to fit a model and make predictions on the test data which are used together with the test data itself to evaluate the loss function.

loss functions are either the default of average negative log-loss for binary outcomes and root-mean-squared error for continuous outcomes, missclassification rates for binary outcomes, or a function with arguments y.test and y.test.hat. y.test.hat is of dimensions equal to length(y.test) \times n.samples. A third option is to pass a list of list(function, evaluationEnvironment), so as to provide default bindings. RMSE is a monotonic transformation of the average log-loss for continuous outcomes, so specifying log-loss in that case calculates RMSE instead.

Value

An array of dimensions n.reps \times length(n.trees) \times length(k) \times length(power) \times length(base). If drop is TRUE, dimensions of length 1 are omitted. If all hyperparameters are of length 1, then the result will be a vector of length n.reps. When the result is an array, the dimnames of the result shall be set to the corresponding hyperparameters.

For method "k-fold", each element is an average across the K fits. For "random subsample", each element represents a single fit.

Author(s)

Vincent Dorie: vdorie@gmail.com

See Also

bart, dbarts

Examples

f <- function(x) {
    10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
        10 * x[,4] + 5 * x[,5]
}

set.seed(99)
sigma <- 1.0
n     <- 100

x  <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y  <- rnorm(n, Ey, sigma)

mad <- function(y.train, y.train.hat, weights) {
    # note, weights are ignored
    mean(abs(y.train - apply(y.train.hat, 1L, mean)))
}



## low iteration numbers to to run quickly
xval <- xbart(x, y, n.samples = 15L, n.reps = 4L, n.burn = c(10L, 3L, 1L),
              n.trees = c(5L, 7L),
              k = c(1, 2, 4),
              power = c(1.5, 2),
              base = c(0.75, 0.8, 0.95), n.threads = 1L,
              loss = mad)

dbarts documentation built on May 29, 2024, 3:31 a.m.