nft2: Fit NFT BART models.

View source: R/nft2.R

nft2R Documentation

Fit NFT BART models.

Description

The nft2()/nft() function is for fitting NFT BART (Nonparametric Failure Time Bayesian Additive Regression Tree) models with different train/test matrices for f and sd functions.

Usage

nft2(
    ## data
    xftrain, xstrain, times, delta=NULL,
    xftest=matrix(nrow=0, ncol=0),
    xstest=matrix(nrow=0, ncol=0),
    rm.const=TRUE, rm.dupe=TRUE,
    ## multi-threading
    tc=getOption("mc.cores", 1), 
    ##MCMC
    nskip=1000, ndpost=2000, nadapt=1000, adaptevery=100,
    chvf=NULL, chvs=NULL,
    method="spearman", use="pairwise.complete.obs",
    pbd=c(0.7, 0.7), pb=c(0.5, 0.5),
    stepwpert=c(0.1, 0.1), probchv=c(0.1, 0.1),
    minnumbot=c(5, 5),
    ## BART and HBART prior parameters
    ntree=c(50, 10), numcut=100, 
    xifcuts=NULL, xiscuts=NULL,
    power=c(2, 2), base=c(0.95, 0.95),
    ## f function
    fmu=NA, k=5, tau=NA, dist='weibull', 
    ## s function
    total.lambda=NA, total.nu=10, mask=NULL,
    ## survival analysis 
    K=100, events=NULL, TSVS=FALSE,
    ## DPM LIO
    drawDPM=1L, 
    alpha=1, alpha.a=1, alpha.b=0.1, alpha.draw=1,
    neal.m=2, constrain=1, 
    m0=0, k0.a=1.5, k0.b=7.5, k0=1, k0.draw=1,
    a0=3, b0.a=2, b0.b=1, b0=1, b0.draw=1,
    ## misc
    na.rm=FALSE, probs=c(0.025, 0.975), printevery=100,
    transposed=FALSE, pred=FALSE
)

nft(
    ## data
    x.train, times, delta=NULL, x.test=matrix(nrow=0, ncol=0),
    rm.const=TRUE, rm.dupe=TRUE,
    ## multi-threading
    tc=getOption("mc.cores", 1), 
    ##MCMC
    nskip=1000, ndpost=2000, nadapt=1000, adaptevery=100,
    chv=NULL,
    method="spearman", use="pairwise.complete.obs",
    pbd=c(0.7, 0.7), pb=c(0.5, 0.5),
    stepwpert=c(0.1, 0.1), probchv=c(0.1, 0.1),
    minnumbot=c(5, 5),
    ## BART and HBART prior parameters
    ntree=c(50, 10), numcut=100, xicuts=NULL,
    power=c(2, 2), base=c(0.95, 0.95),
    ## f function
    fmu=NA, k=5, tau=NA, dist='weibull', 
    ## s function
    total.lambda=NA, total.nu=10, mask=NULL,
    ## survival analysis 
    K=100, events=NULL, TSVS=FALSE,
    ## DPM LIO
    drawDPM=1L, 
    alpha=1, alpha.a=1, alpha.b=0.1, alpha.draw=1,
    neal.m=2, constrain=1, 
    m0=0, k0.a=1.5, k0.b=7.5, k0=1, k0.draw=1,
    a0=3, b0.a=2, b0.b=1, b0=1, b0.draw=1,
    ## misc
    na.rm=FALSE, probs=c(0.025, 0.975), printevery=100,
    transposed=FALSE, pred=FALSE
)

Arguments

xftrain

n x pf matrix of predictor variables for the training data.

xstrain

n x ps matrix of predictor variables for the training data.

x.train

n x p matrix of predictor variables for the training data.

times

n x 1 vector of the observed times for the training data.

delta

n x 1 vector of the time type for the training data: 0, for right-censoring; 1, for an event; and, 2, for left-censoring.

xftest

m x pf matrix of predictor variables for the test set.

xstest

m x ps matrix of predictor variables for the test set.

x.test

m x p matrix of predictor variables for the test set.

rm.const

To remove constant variables or not.

rm.dupe

To remove duplicate variables or not.

tc

Number of OpenMP threads to use.

nskip

Number of MCMC iterations to burn-in and discard.

ndpost

Number of MCMC iterations kept after burn-in.

nadapt

Number of MCMC iterations for adaptation prior to burn-in.

adaptevery

Adapt MCMC proposal distributions every adaptevery iteration.

chvf,chvs,chv

Predictor correlation matrix used as a pre-conditioner for MCMC change-of-variable proposals.

method,use

Correlation options for change-of-variable proposal pre-conditioner.

pbd

Probability of performing a birth/death proposal, otherwise perform a rotate proposal.

pb

Probability of performing a birth proposal given that we choose to perform a birth/death proposal.

stepwpert

Initial width of proposal distribution for peturbing cut-points.

probchv

Probability of performing a change-of-variable proposal. Otherwise, only do a perturb proposal.

minnumbot

Minimum number of observations required in leaf (terminal) nodes.

ntree

Vector of length two for the number of trees used for the mean model and the number of trees used for the variance model.

numcut

Number of cutpoints to use for each predictor variable.

xifcuts,xiscuts,xicuts

More detailed construction of cut-points can be specified by the xicuts function and provided here.

power

Power parameter in the tree depth penalizing prior.

base

Base parameter in the tree depth penalizing prior.

fmu

Prior parameter for the center of the mean model.

k

Prior parameter for the mean model.

tau

Desired SD/ntree for f function leaf prior if known.

dist

Distribution to be passed to intercept-only AFT model to center y.train.

total.lambda

A rudimentary estimate of the process standard deviation. Used in calibrating the variance prior.

total.nu

Shape parameter for the variance prior.

mask

If a proportion is provided, then said quantile of max.i sd(x.i) is used to mask non-stationary departures (with respect to convergence) above this threshold.

K

Number of grid points for which to estimate survival probability.

events

Grid points for which to estimate survival probability.

TSVS

Setting to TRUE will avoid unnecessary processing for Thompson sampling variable selection, i.e., all that is needed is the variable counts from the tree branch decision rules.

drawDPM

Whether to utilize DPM or not.

alpha

Initial value of DPM concentration parameter.

alpha.a

Gamma prior parameter setting for DPM concentration parameter where E[alpha]=alpha.a/alpha.b.

alpha.b

See alpha.a above.

alpha.draw

Whether to draw alpha or it is fixed at the initial value.

neal.m

The number of additional atoms for Neal 2000 DPM algorithm 8.

constrain

Whether to perform constained DPM or unconstrained.

m0

Center of the error distribution: defaults to zero.

k0.a

First Gamma prior argument for k0.

k0.b

Second Gamma prior argument for k0.

k0

Initial value of k0.

k0.draw

Whether to fix k0 or draw it if from the DPM LIO prior hierarchy: k0~Gamma(k0.a, k0.b), i.e., E[k0]=k0.a/k0.b.

a0

First Gamma prior argument for tau.

b0.a

First Gamma prior argument for b0.

b0.b

Second Gamma prior argument for b0.

b0

Initial value of b0.

b0.draw

Whether to fix b0 or draw it from the DPM LIO prior hierarchy: b0~Gamma(b0.a, b0.b), i.e., E[b0]=b0.a/b0.b.

na.rm

Value to be passed to the predict function.

probs

Value to be passed to the predict function.

printevery

Outputs MCMC algorithm status every printevery iterations.

transposed

Specify TRUE if all of the pre-processing for xftrain/xstrain/xftest/xstest has been conducted prior to the call (including tranposing).

pred

Specify TRUE if you want to return the pred item that is used to calculate soffset.

Details

nft2()/nft() is the function to fit time-to-event data. The most general form of the model allowed is Y({\bf x})=mu+f({\bf x})+sd({\bf x})Z where E follows a nonparametric error distribution by default. The nft2()/nft() function returns a fit object of S3 class type nft2/nft that is essentially a list containing the following items.

Value

ots,oid,ovar,oc,otheta

These are XPtrs to the BART f(x) objects in RAM that are only available for fits generated in the current R session.

sts,sid,svar,sc,stheta

Similarly, these are XPtrs to the HBART sd(x) objects.

fmu

The constant mu.

f.train,s.train

The trained f(x) and sd(x) respectively: matrices with ndpost rows and n columns.

f.train.mean,s.train.mean

The posterior mean of the trained f(x) and sd(x) respectively: vectors of length n.

f.trees,s.trees

Character strings representing the trained fits of f(x) and sd(x) respectively to facilitate usage of the predict function when XPtrs are unavailable.

dpalpha

The draws of the DPM concentration parameter alpha.

dpn,dpn.

The number of atom clusters per DPM, J, for all draws including burn-in and excluding burn-in respectively.

dpmu

The draws of the DPM parameter mu[i] where i=1,...,n indexes subjects: a matrix with ndpost rows and n columns.

dpmu.

The draws of the DPM parameter mu[j] where j=1,...,J indexes atom clusters: a matrix with ndpost rows and J columns.

dpwt.

The weights for efficient DPM calculations by atom clusters (as opposed to subjects) for use with dpmu. (and dpsd.; see below): a matrix with ndpost rows and J columns.

dpsd,dpsd.

Similarly, the draws of the DPM parameter tau[i] transformed into the standard deviation sigma[i] for convenience.

dpC

The indices j for each subject i corresponding to their shared atom cluster.

z.train

The data values/augmentation draws of log t.

f.tmind/f.tavgd/f.tmaxd

The min/average/max tier degree of trees in the f ensemble.

s.tmind/s.tavgd/s.tmaxd

The min/average/max tier degree of trees in the s ensemble.

f.varcount,s.varcount

Variable importance counts of branch decision rules for each x of f and s respectively: matrices with ndpost rows and p columns.

f.varcount.mean,s.varcount.mean

Similarly, the posterior mean of the variable importance counts for each x of f and s respectively: vectors of length p.

f.varprob,s.varprob

Similarly, re-weighting the posterior mean of the variable importance counts as sum-to-one probabilities for each x of f and s respectively: vectors of length p.

LPML

The log Pseudo-Marginal Likelihood as typically calculated for right-/left-censoring.

pred

The object returned from the predict function where x.test=x.train in order to calculate the soffset item that is needed to use predict when XPtrs are not available.

soffset

See pred above.

aft

The AFT model fit used to initialize NFT BART.

elapsed

The elapsed time of the run in seconds.

Author(s)

Rodney Sparapani: rsparapa@mcw.edu

References

Sparapani R., Logan B., Maiers M., Laud P., McCulloch R. (2023) Nonparametric Failure Time: Time-to-event Machine Learning with Heteroskedastic Bayesian Additive Regression Trees and Low Information Omnibus Dirichlet Process Mixtures Biometrics (ahead of print) <doi:10.1111/biom.13857>.

See Also

predict.nft2, predict.nft

Examples


##library(nftbart)
data(lung)
N=length(lung$status)

##lung$status: 1=censored, 2=dead
##delta: 0=censored, 1=dead
delta=lung$status-1

## this study reports time in days rather than weeks or months
times=lung$time
times=times/7  ## weeks

## matrix of covariates
x.train=cbind(lung[ , -(1:3)])
## lung$sex:        Male=1 Female=2

## token run just to test installation
post=nft2(x.train, x.train, times, delta, K=0,
         nskip=0, ndpost=10, nadapt=4, adaptevery=1)


set.seed(99)
post=nft2(x.train, x.train, times, delta, K=0)
XPtr=TRUE

x.test = rbind(x.train, x.train)
x.test[ , 2]=rep(1:2, each=N)
K=75
events=seq(0, 150, length.out=K+1)
pred = predict(post, x.test, x.test, K=K, events=events[-1],
               XPtr=XPtr, FPD=TRUE)

plot(events, c(1, pred$surv.fpd.mean[1:K]), type='l', col=4,
     ylim=0:1, 
     xlab=expression(italic(t)), sub='weeks',
     ylab=expression(italic(S)(italic(t), italic(x))))
lines(events, c(1, pred$surv.fpd.upper[1:K]), lty=2, lwd=2, col=4)
lines(events, c(1, pred$surv.fpd.lower[1:K]), lty=2, lwd=2, col=4)
lines(events, c(1, pred$surv.fpd.mean[K+1:K]), lwd=2, col=2)
lines(events, c(1, pred$surv.fpd.upper[K+1:K]), lty=2, lwd=2, col=2)
lines(events, c(1, pred$surv.fpd.lower[K+1:K]), lty=2, lwd=2, col=2)
legend('topright', c('Adv. lung cancer\nmortality example',
                     'M', 'F'), lwd=2, col=c(0, 4, 2), lty=1)



nftbart documentation built on May 1, 2023, 1:08 a.m.

Related to nft2 in nftbart...