mc.crisk.pwbart: Predicting new observations with a previously fitted BART...

View source: R/mc.crisk.pwbart.R

mc.crisk.pwbartR Documentation

Predicting new observations with a previously fitted BART model

Description

BART is a Bayesian “sum-of-trees” model.
For a numeric response y, we have y = f(x) + \epsilon, where \epsilon \sim N(0,\sigma^2).

f is the sum of many tree models. The goal is to have very flexible inference for the uknown function f.

In the spirit of “ensemble models”, each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.

Usage

mc.crisk.pwbart( x.test, x.test2,
                 treedraws, treedraws2,
                 binaryOffset=0, binaryOffset2=0,
                 mc.cores=2L, type='pbart',
                 transposed=FALSE, nice=19L
               )

Arguments

x.test

Matrix of covariates to predict y for cause 1.

x.test2

Matrix of covariates to predict y for cause 2.

treedraws

$treedraws for cause 1.

treedraws2

$treedraws for cause 2.

binaryOffset

Mean to add on to y prediction for cause 1.

binaryOffset2

Mean to add on to y prediction for cause 2.

mc.cores

Number of threads to utilize.

type

Whether to employ Albert-Chib, 'pbart', or Holmes-Held, 'lbart'.

transposed

When running pwbart or mc.pwbart in parallel, it is more memory-efficient to transpose x.test prior to calling the internal versions of these functions.

nice

Set the job niceness. The default niceness is 19: niceness goes from 0 (highest) to 19 (lowest).

Details

BART is an Bayesian MCMC method. At each MCMC interation, we produce a draw from the joint posterior (f,\sigma) | (x,y) in the numeric y case and just f in the binary y case.

Thus, unlike a lot of other modelling methods in R, we do not produce a single model object from which fits and summaries may be extracted. The output consists of values f^*(x) (and \sigma^* in the numeric case) where * denotes a particular draw. The x is either a row from the training data (x.train) or the test data (x.test).

Value

Returns an object of type criskbart which is essentially a list with components:

yhat.test

A matrix with ndpost rows and nrow(x.test) columns. Each row corresponds to a draw f^* from the posterior of f and each column corresponds to a row of x.train. The (i,j) value is f^*(x) for the i^{th} kept draw of f and the j^{th} row of x.train.
Burn-in is dropped.

surv.test

test data fits for survival probability.

surv.test.mean

mean of surv.test over the posterior samples.

prob.test

The probability of suffering cause 1 which is occasionally useful, e.g., in calculating the concordance.

prob.test2

The probability of suffering cause 2 which is occasionally useful, e.g., in calculating the concordance.

cif.test

The cumulative incidence function of cause 1, F_1(t, x), where x's are the rows of the test data.

cif.test2

The cumulative incidence function of cause 2, F_2(t, x), where x's are the rows of the test data.

yhat.test.mean

test data fits = mean of yhat.test columns.

cif.test.mean

mean of cif.test columns for cause 1.

cif.test2.mean

mean of cif.test2 columns for cause 2.

See Also

pwbart, crisk.bart, mc.crisk.bart

Examples


data(transplant)

delta <- (as.numeric(transplant$event)-1)
## recode so that delta=1 is cause of interest; delta=2 otherwise
delta[delta==1] <- 4
delta[delta==2] <- 1
delta[delta>1] <- 2
table(delta, transplant$event)

times <- pmax(1, ceiling(transplant$futime/7)) ## weeks
##times <- pmax(1, ceiling(transplant$futime/30.5)) ## months
table(times)

typeO <- 1*(transplant$abo=='O')
typeA <- 1*(transplant$abo=='A')
typeB <- 1*(transplant$abo=='B')
typeAB <- 1*(transplant$abo=='AB')
table(typeA, typeO)

x.train <- cbind(typeO, typeA, typeB, typeAB)

x.test <- cbind(1, 0, 0, 0)
dimnames(x.test)[[2]] <- dimnames(x.train)[[2]]

## parallel::mcparallel/mccollect do not exist on windows
if(.Platform$OS.type=='unix') {
##test BART with token run to ensure installation works
        post <- mc.crisk.bart(x.train=x.train, times=times, delta=delta,
                               seed=99, mc.cores=2, nskip=5, ndpost=5,
                               keepevery=1)

        pre <- surv.pre.bart(x.train=x.train, x.test=x.test,
                             times=times, delta=delta)

        K <- post$K

        pred <- mc.crisk.pwbart(pre$tx.test, pre$tx.test,
                                post$treedraws, post$treedraws2,
                                post$binaryOffset, post$binaryOffset2)
}

## Not run: 

## run one long MCMC chain in one process
## set.seed(99)
## post <- crisk.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)

## in the interest of time, consider speeding it up by parallel processing
## run "mc.cores" number of shorter MCMC chains in parallel processes
post <- mc.crisk.bart(x.train=x.train,
                       times=times, delta=delta,
                       x.test=x.test, seed=99, mc.cores=8)

check <- mc.crisk.pwbart(post$tx.test, post$tx.test,
                          post$treedraws, post$treedraws2,
                          post$binaryOffset,
                          post$binaryOffset2, mc.cores=8)
## check <- predict(post, newdata=post$tx.test, newdata2=post$tx.test2,
##                  mc.cores=8)

print(c(post$surv.test.mean[1], check$surv.test.mean[1],
        post$surv.test.mean[1]-check$surv.test.mean[1]), digits=22)

print(all(round(post$surv.test.mean, digits=9)==
    round(check$surv.test.mean, digits=9)))

print(c(post$cif.test.mean[1], check$cif.test.mean[1],
        post$cif.test.mean[1]-check$cif.test.mean[1]), digits=22)

print(all(round(post$cif.test.mean, digits=9)==
    round(check$cif.test.mean, digits=9)))

print(c(post$cif.test2.mean[1], check$cif.test2.mean[1],
        post$cif.test2.mean[1]-check$cif.test2.mean[1]), digits=22)

print(all(round(post$cif.test2.mean, digits=9)==
    round(check$cif.test2.mean, digits=9)))


## End(Not run)

BART documentation built on June 22, 2024, 11:33 a.m.