R/predict.nft2.R

Defines functions predict.nft2

Documented in predict.nft2

## Copyright (C) 2021-2022 Rodney A. Sparapani

## This file is part of nftbart.
## predict.nft2.R

## nftbart is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 2 of the License, or
## (at your option) any later version.

## nftbart is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

## Author contact information
## Rodney A. Sparapani: rsparapa@mcw.edu

predict.nft2 = function(
                       ## data
                       object,
                       xftest=object$xftrain,
                       xstest=object$xstrain,
                       ## multi-threading
                       tc=getOption("mc.cores", 1), ##OpenMP thread count
                       ## current process fit vs. previous process fit
                       XPtr=TRUE,
                       ## predictions
                       K=0,
                       events=object$events,
                       FPD=FALSE,
                       probs=c(0.025, 0.975),
                       take.logs=TRUE,
                       na.rm=FALSE,
                       seed=NULL,
                       ## default settings for NFT:BART/HBART/DPM
                       fmu=object$NFT$fmu,
                       soffset=object$soffset,
                       drawDPM=object$drawDPM,
                       ## etc.
                       ...)
{
    if(is.null(object)) stop("No fitted model specified!\n")
    np = nrow(xftest)
    if(np!=nrow(xstest))
        stop('The number of rows in xftest and xstest must be the same')
    ptm <- proc.time()
    xftest = t(xftest)
    xstest = t(xstest)
    ndpost=object$ndpost
    if(length(fmu)==0 && length(object$fmu)==1) fmu=object$fmu 
    m=object$ntree[1]
    if(length(object$ntree)==2) mh=object$ntree[2]
    else mh=object$ntreeh
    xif=object$xifcuts
    xis=object$xiscuts
    n = nrow(object$xftrain)
    if(FPD && np!=(n*(np%/%n)))
        stop('the number of FPD blocks must be an integer')
    pf = ncol(object$xftrain)
    ps = ncol(object$xstrain)
    events.matrix=FALSE
    
    if(length(K)==0) {
        K=0
        take.logs=FALSE
    } else if(K>0) {
        if(length(events)==0) {
            events = unique(quantile(object$z.train.mean,
                                      probs=(1:K)/(K+1)))
            attr(events, 'names') = NULL
            take.logs=FALSE
            K = length(events)
        } else if(length(events)!=K) {
            stop("K and the length of events don't match")
        }
    } else if(K==0 && length(events)>0) {
        events.matrix=(class(events)[1]=='matrix')
        if(events.matrix) {
            if(FPD)
                stop("Friedman's partial dependence function: can't be used with a matrix of events")
            K=ncol(events)
        } else K = length(events)
    }
    if(K>0 && take.logs) events=log(events)
    
    if(length(drawDPM)==0) drawDPM=0

    draw.logt=(length(seed)>0)
    nd=length(object$s.train.mask)
    mask=(nd>0)
    if(!mask) nd=ndpost
    
    q.lower=min(probs)
    q.upper=max(probs)

    if(XPtr) {
        res=.Call("cpsambrt_predict2",
                  xftest, xstest,
                  m,
                  mh,
                  ndpost,
                  xif, xis,
                  tc,
                  object,
                  PACKAGE="nftbart"
                  )
        if(mask) {
            res$f.test.=res$f.test.[object$s.train.mask, ]
            res$s.test.=res$s.test.[object$s.train.mask, ]
        }
        res$f.test.=res$f.test.+fmu
        res$f.test.mean.=apply(res$f.test.,2,mean)
        res$f.test.lower.=
            apply(res$f.test.,2,quantile,probs=q.lower,na.rm=na.rm)
        res$f.test.upper.=
            apply(res$f.test.,2,quantile,probs=q.upper,na.rm=na.rm)
        res$s.test.mean.=apply(res$s.test.,2,mean)
        res$s.test.lower.=
            apply(res$s.test.,2,quantile,probs=q.lower,na.rm=na.rm)
        res$s.test.upper.=
            apply(res$s.test.,2,quantile,probs=q.upper,na.rm=na.rm)
    } else {
        res=list()
    }

    if(np>0) {
        res.=.Call("cprnft2",
                   object,
                   xftest, xstest,
                   xif, xis,
                   tc,
                   PACKAGE="nftbart"
                   )
        if(mask) {
            res.$f.test=res.$f.test[object$s.train.mask, ]
            res.$s.test=res.$s.test[object$s.train.mask, ]
        }
        res$f.test=res.$f.test+fmu
        res$fmu=fmu
        
        m=length(soffset)
        if(m==0) soffset=0
        else if(m>1) 
            soffset=sqrt(mean(soffset^2, na.rm=TRUE)) ## Inf set to NA
        
        res$s.test=exp(res.$s.test-soffset)
        res$s.test.mean =apply(res$s.test, 2, mean)
        res$s.test.lower=apply(res$s.test, 2,
                               quantile, probs=q.lower, na.rm=na.rm)
        res$s.test.upper=apply(res$s.test, 2,
                               quantile, probs=q.upper, na.rm=na.rm)
        res$soffset=soffset

        res$f.test.mean =apply(res$f.test, 2, mean)
        res$f.test.lower=apply(res$f.test, 2,
                               quantile,probs=q.lower,na.rm=na.rm)
        res$f.test.upper=apply(res$f.test, 2,
                               quantile,probs=q.upper,na.rm=na.rm)
        
        if(K>0) {
            ## if(length(events)==0) {
            ##     events <- unique(quantile(object$z.train.mean,
            ##                               probs=(1:K)/(K+1)))
            ##     attr(events, 'names') <- NULL
            ## } else if(take.logs) events=log(events)
            ## events.matrix=(class(events)[1]=='matrix')
            ## if(events.matrix) K=ncol(events)
            ## else K <- length(events)

            if(FPD) {
                H=np%/%n
                if(drawDPM>0) {
                    for(h in 1:H) {
                        if(h==1) {
                            mu. = object$dpmu
                            sd. = object$dpsd
                        } else {
                            mu. = cbind(mu., object$dpmu)
                            sd. = cbind(sd., object$dpsd)
                        }
                    }
                    mu. = mu.*res$s.test
                    sd. = sd.*res$s.test
                    mu. = mu.+res$f.test
                } else {
                    mu. = res$f.test
                    sd. = res$s.test
                }

                ## these FPD predict objects are often too large
                ## to be re-loaded in typical interactive sessions
                ## therefore, we drop the intermediate results
                ##if(K>1) {
                    res$f.test = NULL
                    if(length(res$f.test.)>0) res$f.test. = NULL
                    res$s.test = NULL
                    if(length(res$s.test.)>0) res$s.test. = NULL
                ##}
                
                surv.fpd=list()
                surv.test=list()
                pdf.fpd =list()
                pdf.test =list()
                for(i in 1:H) {
                    h=(i-1)*n+1:n
                    for(j in 1:K) {
                        if(j==1) {
                            surv.fpd[[i]]=list()
                            surv.test[[i]]=list()
                            pdf.fpd[[i]] =list()
                            pdf.test[[i]] =list()
                        }
                        z=events[j]
                        t=exp(z)
                        surv.fpd[[i]][[j]]=
                            apply(matrix(pnorm(z, mu.[ , h], sd.[ , h], lower.tail=FALSE),
                                         nrow=nd, ncol=n), 1, mean)
                        pdf.fpd[[i]][[j]]=
                            apply(matrix(dnorm(z, mu.[ , h], sd.[ , h])/(t*sd.[ , h]),
                                         nrow=nd, ncol=n), 1, mean)
                        if(i==1 && j==1) {
                            res$surv.fpd=cbind(surv.fpd[[1]][[1]])
                            res$pdf.fpd =cbind(pdf.fpd[[1]][[1]])
                        } else {
                            res$surv.fpd=cbind(res$surv.fpd, surv.fpd[[i]][[j]])
                            res$pdf.fpd =cbind(res$pdf.fpd, pdf.fpd[[i]][[j]])
                        }
                        ## if(K==1) {
                        ##     surv.test[[i]][[j]]=matrix(pnorm(z, mu.[ , h], sd.[ , h], lower.tail=FALSE),
                        ##                                nrow=nd, ncol=n)
                        ##     pdf.test[[i]][[j]] =matrix(dnorm(z, mu.[ , h], sd.[ , h])/(t*sd.[ , h]),
                        ##                                nrow=nd, ncol=n)
                        ##     if(i==1 && j==1) {
                        ##         res$surv.test=cbind(surv.test[[1]][[1]])
                        ##         res$pdf.test =cbind(pdf.test[[1]][[1]])
                        ##     } else {
                        ##         res$surv.test=cbind(res$surv.test, surv.test[[i]][[j]])
                        ##         res$pdf.test =cbind(res$pdf.test, pdf.test[[i]][[j]])
                        ##     }
                        ## }
                    }
                }
                
                res$surv.fpd.mean=apply(cbind(res$surv.fpd), 2, mean)
                res$surv.fpd.lower=
                    apply(cbind(res$surv.fpd), 2, quantile, probs=q.lower)
                res$surv.fpd.upper=
                    apply(cbind(res$surv.fpd), 2, quantile, probs=q.upper)
                res$haz.fpd=cbind(res$pdf.fpd/res$surv.fpd)
                    res$haz.fpd.mean =apply(cbind(res$haz.fpd), 2, mean)
                    res$haz.fpd.lower=
                        apply(cbind(res$haz.fpd), 2, quantile, probs=q.lower)
                    res$haz.fpd.upper=
                        apply(cbind(res$haz.fpd), 2, quantile, probs=q.upper)
                res$pdf.fpd.mean =apply(cbind(res$pdf.fpd), 2, mean)
                ## if(K==1) {
                ##     res$surv.test.mean=apply(res$surv.test, 2, mean)
                ##     res$haz.test=res$pdf.test/res$surv.test
                ##     res$haz.test.mean =apply(res$haz.test,  2, mean)
                ##     res$pdf.test.mean =apply(res$pdf.test,  2, mean)
                ## }
            } else if(drawDPM>0) {
                res$surv.test=matrix(0, nrow=nd, ncol=np*K)
                res$pdf.test =matrix(0, nrow=nd, ncol=np*K)
                H=ncol(object$dpwt.)
                ##H=max(c(object$dpn.))
                events.=events
                
                for(i in 1:np) {
                    if(events.matrix) events.=events[i, ]
                    mu.=res$f.test[ , i]
                    sd.=res$s.test[ , i]
                    for(j in 1:K) {
                        k=(i-1)*K+j
                        z=events.[j]
                        t=exp(z)
                        for(h in 1:H) {
                            res$surv.test[ , k]=res$surv.test[ , k]+
                                object$dpwt.[ , h]*
                                pnorm(z, mu.+sd.*object$dpmu.[ , h],
                                      sd.*object$dpsd.[ , h], FALSE)
                            res$pdf.test[ , k]=res$pdf.test[ , k]+
                                object$dpwt.[ , h]*
                                    dnorm(z, mu.+sd.*object$dpmu.[ , h],
                                      sd.*object$dpsd.[ , h])/
                                    (t*sd.*object$dpsd.[ , h])
                        }
                    }
                }
                res$surv.test.mean=apply(res$surv.test, 2, mean)
                res$surv.test.lower=apply(res$surv.test, 2, quantile,
                                          probs=q.lower)
                res$surv.test.upper=apply(res$surv.test, 2, quantile,
                                          probs=q.upper)
                res$pdf.test.mean=apply(res$pdf.test, 2, mean)
                res$pdf.test.lower=apply(res$pdf.test, 2, quantile,
                                          probs=q.lower)
                res$pdf.test.upper=apply(res$pdf.test, 2, quantile,
                                          probs=q.upper)
                res$haz.test=res$pdf.test/res$surv.test
                res$haz.test.mean =apply(res$haz.test, 2, mean)
                res$haz.test.lower=apply(res$haz.test, 2, quantile,
                                          probs=q.lower)
                res$haz.test.upper=apply(res$haz.test, 2, quantile,
                                          probs=q.upper)
            }

            if(take.logs) res$events=exp(events)
            else res$events=events
        }
        res$K=K
        
        if(draw.logt) {
            set.seed(seed)

            if(drawDPM>0) {
                res$logt.test=matrix(0, nrow=nd, ncol=np)
                H=ncol(object$dpwt.)
                for(i in 1:np) {
                    mu. = res$f.test[ , i]
                    sd. = res$s.test[ , i]
                    for(h in 1:H) {
                        res$logt.test[ , i]=res$logt.test[ , i]+
                            object$dpwt.[ , h]*
                            rnorm(nd, mu.+sd.*object$dpmu.[ , h],
                                  sd.*object$dpsd.[ , h])
                    }
                }
            } else {
                res$logt.test=matrix(nrow=nd, ncol=np)
                for(i in 1:np) {
                    mu. = res$f.test[ , i]
                    sd. = res$s.test[ , i]
                    res$logt.test[ , i]=rnorm(nd, mu., sd.)
                }
            }
            res$logt.test.mean=apply(res$logt.test, 2, mean)
        }
    } else {
        res$f.test=res$f.test.
        res$s.test=res$s.test.
    }

    res$probs=probs
    res$elapsed <- (proc.time()-ptm)['elapsed']
    attr(res$elapsed, 'names')=NULL

    return(res)
}

Try the nftbart package in your browser

Any scripts or data that you put into this service are public.

nftbart documentation built on May 1, 2023, 1:08 a.m.