R/mi.nft2.R

Defines functions mi.nft2

## Copyright (C) 2024 Rodney A. Sparapani

## This file is part of nftbart.
## mi.nft2.R

## nftbart is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 2 of the License, or
## (at your option) any later version.

## nftbart is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

## Author contact information
## Rodney A. Sparapani: rsparapa@mcw.edu

mi.nft2 = function(## data
               xftrain, xstrain, times, delta=NULL, 
               xftest=matrix(nrow=0, ncol=0),
               xstest=matrix(nrow=0, ncol=0),
               mult.impute=4L,
               rm.const=TRUE, rm.dupe=TRUE,
               ##edraws2=matrix(nrow=0, ncol=0),
               ##zdraws2=matrix(nrow=0, ncol=0),
               ##impute.bin=NULL, impute.prob=NULL,
               ## multi-threading
               tc=getOption("mc.cores", 1), ##OpenMP thread count
               ##MCMC
               nskip=1000, ndpost=2000, 
               nadapt=1000, adaptevery=100, 
               chvf = NULL, chvs = NULL,
               method="spearman", use="pairwise.complete.obs",
               pbd=c(0.7, 0.7), pb=c(0.5, 0.5),
               stepwpert=c(0.1, 0.1), probchv=c(0.1, 0.1),
               minnumbot=c(5, 5),
               ## BART and HBART prior parameters
               ntree=c(50, 10), numcut=100,
               xifcuts=NULL, xiscuts=NULL,
               power=c(2, 2), base=c(0.95, 0.95),
               ## f function
               fmu=NA, k=5, tau=NA, dist='weibull', 
               ## s function
               total.lambda=NA, total.nu=10, mask=NULL,
               ## survival analysis 
               K=100, events=NULL, TSVS=FALSE,
               ## DPM LIO
               drawDPM=1L, 
               alpha=1, alpha.a=1, alpha.b=0.1, alpha.draw=1,
               neal.m=2, constrain=1, 
               m0=0, k0.a=1.5, k0.b=7.5, k0=1, k0.draw=1,
               ##a0=1.5, b0.a=0.5,
               a0=3, b0.a=2, b0.b=1, b0=1, b0.draw=1,
               ## misc
               na.rm=FALSE, probs=c(0.025, 0.975), printevery=100,
               transposed=FALSE, pred=FALSE
               )
{
    mi.ndpost <- ceiling(ndpost/mult.impute)
    ndpost <- mi.ndpost*mult.impute
    res <- list()
    i <- mult.impute+1
    ## save the matrices before cold-decking
    res[[i]] <- list(xftrain = xftrain, xstrain = xstrain, xftest=xftest, xstest=xstest)
    np <- nrow(xftest)
    for(i in 1:mult.impute) res[[i]] <- nft2(
               xftrain, xstrain, times, delta=delta, xftest=xftest, xstest=xstest,
               rm.const=rm.const, rm.dupe=rm.dupe,
               ## multi-threading
               tc=tc,
               ##MCMC
               nskip=nskip, ndpost=mi.ndpost, nadapt=nadapt, adaptevery=adaptevery, 
               chvf = chvf, chvs = chvs, method=method, use=use,
               pbd=pbd, pb=pb, stepwpert=stepwpert, probchv=probchv, minnumbot=minnumbot,
               ## BART and HBART prior parameters
               ntree=ntree, numcut=numcut, xifcuts=xifcuts, xiscuts=xiscuts, power=power, base=base,
               ## f function
               fmu=fmu, k=k, tau=tau, dist=dist,
               ## s function
               total.lambda=total.lambda, total.nu=total.nu, mask=mask,
               ## survival analysis 
               K=K, events=events, TSVS=TSVS,
               ## DPM LIO
               drawDPM=drawDPM, 
               alpha=alpha, alpha.a=alpha.a, alpha.b=alpha.b, alpha.draw=alpha.draw,
               neal.m=neal.m, constrain=constrain, 
               m0=m0, k0.a=k0.a, k0.b=k0.b, k0=k0, k0.draw=k0.draw,
               a0=a0, b0.a=b0.a, b0.b=b0.b, b0=b0, b0.draw=b0.draw,
               ## misc
               na.rm=na.rm, probs=probs, printevery=printevery,
               transposed=transposed, pred=pred
               )

    res. <- res[[1]]

    ##eval(parse(text=cmd))
    ptr <- c('ots', 'oid', 'ovar', 'oc', 'otheta', 
             'sts', 'sid', 'svar', 'sc', 'stheta',
             'xftrain', 'xstrain', 'xftest', 'xstest',
             'f.trees', 's.trees', 's.train.mask', 
             'dpalpha', 'dpn', 'dpn.', 'dpC', 
             'dpmu', 'dpsd', 'dpmu.', 'dpsd.', 'dpwt.') 
    for(var in ptr) {
        eval(parse(text=paste0('res.$', 
                               var, '<- list(res.$', var, ')')))
        eval(parse(text=paste0('for(i in 2:mult.impute) res.$', 
                               var, '[[i]] <- res[[i]]$', var)))
        if(var %in% c('xftrain', 'xstrain', 'xftest', 'xstest'))
            eval(parse(text=paste0('i <- mult.impute+1; res.$', 
                               var, '[[i]] <- res[[i]]$', var)))
    }

    res.$mult.impute <- mult.impute

    pf <- ncol(res.$xftrain[[1]])
    ## f.text <- paste0(mi.ndpost, ' ', ntree[1], ' ', pf)
    ## res.$f.trees <- sub(f.text, paste0(mi.ndpost*mult.impute, ' ', ntree[1], ' ', pf), res.$f.trees)
    ps <- ncol(res.$xstrain[[1]])
    ## s.text <- paste0(mi.ndpost, ' ', ntree[2], ' ', ps)
    ## res.$s.trees <- sub(s.text, paste0(mi.ndpost*mult.impute, ' ', ntree[2], ' ', ps), res.$s.trees)

    res.$LIO$phi <- list(res.$LIO$phi)

    for(i in 2:mult.impute) {
        ## res.$f.trees <- paste0(res.$f.trees, substr(res[[i]]$f.trees, nchar(f.text)+2,
        ##                                             nchar(res[[i]]$f.trees)))
        ## res.$s.trees <- paste0(res.$s.trees, substr(res[[i]]$s.trees, nchar(s.text)+2,
        ##                                             nchar(res[[i]]$s.trees)))
        res.$soffset[i] <- res[[i]]$soffset
        res.$f.train <- rbind(res.$f.train, res[[i]]$f.train)
        res.$s.train <- rbind(res.$s.train, res[[i]]$s.train)
        ##res.$s.train.mask <- c(res.$s.train.mask, (i-1)*mi.ndpost+res[[i]]$s.train.mask)
        if(length(res.$s.train.max.)>0) res.$s.train.max.[i] <- res[[i]]$s.train.max.
        res.$f.varcount <- rbind(res.$f.varcount, res[[i]]$f.varcount)
        res.$s.varcount <- rbind(res.$s.varcount, res[[i]]$s.varcount)
        if(np>0) {
            res.$f.test <- rbind(res.$f.test, res[[i]]$f.test)
            res.$s.test <- rbind(res.$s.test, res[[i]]$s.test)
        }
        ## res.$dpalpha <- c(res.$dpalpha, res[[i]]$dpalpha)
        ## res.$dpn <- c(res.$dpn, res[[i]]$dpn)
        ## res.$dpn. <- c(res.$dpn., res[[i]]$dpn.)
        ## res.$dpC <- rbind(res.$dpC, res[[i]]$dpC)
        ## res.$dpmu <- rbind(res.$dpmu, res[[i]]$dpmu)
        ## res.$dpsd <- rbind(res.$dpsd, res[[i]]$dpsd)
        res.$LIO$hyper$alpha[i] <- res[[i]]$LIO$hyper$alpha
        res.$LIO$hyper$k0[i] <- res[[i]]$LIO$hyper$k0
        res.$LIO$hyper$b0[i] <- res[[i]]$LIO$hyper$b0
        res.$LIO$C <- rbind(res.$LIO$C, res[[i]]$LIO$C)
        res.$LIO$states <- rbind(res.$LIO$states, res[[i]]$LIO$states)
        res.$LIO$phi[[i]] <- res[[i]]$LIO$phi
        ##res.$LIO$phi <- rbind(res.$LIO$phi, res[[i]]$LIO$phi)
        res.$ssd[i] <- res[[i]]$ssd
        res.$elapsed[i] <- res[[i]]$elapsed
        res.$f.tavgd[i] <- res[[i]]$f.tavgd
        res.$f.tmaxd[i] <- res[[i]]$f.tmaxd
        res.$f.tmind[i] <- res[[i]]$f.tmind
        res.$s.tavgd[i] <- res[[i]]$s.tavgd
        res.$s.tmaxd[i] <- res[[i]]$s.tmaxd
        res.$s.tmind[i] <- res[[i]]$s.tmind
    }

    res.$elapsed.sum <- sum(res.$elapsed)

    ## dpn.max <- max(res.$dpn.)

    ## for(i in 1:mult.impute) {
    ##     h <- ncol(res[[i]]$dpmu.)
    ##     if(h<dpn.max) {
    ##         for(j in (h+1):dpn.max) {
    ##             res[[i]]$dpmu. <- cbind(res[[i]]$dpmu., 0)
    ##             res[[i]]$dpsd. <- cbind(res[[i]]$dpsd., 1)
    ##             res[[i]]$dpwt. <- cbind(res[[i]]$dpwt., 0)
    ##         }
    ##     }
    ##     if(i == 1) {
    ##         res.$dpmu. <- res[[i]]$dpmu.
    ##         res.$dpsd. <- res[[i]]$dpsd.
    ##         res.$dpwt. <- res[[i]]$dpwt.
    ##     } else {
    ##         res.$dpmu. <- rbind(res.$dpmu., res[[i]]$dpmu.)
    ##         res.$dpsd. <- rbind(res.$dpsd., res[[i]]$dpsd.)
    ##         res.$dpwt. <- rbind(res.$dpwt., res[[i]]$dpwt.)
    ##     }
    ## }

    res.$ndpost <- nrow(res.$f.train)

    res.$f.train.mean <- apply(res.$f.train, 2, mean)
    res.$f.train.min  <- apply(res.$f.train, 2, min)
    res.$f.train.max  <- apply(res.$f.train, 2, max)
    res.$f.varcount.mean <- apply(res.$f.varcount, 2, mean)
    res.$f.varprob <- res.$f.varcount.mean/sum(res.$f.varcount.mean)
    res.$s.train.mean <- apply(res.$s.train, 2, mean)
    res.$s.train.min  <- apply(res.$s.train, 2, min)
    res.$s.train.max  <- apply(res.$s.train, 2, max)
    res.$s.varcount.mean <- apply(res.$s.varcount, 2, mean)
    res.$s.varprob <- res.$s.varcount.mean/sum(res.$s.varcount.mean)

    if(np>0) {
        res.$f.test.mean <- apply(res.$f.test, 2, mean)
        res.$s.test.mean <- apply(res.$s.test, 2, mean)
    }

    attr(res., "class")='nft2mi'
    
    return(res.)
}

Try the nftbart package in your browser

Any scripts or data that you put into this service are public.

nftbart documentation built on Dec. 3, 2025, 9:06 a.m.