rbart | R Documentation |
Fits a varying intercept/random effect BART model.
rbart_vi(
formula, data, test, subset, weights, offset, offset.test = offset,
group.by, group.by.test, prior = cauchy,
sigest = NA_real_, sigdf = 3.0, sigquant = 0.90,
k = 2.0,
power = 2.0, base = 0.95,
n.trees = 75L,
n.samples = 1500L, n.burn = 1500L,
n.chains = 4L, n.threads = min(dbarts::guessNumCores(), n.chains),
combineChains = FALSE,
n.cuts = 100L, useQuantiles = FALSE,
n.thin = 5L, keepTrainingFits = TRUE,
printEvery = 100L, printCutoffs = 0L,
verbose = TRUE,
keepTrees = TRUE, keepCall = TRUE,
seed = NA_integer_,
keepSampler = keepTrees,
keepTestFits = TRUE,
callback = NULL,
...)
## S3 method for class 'rbart'
plot(
x, plquants = c(0.05, 0.95), cols = c('blue', 'black'), ...)
## S3 method for class 'rbart'
fitted(
object,
type = c("ev", "ppd", "bart", "ranef"),
sample = c("train", "test"),
...)
## S3 method for class 'rbart'
extract(
object,
type = c("ev", "ppd", "bart", "ranef", "trees"),
sample = c("train", "test"),
combineChains = TRUE,
...)
## S3 method for class 'rbart'
predict(
object, newdata, group.by, offset,
type = c("ev", "ppd", "bart", "ranef"),
combineChains = TRUE,
...)
## S3 method for class 'rbart'
residuals(object, ...)
group.by |
Grouping factor. Can be an integer vector/factor, or a reference to such in |
group.by.test |
Grouping factor for test data, of the same type as |
prior |
A function or symbolic reference to built-in priors. Determines the prior over the standard deviation of the random effects. Supplied functions take two arguments, |
n.thin |
The number of tree jumps taken for every stored sample, but also the number of samples from the posterior of the standard deviation of the random effects before one is kept. |
keepTestFits |
Logical where, if false, test fits are obtained while running but not returned. Useful with |
callback |
Optional function of |
formula , data , test , subset , weights , offset , offset.test , sigest , sigdf , sigquant , k , power , base , n.trees , n.samples , n.burn , n.chains , n.threads , combineChains , n.cuts , useQuantiles , keepTrainingFits , printEvery , printCutoffs , verbose , keepTrees , keepCall , seed , keepSampler , ... |
Same as in |
object |
A fitted |
newdata |
Same as |
type |
One of |
sample |
One of |
x , plquants , cols |
Same as in |
Fits a BART model with additive random intercepts, one for each factor level of group.by
. For continuous responses:
y_i \sim N(f(x_i) + \alpha_{g[i]}, \sigma^2)
\alpha_j \sim N(0, \tau^2)
.
For binary outcomes the response model is changed to P(Y_i = 1) = \Phi(f(x_i) + \alpha_{g[i]})
. i
indexes observations, g[i]
is the group index of observation i
, f(x)
and \sigma_y
come from a BART model, and \alpha_j
are the independent and identically distributed random intercepts. Draws from the posterior of tau
are made using a slice sampler, with a width dynamically determined by assessing the curvature of the posterior distribution at its mode.
Predicting random effects for groups not in the training sample is supported by sampling from their posterior predictive distribution, that is a draw is taken from p(\alpha \mid y) = \int p(\alpha \mid \tau)p(\tau \mid y)d\alpha
. For out-of-sample groups in the test data, these random effect draws can be kept with the saved object. For those supplied to predict
, they cannot and may change for subsequent calls.
See the generics section of bart
.
An object of class rbart
. Contains all of the same elements of an object of class bart
, as well as the elements:
ranef |
Samples from the posterior of the random effects. A array/matrix of posterior samples. The |
ranef.mean |
Posterior mean of random effects, derived by taking mean across group index of samples. |
tau |
Matrix of posterior samples of |
first.tau |
Burn-in draws of |
callback |
Optional results of |
Vincent Dorie: vdorie@gmail.com
bart
, dbarts
f <- function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y <- rnorm(n, Ey, sigma)
n.g <- 10
g <- sample(n.g, length(y), replace = TRUE)
sigma.b <- 1.5
b <- rnorm(n.g, 0, sigma.b)
y <- y + b[g]
df <- as.data.frame(x)
colnames(df) <- paste0("x_", seq_len(ncol(x)))
df$y <- y
df$g <- g
## low numbers to reduce run time
rbartFit <- rbart_vi(y ~ . - g, df, group.by = g,
n.samples = 40L, n.burn = 10L, n.thin = 2L,
n.chains = 1L,
n.trees = 25L, n.threads = 1L)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.