View source: R/predict.mbart.R
predict.mbart | R Documentation |
BART is a Bayesian “sum-of-trees” model.
For a numeric response y
, we have
y = f(x) + \epsilon
,
where \epsilon \sim N(0,\sigma^2)
.
f
is the sum of many tree models.
The goal is to have very flexible inference for the uknown
function f
.
In the spirit of “ensemble models”, each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.
## S3 method for class 'mbart'
predict(object, newdata, mc.cores=1, openmp=(mc.cores.openmp()>0), ...)
## S3 method for class 'mbart2'
predict(object, newdata, mc.cores=1, openmp=(mc.cores.openmp()>0), ...)
object |
|
newdata |
Matrix of covariates to predict the distribution of |
mc.cores |
Number of threads to utilize. |
openmp |
Logical value dictating whether OpenMP is utilized for parallel
processing. Of course, this depends on whether OpenMP is available
on your system which, by default, is verified with |
... |
Other arguments which will be passed on to |
BART is an Bayesian MCMC method.
At each MCMC interation, we produce a draw from the joint posterior
(f,\sigma) | (x,y)
in the numeric y
case
and just f
in the binary y
case.
Thus, unlike a lot of other modelling methods in R, we do not produce a single model object
from which fits and summaries may be extracted. The output consists of values
f^*(x)
(and \sigma^*
in the numeric case) where * denotes a particular draw.
The x
is either a row from the training data (x.train) or the test data (x.test).
Returns an object of type mbart
with predictions corresponding to newdata
.
mbart
, mbart2
## load the advanced lung cancer example
data(lung)
group <- -which(is.na(lung[ , 7])) ## remove missing row for ph.karno
times <- lung[group, 2] ##lung$time
delta <- lung[group, 3]-1 ##lung$status: 1=censored, 2=dead
##delta: 0=censored, 1=dead
## this study reports time in days rather than months like other studies
## coarsening from days to months will reduce the computational burden
times <- ceiling(times/30)
summary(times)
table(delta)
x.train <- as.matrix(lung[group, c(4, 5, 7)]) ## matrix of observed covariates
## lung$age: Age in years
## lung$sex: Male=1 Female=2
## lung$ph.karno: Karnofsky performance score (dead=0:normal=100:by=10)
## rated by physician
dimnames(x.train)[[2]] <- c('age(yr)', 'M(1):F(2)', 'ph.karno(0:100:10)')
summary(x.train[ , 1])
table(x.train[ , 2])
table(x.train[ , 3])
x.test <- matrix(nrow=84, ncol=3) ## matrix of covariate scenarios
dimnames(x.test)[[2]] <- dimnames(x.train)[[2]]
i <- 1
for(age in 5*(9:15)) for(sex in 1:2) for(ph.karno in 10*(5:10)) {
x.test[i, ] <- c(age, sex, ph.karno)
i <- i+1
}
## this x.test is relatively small, but often you will want to
## predict for a large x.test matrix which may cause problems
## due to consumption of RAM so we can predict separately
## mcparallel/mccollect do not exist on windows
if(.Platform$OS.type=='unix') {
##test BART with token run to ensure installation works
set.seed(99)
post <- surv.bart(x.train=x.train, times=times, delta=delta, nskip=5, ndpost=5, keepevery=1)
pre <- surv.pre.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)
pred <- predict(post, pre$tx.test)
##pred. <- surv.pwbart(pre$tx.test, post$treedraws, post$binaryOffset)
}
## Not run:
## run one long MCMC chain in one process
set.seed(99)
post <- surv.bart(x.train=x.train, times=times, delta=delta)
## run "mc.cores" number of shorter MCMC chains in parallel processes
## post <- mc.surv.bart(x.train=x.train, times=times, delta=delta,
## mc.cores=5, seed=99)
pre <- surv.pre.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)
pred <- predict(post, pre$tx.test)
## let's look at some survival curves
## first, a younger group with a healthier KPS
## age 50 with KPS=90: males and females
## males: row 17, females: row 23
x.test[c(17, 23), ]
low.risk.males <- 16*post$K+1:post$K ## K=unique times including censoring
low.risk.females <- 22*post$K+1:post$K
plot(post$times, pred$surv.test.mean[low.risk.males], type='s', col='blue',
main='Age 50 with KPS=90', xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, pred$surv.test.mean[low.risk.females], type='s', col='red')
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.