Description Usage Arguments Value Author(s) References See Also Examples
Here we have implemented a simple and direct approach to utilize BART in survival analysis that is very flexible, and is akin to discrete-time survival analysis. Following the capabilities of BART, we allow for maximum flexibility in modeling the dependence of survival times on covariates. In particular, we do not impose proportional hazards.
To elaborate, consider data in the usual form: (t, delta, x) where t is the event time, delta is an indicator distinguishing events (delta=1) from right-censoring (delta=0), x is a vector of covariates, and i=1, ..., N (i suppressed for convenience) indexes subjects.
We denote the K distinct event/censoring times by 0<t(1)<...< t(K)<infinity thus taking t(j) to be the j'th order statistic among distinct observation times and, for convenience, t(0)=0. Now consider event indicators y(j) for each subject i at each distinct time t(j) up to and including the subject's observation time t=t(n) with n=sum I[t(j)<=t]. This means y(j)=0 if j<n and y(n)=delta.
We then denote by p(j) the probability of an event at time t(j) conditional on no previous event. We now write the model for y(j) as a nonparametric probit regression of y(j) on the time t(j) and the covariates x, and then utilize BART for binary responses. Specifically, y(j) = delta I[t=t(j)], j=1, ..., n ; we have p(j) = F(mu(j)), mu(j) = mu0+f(t(j), x) where F denotes the standard normal cdf (probit link). As in the binary response case, f is the sum of many tree models.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | surv.bart( x.train, y.train=NULL, times=NULL, delta=NULL, x.test=matrix(0.0,0,0),
k=2.0, power=2.0, base=.95, binaryOffset=NULL,
ntree=50, ndpost=10000, nskip=250, printevery=100, keepevery=10,
keeptrainfits=TRUE, usequants=FALSE, numcut=100, printcutoffs=0,
verbose=TRUE,
seed=99, ## mc.surv.bart only
mc.cores=2 ## mc.surv.bart only
)
mc.surv.bart( x.train, y.train=NULL, times=NULL, delta=NULL, x.test=matrix(0.0,0,0),
k=2.0, power=2.0, base=.95, binaryOffset=NULL,
ntree=50, ndpost=10000, nskip=250, printevery=100, keepevery=10,
keeptrainfits=TRUE, usequants=FALSE, numcut=100, printcutoffs=0,
verbose=TRUE,
seed=99, ## mc.surv.bart only
mc.cores=2 ## mc.surv.bart only
)
|
x.train |
Explanatory variables for training (in sample)
data. |
y.train |
Binary response dependent variable for training (in sample) data. |
times |
The time of event or right-censoring. |
delta |
The event indicator: 1 is an event while 0 is censored. |
x.test |
Explanatory variables for test (out of sample) data. |
k |
k is the number of prior standard deviations f(t, x) is away from +/-3. The bigger k is, the more conservative the fitting will be. |
power |
Power parameter for tree prior. |
base |
Base parameter for tree prior. |
binaryOffset |
The model is P(Y=1 | t, x) = F(f(t, x) + mu0)
where mu0 is specified by |
ntree |
The number of trees in the sum. |
ndpost |
The number of posterior draws after burn in, ndpost/keepevery will actually be returned. |
nskip |
Number of MCMC iterations to be treated as burn in. |
printevery |
As the MCMC runs, a message is printed every printevery draws. |
keepevery |
Every keepevery draw is kept to be returned to the user. |
keeptrainfits |
If true the draws of f(t, x) for x = rows of x.train are returned. |
usequants |
Decision rules in the tree are of the form x <= c vs. x > c for each variable corresponding to a column of x.train. usequants determines how the set of possible c is determined. If usequants is true, then the c are a subset of the values (xs[i]+xs[i+1])/2 where xs is unique sorted values obtained from the corresponding column of x.train. If usequants is false, the cutoffs are equally spaced across the range of values taken on by the corresponding column of x.train. |
numcut |
The number of possible values of c (see usequants). If a single number if given, this is used for all variables. Otherwise a vector with length equal to ncol(x.train) is required, where the i^th element gives the number of c used for the i^th variable in x.train. If usequants is false, numcut equally spaced cutoffs are used covering the range of values in the corresponding column of x.train. If usequants is true, then min(numcut, the number of unique values in the corresponding columns of x.train - 1) c values are used. |
printcutoffs |
The number of cutoff rules c to printed to screen before the MCMC is run. Give a single integer, the same value will be used for all variables. If 0, nothing is printed. |
verbose |
Logical, if FALSE supress printing. |
seed |
|
mc.cores |
|
surv.bart
returns a list.
Besides the items listed below,
the list has a binaryOffset
component giving the value used,
a times
component giving the unique times and K
which is the number of
unique times.
yhat.train |
A matrix with (ndpost/keepevery) rows and nrow(x.train) columns.
Each row corresponds to a draw f* from the posterior of f
and each column corresponds to a row of x.train.
The (i,j) value is f*(t, x) for the i\^th kept draw of f
and the j\^th row of x.train. |
yhat.test |
Same as yhat.train but now the x's are the rows of the test data. |
surv.test |
The survival function, S(t|x), where x's are the rows of the test data. |
yhat.train.mean |
train data fits = mean of yhat.train columns. |
yhat.test.mean |
test data fits = mean of yhat.test columns. |
surv.test.mean |
mean of surv.test columns. |
varcount |
a matrix with (ndpost/keepevery) rows and nrow(x.train) columns. Each row is for a draw. For each variable (corresponding to the columns), the total count of the number of times that variable is used in a tree decision rule (over all trees) is given. |
Note that yhat.train and yhat.test are
f(t, x) + binaryOffset
. If you want draws of the probability
P(Y=1 | t, x) you need to apply the normal cdf (pnorm
)
to these values.
Rodney Sparapani: rsparapa@mcw.edu
Sparapani, R., Logan, B., McCulloch, R., and Laud, P. (2016) Nonparametric survival analysis using Bayesian Additive Regression Trees (BART). Statistics in Medicine, in press.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | ## Not run:
require(survbart)
## load survival package for the advanced lung cancer example
require(survival)
group <- -which(is.na(lung[ , 7])) ## remove missing row for ph.karno
times <- lung[group, 2] ##lung$time
delta <- lung[group, 3]-1 ##lung$status: 1=censored, 2=dead
##delta: 0=censored, 1=dead
## this study reports time in days rather than months like other studies
## coarsening from days to months will reduce the computational burden
times <- ceiling(times/30)
summary(times)
table(delta)
x.train <- as.matrix(lung[group, c(4, 5, 7)]) ## matrix of observed covariates
## lung$age: Age in years
## lung$sex: Male=1 Female=2
## lung$ph.karno: Karnofsky performance score (dead=0:normal=100:by=10)
## rated by physician
dimnames(x.train)[[2]] <- c('age(yr)', 'M(1):F(2)', 'ph.karno(0:100:10)')
summary(x.train[ , 1])
table(x.train[ , 2])
table(x.train[ , 3])
x.test <- matrix(nrow=84, ncol=3) ## matrix of covariate scenarios
dimnames(x.test)[[2]] <- dimnames(x.train)[[2]]
i <- 1
for(age in 5*(9:15)) for(sex in 1:2) for(ph.karno in 10*(5:10)) {
x.test[i, ] <- c(age, sex, ph.karno)
i <- i+1
}
## run one long MCMC chain in one process
set.seed(99)
post <- surv.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)
## run "mc.cores" number of shorter MCMC chains in parallel processes
## post <- mc.surv.bart(x.train=x.train, times=times, delta=delta, x.test=x.test,
## mc.cores=20, seed=99)
##saveRDS(object=post, file='post.rds')
## you can save time by reading in the posterior
## instead of re-generating it every time
## post <- readRDS(file='post.rds')
## let's look at some survival curves
## first, a younger group with a healthier KPS
## age 50 with KPS=90: males and females
## males: row 17, females: row 23
x.test[c(17, 23), ]
low.risk.males <- 16*post$K+1:post$K ## K=unique times including censoring
low.risk.females <- 22*post$K+1:post$K
## second, an older group with a poor KPS
## age 70 with KPS=60: males and females
x.test[c(62, 68), ]
high.risk.males <- 61*post$K+1:post$K
high.risk.females <- 67*post$K+1:post$K
old.par <- par(mfrow=c(2, 2))
plot(post$times, post$surv.test.mean[low.risk.males], type='s', col='blue',
main='Age 50 with KPS=90',
xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[low.risk.females], type='s', col='red')
plot(post$times, post$surv.test.mean[high.risk.males], type='s', col='blue',
main='Age 70 with KPS=60',
xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[high.risk.females], type='s', col='red')
plot(post$times, post$surv.test.mean[low.risk.males], type='s', col='blue',
main='Males',
xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[high.risk.males], type='s')
plot(post$times, post$surv.test.mean[low.risk.females], type='s', col='red',
main='Females',
xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[high.risk.females], type='s')
par(old.par)
## calculate 95
credible95 <- apply(post$surv.test[ , low.risk.females], 2, quantile,
probs=c(0.025, 0.975))
plot(post$times, post$surv.test.mean[low.risk.females], type='s', col='red',
main='Females aged 50/KPS=90',
xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, credible95[1, ], type='s')
points(post$times, credible95[2, ], type='s')
## End(Not run)
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.