R/wbart.R

#  borrowr: estimate population average treatment effects with borrowing between data sources.
#  Copyright (C) 2019  Jeffrey A. Verdoliva Boatman
#  This is a modified version of the code from the BART R package.
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.

#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.

wbart <- function (x.train, y.train, x.test = matrix(0, 0, 0), sparse = FALSE,
    theta = 0, omega = 1, a = 0.5, b = 1, augment = FALSE, rho = NULL,
    xinfo = matrix(0, 0, 0), usequants = FALSE, cont = FALSE,
    rm.const = TRUE, sigest = NA, sigdf = 3, sigquant = 0.9,
    k = 2, power = 2, base = 0.95, lambda = NA, gamma = NA,
    fmean = mean(y.train), w = rep(1, length(y.train)), ntree = 200L,
    numcut = 100L, ndpost = 1000L, nskip = 100L, keepevery = 1L,
    nkeeptrain = ndpost, nkeeptest = ndpost, nkeeptestmean = ndpost,
    nkeeptreedraws = ndpost, printevery = 100L, transposed = FALSE)
{
    n = length(y.train)
    if (!transposed) {
        temp = bartModelMatrix(x.train, numcut, usequants = usequants,
            cont = cont, xinfo = xinfo, rm.const = rm.const)
        x.train = t(temp$X)
        numcut = temp$numcut
        xinfo = temp$xinfo
        if (length(x.test) > 0) {
            x.test = bartModelMatrix(x.test)
            x.test = t(x.test[, temp$rm.const])
        }
        rm.const <- temp$rm.const
        grp <- temp$grp
        rm(temp)
    }
    else {
        rm.const <- NULL
        grp <- NULL
    }
    if (n != ncol(x.train))
        stop("The length of y.train and the number of rows in x.train must be identical")
    p = nrow(x.train)
    np = ncol(x.test)
    if (length(rho) == 0)
        rho = p
    if (length(rm.const) == 0)
        rm.const <- 1:p
    if (length(grp) == 0)
        grp <- 1:p
    y.train = y.train - fmean
    if ((nkeeptrain != 0) & ((ndpost%%nkeeptrain) != 0)) {
        nkeeptrain = ndpost
        # cat("*****nkeeptrain set to ndpost\n")
    }
    if ((nkeeptest != 0) & ((ndpost%%nkeeptest) != 0)) {
        nkeeptest = ndpost
        # cat("*****nkeeptest set to ndpost\n")
    }
    if ((nkeeptestmean != 0) & ((ndpost%%nkeeptestmean) != 0)) {
        nkeeptestmean = ndpost
        # cat("*****nkeeptestmean set to ndpost\n")
    }
    if ((nkeeptreedraws != 0) & ((ndpost%%nkeeptreedraws) !=
        0)) {
        nkeeptreedraws = ndpost
        # cat("*****nkeeptreedraws set to ndpost\n")
    }
    nu = sigdf
    if (is.na(lambda)) {
        if (is.na(sigest)) {
            if (p < n) {
                df = data.frame(t(x.train), y.train)
                lmf = lm(y.train ~ ., df)
                sigest = summary(lmf)$sigma
            }
            else {
                sigest = sd(y.train)
            }
        }
        qchi = qchisq(1 - sigquant, nu)
        lambda = (sigest * sigest * qchi)/nu
    }
    else {
        sigest = sqrt(lambda)
    }
    # if (is.na(sigmaf)) {
    #     tau = (max(y.train) - min(y.train))/(2 * k * sqrt(ntree))
    # }
    # else {
    #     tau = sigmaf/sqrt(ntree)
    # }
    # mine
    if(is.na(gamma)) gamma <- 1 / ntree
    # end mine
    ptm <- proc.time()
    res = .Call("cwbart", n, p, np, x.train, y.train, x.test,
        ntree, numcut, ndpost * keepevery, nskip, power, base,
        nu, lambda, gamma, sigest, w, sparse, theta, omega, grp,
        a, b, rho, augment, nkeeptrain, nkeeptest, nkeeptestmean,
        nkeeptreedraws, printevery, xinfo)
    res$proc.time <- proc.time() - ptm
    res$mu = fmean
    res$yhat.train.mean = res$yhat.train.mean + fmean
    res$yhat.train = res$yhat.train + fmean
    res$yhat.test.mean = res$yhat.test.mean + fmean
    res$yhat.test = res$yhat.test + fmean
    if (nkeeptreedraws > 0)
        names(res$treedraws$cutpoints) = dimnames(x.train)[[1]]
    dimnames(res$varcount)[[2]] = as.list(dimnames(x.train)[[1]])
    dimnames(res$varprob)[[2]] = as.list(dimnames(x.train)[[1]])
    res$varcount.mean <- apply(res$varcount, 2, mean)
    res$varprob.mean <- apply(res$varprob, 2, mean)
    res$rm.const <- rm.const
    attr(res, "class") <- "wbart"
    return(res)
}

Try the borrowr package in your browser

Any scripts or data that you put into this service are public.

borrowr documentation built on Dec. 8, 2020, 5:08 p.m.