Nothing
## Copyright (C) 2024 Rodney A. Sparapani
## This file is part of nftbart.
## predict.nft2mi.R
## nftbart is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 2 of the License, or
## (at your option) any later version.
## nftbart is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.
## Author contact information
## Rodney A. Sparapani: rsparapa@mcw.edu
predict.nft2mi = function(
## data
object,
xftest=object$xftrain,
xstest=object$xstrain,
## multi-threading
tc=getOption("mc.cores", 1), ##OpenMP thread count
## current process fit vs. previous process fit
XPtr=FALSE, ## external pointers not working here
## predictions
K=0,
events=object$events,
FPD=FALSE,
probs=c(0.025, 0.975),
take.logs=TRUE,
na.rm=FALSE,
RMST.max=NULL,
##seed=NULL,
## default settings for NFT:BART/HBART/DPM
fmu=object$NFT$fmu,
soffset=object$soffset,
drawDPM=object$drawDPM,
## etc.
...)
{
if(is.null(object)) stop("No fitted model specified!\n")
xftest.list <- NULL
xstest.list <- NULL
if(is.list(xftest)) {
xftest.list <- xftest
xftest <- xftest.list[[1]]
}
if(is.list(xstest)) {
xstest.list <- xstest
xstest <- xstest.list[[1]]
}
n = nrow(object$xftrain[[1]])
np = nrow(xftest)
if(np!=nrow(xstest))
stop('The number of rows in xftest and xstest must be the same!')
pf = ncol(object$xftrain[[1]])
if(pf!=ncol(xftest))
stop('The number of columns in xftrain and xftest must be the same!')
ps = ncol(object$xstrain[[1]])
if(ps!=ncol(xstest))
stop('The number of columns in xstrain and xstest must be the same!')
if(FPD && np!=(n*(np%/%n)))
stop('The number of FPD blocks must be an integer')
events.matrix=FALSE
if(length(RMST.max)>0) {
K=0
} else if(length(K)==0) {
K=0
take.logs=FALSE
} else if(K>0) {
if(length(events)==0) {
##events = unique(quantile(object$z.train.mean,
events = unique(quantile(object$times,
probs=(1:K)/(K+1)))
attr(events, 'names') = NULL
take.logs=FALSE
K = length(events)
} else if(length(events)!=K) {
stop("K and the length of events don't match")
}
} else if(K==0 && length(events)>0) {
events.matrix=(class(events)[1]=='matrix')
if(events.matrix) {
if(FPD)
stop("Friedman's partial dependence function: can't be used with a matrix of events")
K=ncol(events)
} else K = length(events)
}
if(K>0 && take.logs) events=log(events)
object. <- object
res.list <- list()
mult.impute <- object$mult.impute
if(length(soffset) == 1) soffset[2:mult.impute] <- soffset[1]
for(i in 1:mult.impute) {
object$xftrain <- object.$xftrain[[i]]
object$xstrain <- object.$xstrain[[i]]
if(is.list(xftest.list)) xftest <- xftest.list[[i]]
if(is.list(xstest.list)) xstest <- xstest.list[[i]]
ptr <- c('ots', 'oid', 'ovar', 'oc', 'otheta',
'sts', 'sid', 'svar', 'sc', 'stheta',
'f.trees', 's.trees', 's.train.mask',
'dpmu', 'dpsd', 'dpmu.', 'dpsd.', 'dpwt.')
for(var in ptr)
eval(parse(text=paste0('object$', var,
' <- object.$', var, '[[i]]')))
attr(object, 'class') <- 'nft2'
res <- predict(
## data
object,
xftest=xftest,
xstest=xstest,
## multi-threading
tc=tc, ##OpenMP thread count
## current process fit vs. previous process fit
XPtr = FALSE, ## external pointers not working here
## predictions
K=K,
events=events,
FPD=FPD,
probs=probs,
take.logs=take.logs,
na.rm=na.rm,
RMST.max=RMST.max,
##seed=NULL,
## default settings for NFT:BART/HBART/DPM
fmu=fmu,
soffset=soffset[i],
drawDPM=drawDPM,
## etc.
...)
res.list[[i]] <- res
}
res <- res.list[[1]]
for(i in 2:mult.impute) {
if(FPD) {
res$surv.fpd <- rbind(res$surv.fpd, res.list[[i]]$surv.fpd)
res$pdf.fpd <- rbind(res$pdf.fpd, res.list[[i]]$pdf.fpd)
res$haz.fpd <- rbind(res$haz.fpd, res.list[[i]]$haz.fpd)
} else {
res$f.test <- rbind(res$f.test, res.list[[i]]$f.test)
res$s.test <- rbind(res$s.test, res.list[[i]]$s.test)
res$surv.test <- rbind(res$surv.test, res.list[[i]]$surv.test)
res$pdf.test <- rbind(res$pdf.test, res.list[[i]]$pdf.test)
res$haz.test <- rbind(res$haz.test, res.list[[i]]$haz.test)
}
res$soffset[i] <- res.list[[i]]$soffset
res$elapsed[i] <- res.list[[i]]$elapsed
}
lower <- min(probs)
upper <- max(probs)
if(FPD) {
## ndpost <- nrow(res$f.test)
## H <- np/n
## res$surv.fpd <- matrix(nrow = ndpost, ncol = H)
## res$pdf.fpd <- matrix(nrow = ndpost, ncol = H)
## res$haz.fpd <- matrix(nrow = ndpost, ncol = H)
## h <- 1:n
## for(i in 1:H) {
## res$surv.fpd[ , i] <- apply(res$surv.test[ , (i-1)*n+h], 1, mean)
## res$pdf.fpd[ , i] <- apply(res$pdf.test[ , (i-1)*n+h], 1, mean)
## res$haz.fpd[ , i] <- apply(res$haz.test[ , (i-1)*n+h], 1, mean)
## }
res$surv.fpd.mean <- apply(res$surv.fpd, 2, mean)
res$surv.fpd.lower <- apply(res$surv.fpd, 2, quantile, probs = lower)
res$surv.fpd.upper <- apply(res$surv.fpd, 2, quantile, probs = upper)
res$pdf.fpd.mean <- apply(res$pdf.fpd, 2, mean)
res$pdf.fpd.lower <- apply(res$pdf.fpd, 2, quantile, probs = lower)
res$pdf.fpd.upper <- apply(res$pdf.fpd, 2, quantile, probs = upper)
res$haz.fpd.mean <- apply(res$haz.fpd, 2, mean)
res$haz.fpd.lower <- apply(res$haz.fpd, 2, quantile, probs = lower)
res$haz.fpd.upper <- apply(res$haz.fpd, 2, quantile, probs = upper)
} else {
res$f.test.mean <- apply(res$f.test, 2, mean)
res$f.test.lower <- apply(res$f.test, 2, quantile, probs = lower)
res$f.test.upper <- apply(res$f.test, 2, quantile, probs = upper)
res$s.test.mean <- apply(res$s.test, 2, mean)
res$s.test.lower <- apply(res$s.test, 2, quantile, probs = lower)
res$s.test.upper <- apply(res$s.test, 2, quantile, probs = upper)
res$surv.test.mean <- apply(res$surv.test, 2, mean)
res$surv.test.lower <- apply(res$surv.test, 2, quantile, probs = lower)
res$surv.test.upper <- apply(res$surv.test, 2, quantile, probs = upper)
res$pdf.test.mean <- apply(res$pdf.test, 2, mean)
res$pdf.test.lower <- apply(res$pdf.test, 2, quantile, probs = lower)
res$pdf.test.upper <- apply(res$pdf.test, 2, quantile, probs = upper)
res$haz.test.mean <- apply(res$haz.test, 2, mean)
res$haz.test.lower <- apply(res$haz.test, 2, quantile, probs = lower)
res$haz.test.upper <- apply(res$haz.test, 2, quantile, probs = upper)
}
res$elapsed.sum <- sum(res$elapsed)
return(res)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.