R/BARP_barp.R

Defines functions barps

Documented in barps

#' Bayesian Additive Regression Trees with Post-Stratification (BARP)
#'
#' This function uses Bayesian Additive Regression Trees (BART) to extrapolate survey data to a level of geographic aggregation at which the original survey was not sampled to be representative of.
#' This is a modified version of the \code{barp} function from the \pkg{BARP} to allow for seed fixation.(\url{https://github.com/jbisbee1/BARP})
#' @source https://github.com/jbisbee1/BARP
#' @param y Outcome of interest. Should be a character of the column name containing the variable of interest.
#' @param x Prognostic covariates. Should be a vector of column names corresponding to the covariates used to predict the outcome variable of interest.
#' @param dat Survey data containing the x and y column names. The explanatory variables X included in the model must be converted to factors prior to input.
#' @param census Census data containing the x column names. It must also have the same structure as X. If the user provides raw census data, BARP will calculate proportions for each unique bin of x covariates. Otherwise, the researcher must calculate bin proportions and indicate the column name that contains the proportions, either as percentages or as raw counts.
#' @param geo.unit The column name corresponding to the unit at which outcomes should be aggregated.
#' @param algorithm Algorithm for predicting opinions. Can be any algorithm(s) included in the \pkg{SuperLearner} package. If multiple algorithms are listed, predicted opinions are provided for each separately, as well as for the weighted ensemble. Defaults to \code{BARP} which implements Bayesian Additive Regression Trees via \code{bartMachine}.
#' @param proportion The column name corresponding to the proportions for covariate bins in the Census data. If left to the default \code{None} value, BARP assumes raw census data and estimates bin proportions automatically.
#' @param cred_int A vector giving the lower and upper bounds on the credible interval for the predictions.
#' @param BSSD Calculate bootstrapped standard deviation. Defaults to \code{FALSE} in which case the standard deviation is generated by BART's default.
#' @param nsims The number of bootstrap simulations.
#' @param setSeed Seed to control random number generation.
#' @param ... Additional arguments to be passed to bartMachine or SuperLearner.
#' @return Returns an object of class \code{BARP}, containing a list of the following components:
#' \item{pred.opn}{A \code{data.frame} where each row corresponds to the geographic unit of interest and the columns summarize the predicted outcome and the upper and lower bounds for the given credible interval (\code{cred_int}).}
#' \item{trees}{A \code{bartMachine} object.} 
#' \item{risk}{A \code{data.frame} containing the cross-validation risk for each algorithm and the associated weight used in the ensemble predictions. Only useful when multiple algorithms are used.}
#' \item{barp.dat}{Data containing the estimates and credible intervals for each observation in the input census dataset.}
#' \item{setSeed}{The random seed value employed during model estimation using bartMachine.}
#' \item{proportion}{The number of observations in each combination of features.}
#' \item{x}{The names of the explanatory variables included in the model.}
#' @seealso 
#' \code{barps} is used to implement Bayesian Additive Regression Trees based on the \pkg{bartMachine} package. 
#' For detailed options, see \url{https://CRAN.R-project.org/package=bartMachine}.
#' 
#' \code{barps} also uses the \pkg{SuperLearner} package to implement alternative regularizers. 
#' For more details, see \url{https://CRAN.R-project.org/package=SuperLearner}.
#' @export

barps <- function(y,x,dat,census,geo.unit,algorithm = "BARP",  setSeed = NULL,
                 proportion = "None",cred_int = c(0.025,0.975),BSSD = FALSE,nsims = 200,...) {
 
   n <- 0;  
   if (!is.null(setSeed)) { set.seed(setSeed) }
   
  # Error and warning checks
  if(!all(x %in% colnames(dat))) {
    stop(paste("Variable '",x[which(!(x %in% colnames(dat)))],"' is not in your survey data.",sep=""))
  }
  if(!all(x %in% colnames(census))) {
    stop(paste("Variable '",x[which(!(x %in% colnames(census)))],"' is not in your census data.",sep=""))
  }
  if(!(y %in% colnames(dat))) {
    stop(paste("Outcome '",y,"' is not in your survey data.",sep=""))
  }
  if(!(geo.unit %in% x)) {
    stop(paste("The geographic unit '",geo.unit,"' is not among your x controls.",sep=""))
  }
  if(proportion == "None") {
    census <- census %>% group_by_(.dots = x) %>% summarise(n = n())
  } else {
    census$n <- census[[proportion]]
  }

  # Optional parameters
  dots <- list(...)
  dots$seed <- setSeed
  # Building parameter lists
  bart.args <- names(as.list(args(bartMachine)))
  serial <- FALSE
  if(any(names(dots) %in% bart.args) & grepl("barp|BARP|bart|BART|bartMachine",paste(algorithm,collapse = " "))) {
    dots.barp <- list()
    for(par in which(names(dots) %in% bart.args)) {
      dots.barp[[names(dots)[par]]] <- dots[[par]]
    }

    others <- setdiff(names(dots),bart.args)
    if(length(others) > 0) {
      dots.oth <- list()
      for(par in others) {
        dots.oth[[par]] <- dots[[par]]
      }
    } else {
      dots.oth <- list(method = "method.NNLS")
    }
    if(any(names(dots.barp) == "serialize")) {
      if(length(algorithm) > 1 | BSSD == TRUE) {
        dots.barp[[which(names(dots.barp) == "serialize")]] <- NULL
        serial <- TRUE
      }
    }
    barp_aug <- create.Learner("SL.bartMachine",params = dots.barp)
    algorithm[which(grepl("barp|BARP|bartMachine",algorithm))] <- barp_aug$names
    if(serial) {
      dots.barp[["serialize"]] <- TRUE
    }
    if(!is.null(setSeed)) {
      dots.barp[['seed']] <- setSeed
    }
    barp_serial <- create.Learner("SL.bartMachine",params = dots.barp,name_prefix = "BARPserial")

  } else {
    dots.oth <- dots
    dots.oth$method <- "method.NNLS"
  }

  algorithm <- paste0("SL.",gsub("SL.","",algorithm))
  algorithm <- gsub("barp|BARP","bartMachine",algorithm)


  # Converting to data.frame format
  dat <- as.data.frame(dat)
  census <- as.data.frame(census)
  #set.seed(setSeed)
  if(BSSD) {
    barp.geo.tmp <- matrix(NA,nrow = length(unique(census[[geo.unit]])),ncol = nsims)
      message("Starting bootstrapped estimation.")
    pb <- txtProgressBar(min = 0, max = nsims, style = 3)
    for(i in 1:nsims) {
      inds <- sample(1:nrow(dat),replace = T)
      prepped.dat <-prep.dat(dat = dat[inds,],census = census,x=x,y=y)
      x_test <- as.data.frame(prepped.dat$x_test)
      x_train <- as.data.frame(prepped.dat$x_train)
      y.train <- prepped.dat$y.train
      if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
        dots.oth$family <- binomial()
      }
      input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = algorithm)
      for(oth in 1:length(dots.oth)) {
        input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
      }
      input$BSSD <- BSSD
      input$setSeed <- setSeed
      sl = do.call(barp.SL.seed,input)

      # For bootstrapped intervals, use the ensemble predictions
      p <- sl$SL.predict
      barp.dat <- as.data.frame(cbind(census,p))
      barp.geo.tmp[,i] <- as.vector(sapply(unique(census[[geo.unit]]),
                                           function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat$p[barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j])))
      setTxtProgressBar(pb, i)
    }
    close(pb)
    y.pred <- apply(barp.geo.tmp,1,mean)
    y.lb <- apply(barp.geo.tmp,1,quantile,cred_int[1])
    y.ub <- apply(barp.geo.tmp,1,quantile,cred_int[2])
    barp.geo <- data.frame("geo.unit" = unique(census[[geo.unit]]),
                           "pred.opn" = y.pred,"opn.lb" = y.lb,"opn.ub" = y.ub)
    if(grepl("bartMachine",paste(algorithm,collapse = " "))) {
      prepped.dat <- prep.dat(dat = dat,census = census,x=x,y=y)
      x_test <- as.data.frame(prepped.dat$x_test)
      x_train <- as.data.frame(prepped.dat$x_train)
      y.train <- prepped.dat$y.train
      if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
        dots.oth$family <- binomial()
      }
      if(serial) {
        barpalg <- barp_serial$names
      } else {
        barpalg <- algorithm[which(grepl("bartMachine",algorithm))][1]
      }
      input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = barpalg)
      for(oth in 1:length(dots.oth)) {
        input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
      }
      input$setSeed <- setSeed
      sl <- do.call(barp.SL.seed,input)
      index <- which(grepl("bartMachine",names(sl$fitLibrary)))
      trees <- sl$fitLibrary[[index]]$object
    } else {
      if(length(algorithm) > 1) {
        prepped.dat <- prep.dat(dat = dat,census = census,x=x,y=y)
        x_test <- as.data.frame(prepped.dat$x_test)
        x_train <- as.data.frame(prepped.dat$x_train)
        y.train <- prepped.dat$y.train
        if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
          dots.oth$family <- binomial()
        }
        input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = algorithm)
        for(oth in 1:length(dots.oth)) {
          input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
        }
        input$setSeed <- setSeed
        sl <- do.call(barp.SL.seed,input)
      }
      trees <- NULL
    }
  } else {
    prepped.dat <- prep.dat(dat = dat,census = census,x=x,y=y)
    x_test <- as.data.frame(prepped.dat$x_test)
    x_train <- as.data.frame(prepped.dat$x_train)
    y.train <- prepped.dat$y.train
    if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
      dots.oth$family <- binomial()
      dots.oth$method <- "method.AUC"
    }
    input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = algorithm)
    for(oth in 1:length(dots.oth)) {
      input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
    }
    input$setSeed <- setSeed
    sl <- do.call(barp.SL.seed,input)

    if(grepl("bartMachine",paste(algorithm,collapse = " "))) {
      if(serial) {
        input$SL.library <- barp_serial$names
        input$setSeed <- setSeed
        tree <- do.call(barp.SL.seed,input)
        trees <- tree$fitLibrary$BARPserial_1_All$object
      } else {
        index <- which(grepl("bartMachine",names(sl$fitLibrary)))
        ints <- calc_credible_intervals(sl$fitLibrary[[index]]$object,x_test,
                                        ci_conf = cred_int[2] - cred_int[1])
        trees <- sl$fitLibrary[[index]]$object
      }
    } else {
      ints <- cbind(rep(NA,nrow(sl$SL.predict)),rep(NA,nrow(sl$SL.predict)))
      if(length(algorithm) > 1) {
        ints <- t(apply(cbind(sl$library.predict,sl$SL.predict),1,quantile,cred_int))
      }

      colnames(ints) <- c("ci_lower_bd","ci_upper_bd")
      verbose = TRUE
      if (verbose) { cat("\n") }
      trees <- NULL
    }
    lib.pred <- as.data.frame(sl$library.predict,stringsAsFactors = F)
    if(length(algorithm) > 1) {
      # No need to have the ensemble if there's just one algorithm
      lib.pred$ens <- as.vector(sl$SL.predict)
    }
    barp.dat <- as.data.frame(cbind(census,lib.pred,ints))

    barp.geo <- data.frame("geo.unit" = unique(census[[geo.unit]]))
    for(p in colnames(lib.pred)) {
      y.pred<-sapply(unique(census[[geo.unit]]),
                     function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat[[p]][barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j]))
      barp.geo <- cbind(barp.geo,tmp = y.pred)
      colnames(barp.geo)[which(colnames(barp.geo) == "tmp")] <- p
    }
    barp.geo$opn.lb<-sapply(unique(census[[geo.unit]]),
                            function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat$ci_lower_bd[barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j]))
    barp.geo$opn.ub<-sapply(unique(census[[geo.unit]]),
                            function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat$ci_upper_bd[barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j]))

  }
  colnames(barp.geo)[1] <- geo.unit
  factors <- sapply(dat,class)
  factors <- factors[which(factors %in% c("factor","character"))]
  risk <- data.frame(cbind(sl$cvRisk,sl$coef))
  colnames(risk) <- c(gsub("method.","",input$method),"coef")
  barp.obj <- list(pred.opn = barp.geo,trees = trees,algorithms = algorithm,
                   BSSD = BSSD,factors = factors,risk = risk, barp.dat = barp.dat,
                   setSeed = setSeed ,  proportion =  proportion, x = x [order (x)] )
  class(barp.obj) <- "barp"
  return(barp.obj)
}

Try the bartXViz package in your browser

Any scripts or data that you put into this service are public.

bartXViz documentation built on Aug. 8, 2025, 6:23 p.m.