R/barp.R

Defines functions barp

Documented in barp

#' BARP
#' 
#' This function uses Bayesian Additive Regression Trees (BART) to extrapolate survey data to a level of geographic aggregation at which the original survey was not sampled to be representative of.
#' @param y Outcome of interest. Should be a character of the column name containing the variable of interest.
#' @param x Prognostic covariates. Should be a vector of column names corresponding to the covariates used to predict the outcome variable of interest.
#' @param dat Survey data containing the x and y column names.
#' @param census Census data containing the x column names. If the user provides raw census data, BARP will calculate proportions for each unique bin of x covariates. Otherwise, the researcher must calculate bin proportions and indicate the column name that contains the proportions, either as percentages or as raw counts.
#' @param geo.unit The column name corresponding to the unit at which outcomes should be aggregated.
#' @param algorithm Algorithm for predicting opinions. Can be any algorithm(s) included in the \code{SuperLearner} package. If multiple algorithms are listed, predicted opinions are provided for each separately, as well as for the weighted ensemble. Defaults to \code{BARP} which implements Bayesian Additive Regression Trees via \code{bartMachine}.
#' @param proportion The column name corresponding to the proportions for covariate bins in the Census data. If left to the default \code{None} value, BARP assumes raw census data and estimates bin proportions automatically.
#' @param cred_int A vector giving the lower and upper bounds on the credible interval for the predictions.
#' @param BSSD Calculate bootstrapped standard deviation. Defaults to \code{FALSE} in which case the standard deviation is generated by BART's default.
#' @param nsims The number of bootstrap simulations. 
#' @param setSeed Seed to control random number generation.
#' @param ... Additional arguments to be passed to bartMachine or SuperLearner.
#' @return Returns an object of class ``BARP'', containing a list of the following components:
#' \item{pred.opn}{A \code{data.frame} where each row corresponds to the geographic unit of interest and the columns summarize the predicted outcome and the upper and lower bounds for the given credible interval (\code{cred_int}).}
#' \item{trees}{A \code{bartMachine} object. See \code{\link{bartMachine}} for details.}
#' \item{risk}{A \code{data.frame} containing the cross-validation risk for each algorithm and the associated weight used in the ensemble predictions. Only useful when multiple algorithms are used.}
#' @keywords MRP BARP Mister P multilevel regression poststratification
#' @seealso \code{\link{bartMachine}} which this program uses to implement Bayesian Additive Regression Trees.
#' @seealso \code{\link{SuperLearner}} which this program uses to implement alternative regularizers.
#' @examples 
#' data("gaymar")
#' barp <- barp(y = "supp_gaymar",
#'              x = c("pvote","religcon","age","educ","gXr","stateid","region"),
#'              dat = svy,
#'              census = census06,
#'              geo.unit = "stateid",
#'              proportion = "n",
#'              cred_int = c(0.025,0.975))
#' @export

barp <- function(y,x,dat,census,geo.unit,algorithm = "BARP",
                 proportion = "None",cred_int = c(0.025,0.975),BSSD = F,nsims = 200,setSeed = NULL,...) {
  set.seed(setSeed)
  # Error and warning checks
  if(!all(x %in% colnames(dat))) {
    stop(paste("Variable '",x[which(!(x %in% colnames(dat)))],"' is not in your survey data.",sep=""))
  }
  if(!all(x %in% colnames(census))) {
    stop(paste("Variable '",x[which(!(x %in% colnames(census)))],"' is not in your census data.",sep=""))
  }
  if(!(y %in% colnames(dat))) {
    stop(paste("Outcome '",y,"' is not in your survey data.",sep=""))
  }
  if(!(geo.unit %in% x)) {
    stop(paste("The geographic unit '",geo.unit,"' is not among your x controls.",sep=""))
  }
  if(proportion == "None") {
    census <- census %>% group_by_(.dots = x) %>% summarise(n = n())
  } else {
    census$n <- census[[proportion]]
  }
  
  # Optional parameters
  dots <- list(...)
  dots$seed = setSeed
  # Building parameter lists
  bart.args <- names(as.list(args(bartMachine)))
  serial = FALSE
  if(any(names(dots) %in% bart.args) & grepl("barp|BARP|bart|BART|bartMachine",paste(algorithm,collapse = " "))) {
    dots.barp <- list()
    for(par in which(names(dots) %in% bart.args)) {
      dots.barp[[names(dots)[par]]] <- dots[[par]]
    }
    
    others <- setdiff(names(dots),bart.args)
    if(length(others) > 0) {
      dots.oth <- list()
      for(par in others) {
        dots.oth[[par]] <- dots[[par]]
      }
    } else {
      dots.oth <- list(method = "method.NNLS")
    }
    if(any(names(dots.barp) == "serialize")) {
      if(length(algorithm) > 1 | BSSD == TRUE) {
        dots.barp[[which(names(dots.barp) == "serialize")]] <- NULL
        serial = TRUE
      }
    }
    barp_aug <- create.Learner("SL.bartMachine",params = dots.barp)
    algorithm[which(grepl("barp|BARP|bartMachine",algorithm))] <- barp_aug$names
    if(serial) {
      dots.barp[["serialize"]] <- TRUE
    }
    if(!is.null(setSeed)) {
      dots.barp[['seed']] <- setSeed
    }
    barp_serial <- create.Learner("SL.bartMachine",params = dots.barp,name_prefix = "BARPserial")
    
  } else {
    dots.oth <- dots
    dots.oth$method = "method.NNLS"
  }
  
  algorithm <- paste0("SL.",gsub("SL.","",algorithm))
  algorithm <- gsub("barp|BARP","bartMachine",algorithm)
  
  
  # Converting to data.frame format
  dat <- as.data.frame(dat)
  census <- as.data.frame(census)
  
  if(BSSD) {
    barp.geo.tmp <- matrix(NA,nrow = length(unique(census[[geo.unit]])),ncol = nsims)
    cat("Starting bootstrapped estimation\n")
    pb <- txtProgressBar(min = 0, max = nsims, style = 3)
    for(i in 1:nsims) {
      inds <- sample(1:nrow(dat),replace = T)
      prepped.dat <- prep.dat(dat = dat[inds,],census = census,x=x,y=y)
      x_test <- as.data.frame(prepped.dat$x_test)
      x_train <- as.data.frame(prepped.dat$x_train)
      y.train <- prepped.dat$y.train
      if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
        dots.oth$family <- binomial()
      }
      input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = algorithm)
      for(oth in 1:length(dots.oth)) {
        input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
      }
      input$BSSD = BSSD
      sl = do.call(barp.SL,input)
      
      # For bootstrapped intervals, use the ensemble predictions
      p <- sl$SL.predict
      barp.dat <- as.data.frame(cbind(census,p))
      barp.geo.tmp[,i] <- as.vector(sapply(unique(census[[geo.unit]]),
                                           function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat$p[barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j])))
      setTxtProgressBar(pb, i)
    }
    close(pb)
    y.pred <- apply(barp.geo.tmp,1,mean)
    y.lb <- apply(barp.geo.tmp,1,quantile,cred_int[1])
    y.ub <- apply(barp.geo.tmp,1,quantile,cred_int[2])
    barp.geo <- data.frame("geo.unit" = unique(census[[geo.unit]]),
                           "pred.opn" = y.pred,"opn.lb" = y.lb,"opn.ub" = y.ub)
    if(grepl("bartMachine",paste(algorithm,collapse = " "))) {
      prepped.dat <- prep.dat(dat = dat,census = census,x=x,y=y)
      x_test <- as.data.frame(prepped.dat$x_test)
      x_train <- as.data.frame(prepped.dat$x_train)
      y.train <- prepped.dat$y.train
      if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
        dots.oth$family <- binomial()
      }
      if(serial) {
        barpalg <- barp_serial$names
      } else {
        barpalg <- algorithm[which(grepl("bartMachine",algorithm))][1]
      }
      input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = barpalg)
      for(oth in 1:length(dots.oth)) {
        input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
      }
      sl = do.call(barp.SL,input)
      index <- which(grepl("bartMachine",names(sl$fitLibrary)))
      trees <- sl$fitLibrary[[index]]$object
    } else {
      if(length(algorithm) > 1) {
        prepped.dat <- prep.dat(dat = dat,census = census,x=x,y=y)
        x_test <- as.data.frame(prepped.dat$x_test)
        x_train <- as.data.frame(prepped.dat$x_train)
        y.train <- prepped.dat$y.train
        if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
          dots.oth$family <- binomial()
        }
        input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = algorithm)
        for(oth in 1:length(dots.oth)) {
          input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
        }
        sl = do.call(barp.SL,input)
      }
      trees <- NULL
    }
  } else {
    prepped.dat <- prep.dat(dat = dat,census = census,x=x,y=y)
    x_test <- as.data.frame(prepped.dat$x_test)
    x_train <- as.data.frame(prepped.dat$x_train)
    y.train <- prepped.dat$y.train
    if(length(unique(y.train)) == 2 & !grepl("family",paste(names(dots.oth),collapse = " "))) {
      dots.oth$family <- binomial()
      dots.oth$method <- "method.AUC"
    }
    input <- list(Y = y.train,X = x_train,newX = x_test,SL.library = algorithm)
    for(oth in 1:length(dots.oth)) {
      input[[names(dots.oth)[oth]]] <- dots.oth[[oth]]
    }
    
    sl = do.call(barp.SL,input)
    
    if(grepl("bartMachine",paste(algorithm,collapse = " "))) {
      if(serial) {
        input$SL.library <- barp_serial$names
        tree = do.call(barp.SL,input)
        trees <- tree$fitLibrary$BARPserial_1_All$object
      } else {
        index <- which(grepl("bartMachine",names(sl$fitLibrary)))
        ints <- calc_credible_intervals(sl$fitLibrary[[index]]$object,x_test,
                                        ci_conf = cred_int[2] - cred_int[1])
        trees <- sl$fitLibrary[[index]]$object 
      }
    } else {
      ints <- cbind(rep(NA,nrow(sl$SL.predict)),rep(NA,nrow(sl$SL.predict)))
      if(length(algorithm) > 1) {
        ints <- t(apply(cbind(sl$library.predict,sl$SL.predict),1,quantile,cred_int))
      }
      
      colnames(ints) <- c("ci_lower_bd","ci_upper_bd")
      cat("\n")
      trees <- NULL
    }
    lib.pred <- as.data.frame(sl$library.predict,stringsAsFactors = F)
    if(length(algorithm) > 1) {
      # No need to have the ensemble if there's just one algorithm
      lib.pred$ens <- as.vector(sl$SL.predict)
    }
    barp.dat <- as.data.frame(cbind(census,lib.pred,ints))
    
    barp.geo <- data.frame("geo.unit" = unique(census[[geo.unit]]))
    for(p in colnames(lib.pred)) {
      y.pred<-sapply(unique(census[[geo.unit]]),
                     function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat[[p]][barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j]))  
      barp.geo <- cbind(barp.geo,tmp = y.pred)
      colnames(barp.geo)[which(colnames(barp.geo) == "tmp")] <- p
    }
    barp.geo$opn.lb<-sapply(unique(census[[geo.unit]]),
                            function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat$ci_lower_bd[barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j]))
    barp.geo$opn.ub<-sapply(unique(census[[geo.unit]]),
                            function(j) sum(barp.dat$n[barp.dat[[geo.unit]]==j]*barp.dat$ci_upper_bd[barp.dat[[geo.unit]]==j])/sum(barp.dat$n[barp.dat[[geo.unit]]==j]))
    
  }
  colnames(barp.geo)[1] <- geo.unit
  factors <- sapply(dat,class)
  factors <- factors[which(factors %in% c("factor","character"))]
  risk <- data.frame(cbind(sl$cvRisk,sl$coef))
  colnames(risk) <- c(gsub("method.","",input$method),"coef")
  barp.obj <- list(pred.opn = barp.geo,trees = trees,algorithms = algorithm,
                   BSSD = BSSD,factors = factors,risk = risk)
  class(barp.obj) <- "barp"
  return(barp.obj)
}
jbisbee1/BARP documentation built on Jan. 5, 2023, 9:15 a.m.