surv.bart: Nonparametric survival analysis with BART

Description Usage Arguments Value Author(s) References See Also Examples

Description

Here we have implemented a simple and direct approach to utilize BART in survival analysis that is very flexible, and is akin to discrete-time survival analysis. Following the capabilities of BART, we allow for maximum flexibility in modeling the dependence of survival times on covariates. In particular, we do not impose proportional hazards.

To elaborate, consider data in the usual form: (t, delta, x) where t is the event time, delta is an indicator distinguishing events (delta=1) from right-censoring (delta=0), x is a vector of covariates, and i=1, ..., N (i suppressed for convenience) indexes subjects.

We denote the K distinct event/censoring times by 0<t(1)<...< t(K)<infinity thus taking t(j) to be the j'th order statistic among distinct observation times and, for convenience, t(0)=0. Now consider event indicators y(j) for each subject i at each distinct time t(j) up to and including the subject's observation time t=t(n) with n=sum I[t(j)<=t]. This means y(j)=0 if j<n and y(n)=delta.

We then denote by p(j) the probability of an event at time t(j) conditional on no previous event. We now write the model for y(j) as a nonparametric probit regression of y(j) on the time t(j) and the covariates x, and then utilize BART for binary responses. Specifically, y(j) = delta I[t=t(j)], j=1, ..., n ; we have p(j) = F(mu(j)), mu(j) = mu0+f(t(j), x) where F denotes the standard normal cdf (probit link). As in the binary response case, f is the sum of many tree models.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
surv.bart( x.train, y.train=NULL, times=NULL, delta=NULL, x.test=matrix(0.0,0,0),
           k=2.0, power=2.0, base=.95, binaryOffset=NULL,
           ntree=50, ndpost=10000, nskip=250, printevery=100, keepevery=10,
           keeptrainfits=TRUE, usequants=FALSE, numcut=100, printcutoffs=0,
           verbose=TRUE,
           seed=99,   ## mc.surv.bart only
           mc.cores=2 ## mc.surv.bart only
         )

mc.surv.bart( x.train, y.train=NULL, times=NULL, delta=NULL, x.test=matrix(0.0,0,0),
              k=2.0, power=2.0, base=.95, binaryOffset=NULL,
              ntree=50, ndpost=10000, nskip=250, printevery=100, keepevery=10,
              keeptrainfits=TRUE, usequants=FALSE, numcut=100, printcutoffs=0,
              verbose=TRUE,
              seed=99,   ## mc.surv.bart only
              mc.cores=2 ## mc.surv.bart only
            )

Arguments

x.train

Explanatory variables for training (in sample) data.
Must be a matrix with (as usual) rows corresponding to observations and columns to variables.
surv.bart will generate draws of f(t, x) for each x which is a row of x.train (note that the definition of x.train is dependent on whether y.train has been specified; see below).

y.train

Binary response dependent variable for training (in sample) data.
If y.train is NULL, then y.train (x.train and x.test, if specified) are generated by a call to surv.pre.bart (which require that times and delta be provided: see below); otherwise, y.train (x.train and x.test, if specified) are utilized as given assuming that the data construction has already been performed.

times

The time of event or right-censoring.
If y.train is NULL, then times (and delta) must be provided.

delta

The event indicator: 1 is an event while 0 is censored.
If y.train is NULL, then delta (and times) must be provided.

x.test

Explanatory variables for test (out of sample) data.
Must be a matrix and have the same structure as x.train.
surv.bart will generate draws of f(t, x) for each x which is a row of x.test.

k

k is the number of prior standard deviations f(t, x) is away from +/-3. The bigger k is, the more conservative the fitting will be.

power

Power parameter for tree prior.

base

Base parameter for tree prior.

binaryOffset

The model is P(Y=1 | t, x) = F(f(t, x) + mu0) where mu0 is specified by binaryOffset.
The idea is that f is shrunk towards 0, so the offset allows you to shrink towards a probability other than .5.
If binaryOffset=NULL when times and delta were provided, then an exponential distribution offset is assumed independent of the covariates, i.e. binaryOffset=qnorm(1-exp(-mean.diff*sum(delta)/sum(times))) where mean.diff is the mean of the differences of the distinct ordered adjacent times, i.e. mean(t(1)-t(0), ..., t(K)-t(K-1)).
If binaryOffset=NULL when times and delta were not provided, then binaryOffset=0.

ntree

The number of trees in the sum.

ndpost

The number of posterior draws after burn in, ndpost/keepevery will actually be returned.

nskip

Number of MCMC iterations to be treated as burn in.

printevery

As the MCMC runs, a message is printed every printevery draws.

keepevery

Every keepevery draw is kept to be returned to the user.
A “draw” will consist of values f*(t, x) at x = rows from the train(optionally) and test data, where f* denotes the current draw of f.

keeptrainfits

If true the draws of f(t, x) for x = rows of x.train are returned.

usequants

Decision rules in the tree are of the form x <= c vs. x > c for each variable corresponding to a column of x.train. usequants determines how the set of possible c is determined. If usequants is true, then the c are a subset of the values (xs[i]+xs[i+1])/2 where xs is unique sorted values obtained from the corresponding column of x.train. If usequants is false, the cutoffs are equally spaced across the range of values taken on by the corresponding column of x.train.

numcut

The number of possible values of c (see usequants). If a single number if given, this is used for all variables. Otherwise a vector with length equal to ncol(x.train) is required, where the i^th element gives the number of c used for the i^th variable in x.train. If usequants is false, numcut equally spaced cutoffs are used covering the range of values in the corresponding column of x.train. If usequants is true, then min(numcut, the number of unique values in the corresponding columns of x.train - 1) c values are used.

printcutoffs

The number of cutoff rules c to printed to screen before the MCMC is run. Give a single integer, the same value will be used for all variables. If 0, nothing is printed.

verbose

Logical, if FALSE supress printing.

seed

mc.surv.bart only: seed required for reproducible MCMC.

mc.cores

mc.surv.bart only: number of cores to employ in parallel.

Value

surv.bart returns a list. Besides the items listed below, the list has a binaryOffset component giving the value used, a times component giving the unique times and K which is the number of unique times.

yhat.train

A matrix with (ndpost/keepevery) rows and nrow(x.train) columns. Each row corresponds to a draw f* from the posterior of f and each column corresponds to a row of x.train. The (i,j) value is f*(t, x) for the i\^th kept draw of f and the j\^th row of x.train.
Burn-in is dropped.

yhat.test

Same as yhat.train but now the x's are the rows of the test data.

surv.test

The survival function, S(t|x), where x's are the rows of the test data.

yhat.train.mean

train data fits = mean of yhat.train columns.

yhat.test.mean

test data fits = mean of yhat.test columns.

surv.test.mean

mean of surv.test columns.

varcount

a matrix with (ndpost/keepevery) rows and nrow(x.train) columns. Each row is for a draw. For each variable (corresponding to the columns), the total count of the number of times that variable is used in a tree decision rule (over all trees) is given.

Note that yhat.train and yhat.test are f(t, x) + binaryOffset. If you want draws of the probability P(Y=1 | t, x) you need to apply the normal cdf (pnorm) to these values.

Author(s)

Rodney Sparapani: rsparapa@mcw.edu

References

Sparapani, R., Logan, B., McCulloch, R., and Laud, P. (2016) Nonparametric survival analysis using Bayesian Additive Regression Trees (BART). Statistics in Medicine, in press.

See Also

surv.pre.bart

Examples

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
## Not run: 
require(survbart)

## load survival package for the advanced lung cancer example
require(survival)

group <- -which(is.na(lung[ , 7])) ## remove missing row for ph.karno
times <- lung[group, 2]   ##lung$time
delta <- lung[group, 3]-1 ##lung$status: 1=censored, 2=dead
                          ##delta: 0=censored, 1=dead

## this study reports time in days rather than months like other studies
## coarsening from days to months will reduce the computational burden
times <- ceiling(times/30)

summary(times)
table(delta)

x.train <- as.matrix(lung[group, c(4, 5, 7)]) ## matrix of observed covariates

## lung$age:        Age in years
## lung$sex:        Male=1 Female=2
## lung$ph.karno:   Karnofsky performance score (dead=0:normal=100:by=10)
##                  rated by physician

dimnames(x.train)[[2]] <- c('age(yr)', 'M(1):F(2)', 'ph.karno(0:100:10)')

summary(x.train[ , 1])
table(x.train[ , 2])
table(x.train[ , 3])

x.test <- matrix(nrow=84, ncol=3) ## matrix of covariate scenarios

dimnames(x.test)[[2]] <- dimnames(x.train)[[2]]

i <- 1

for(age in 5*(9:15)) for(sex in 1:2) for(ph.karno in 10*(5:10)) {
    x.test[i, ] <- c(age, sex, ph.karno)
    i <- i+1
}

## run one long MCMC chain in one process
set.seed(99)
post <- surv.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)

## run "mc.cores" number of shorter MCMC chains in parallel processes
## post <- mc.surv.bart(x.train=x.train, times=times, delta=delta, x.test=x.test,
##                      mc.cores=20, seed=99)

##saveRDS(object=post, file='post.rds')
## you can save time by reading in the posterior
## instead of re-generating it every time
## post <- readRDS(file='post.rds')

## let's look at some survival curves
## first, a younger group with a healthier KPS
## age 50 with KPS=90: males and females
## males: row 17, females: row 23
x.test[c(17, 23), ]

low.risk.males <- 16*post$K+1:post$K ## K=unique times including censoring

low.risk.females <- 22*post$K+1:post$K

## second, an older group with a poor KPS
## age 70 with KPS=60: males and females
x.test[c(62, 68), ]

high.risk.males <- 61*post$K+1:post$K

high.risk.females <- 67*post$K+1:post$K

old.par <- par(mfrow=c(2, 2))

plot(post$times, post$surv.test.mean[low.risk.males], type='s', col='blue',
     main='Age 50 with KPS=90',
     xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[low.risk.females], type='s', col='red')

plot(post$times, post$surv.test.mean[high.risk.males], type='s', col='blue',
     main='Age 70 with KPS=60',
     xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[high.risk.females], type='s', col='red')

plot(post$times, post$surv.test.mean[low.risk.males], type='s', col='blue',
     main='Males',
     xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[high.risk.males], type='s')

plot(post$times, post$surv.test.mean[low.risk.females], type='s', col='red',
     main='Females',
     xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, post$surv.test.mean[high.risk.females], type='s')

par(old.par)

## calculate 95
credible95 <- apply(post$surv.test[ , low.risk.females], 2, quantile,
                    probs=c(0.025, 0.975))

plot(post$times, post$surv.test.mean[low.risk.females], type='s', col='red',
     main='Females aged 50/KPS=90',
     xlab='t', ylab='S(t)', ylim=c(0, 1))
points(post$times, credible95[1, ], type='s')
points(post$times, credible95[2, ], type='s')


## End(Not run)

survbart documentation built on May 2, 2019, 5:47 p.m.

Related to surv.bart in survbart...