R/predict.R

Defines functions dfextract predict2.bart

Documented in predict2.bart

#' @title predict() for spatial use of BART models
#'
#' @description
#' A predict() wrapper for combining BART models with spatial input data, to generate a Raster or RasterStack of predicted outputs. This now includes the ability to predict from random intercept models, which can be used to deal with clustering in space and time of outcome variables!
#' 
#' @param object A BART model objector riBART model object generated by the dbarts package 
#' @param x.layers An object of class RasterStack
#' @param quantiles Include the extraction of quantiles (e.g. 5\% and 95\% credible interval) from the posterior
#' @param ri.data If 'object' is a riBART model, this gives either one consistent value (e.g. a prediction year) or a RasterLayer for the random effect
#' @param ri.name The name of the random intercept in the riBART model
#' @param ri.pred Should the random intercept be *included* in the prediction value or dropped? Defaults to FALSE (treats the random intercept as noise to be excluded)
#' @param splitby If set to a value higher than 1, will split your dataset into approximately n divisible chunks 
#' @param quiet No progress bars
#' 
#' @export
#'


############################

predict2.bart <- function(object, 
                     x.layers,
                     quantiles=c(),
                     ri.data=NULL, 
                     ri.name=NULL, 
                     ri.pred=FALSE,
                     splitby=1,
                     quiet=FALSE) {
  
  if(class(object)=='rbart') {
    if(is.null(ri.data)) {stop('ERROR: Input either a value or a raster in ri.data')}
    if(is.null(ri.name)) {stop('ERROR: Input the correct random effect variable name in the model object in ri.name')}
  }
  
  if(class(object)=='rbart') {
    xnames <- attr(object$fit[[1]]$data@x, "term.labels")
    if (all(xnames %in% c(names(x.layers),ri.name))) {
      x.layers <- x.layers[[xnames[!(xnames==ri.name)]]]
    } else {
      stop("Variable names of RasterStack don't match the requested names")
    }
  }
  if(class(object)=='bart') {
    xnames <- attr(object$fit$data@x, "term.labels")
    if(all(xnames %in% names(x.layers))) {
      x.layers <- x.layers[[xnames]]
    } else {
      stop("Variable names of RasterStack don't match the requested names")
    }
  }
  
  input.matrix <- as.matrix(raster::getValues(x.layers))
  blankout <- data.frame(matrix(ncol=(1+length(quantiles)), 
                                nrow=ncell(x.layers[[1]])))
  whichvals <- which(complete.cases(input.matrix))
  input.matrix <- input.matrix[complete.cases(input.matrix),]
  
  if(class(object)=='rbart') {
    if(class(ri.data)=='RasterLayer') {
      input.matrix <- cbind(input.matrix,values(ri.data))
    } else {
      input.matrix <- cbind(input.matrix,rep(ri.data, nrow(input.matrix)))
    }
    colnames(input.matrix)[ncol(input.matrix)] <- ri.name
  }
  
  if(splitby==1) {
    if(class(object)=='bart') {
      pred <- dbarts:::predict.bart(object, input.matrix)
    } else if(class(object)=='rbart') {
      if(ri.pred==FALSE) {
        pred <- dbarts:::predict.rbart(object, 
                                       input.matrix[,!(colnames(input.matrix)==ri.name)], 
                                       group.by=input.matrix[,ri.name],
                                       type='bart')
      } else {
        pred <- dbarts:::predict.rbart(object, 
                                       input.matrix[,!(colnames(input.matrix)==ri.name)], 
                                       group.by=input.matrix[,ri.name],
                                       type='ppd')
      }
    } 
    pred.summary <- dfextract(pred, quant=quantiles)
  } else {
    split <- floor(nrow(input.matrix)/splitby)
    input.df <- data.frame(input.matrix)
    input.str <- split(input.df, (as.numeric(1:nrow(input.df))-1) %/% split)
    for(i in 1:length(input.str)){
        if(i==1) {start_time <- Sys.time()}
      
          if(class(object)=='bart') {
            pred <- dbarts:::predict.bart(object, input.str[[i]])
           } else if(class(object)=='rbart') {
             if(ri.pred==FALSE) {
               pred <- dbarts:::predict.rbart(object, 
                                              input.str[[i]][,!(colnames(input.str[[i]])==ri.name)], 
                                              group.by=input.str[[i]][,ri.name],
                                              type='bart')
             } else {
               pred <- dbarts:::predict.rbart(object, 
                                              input.str[[i]][,!(colnames(input.str[[i]])==ri.name)], 
                                              group.by=input.str[[i]][,ri.name],
                                              type='ppd')
             }
          } 
        pred.summary <- dfextract(pred, quant=quantiles)
        input.str[[i]] <- pred.summary
        if(i==1) {end_time <- Sys.time()
                  cat('Estimated time to total prediction (mins):\n') 
                  cat(length(input.str)*as.numeric(end_time - start_time)/60)
                  cat('\n')
                  if(!quiet){pb <- txtProgressBar(min = 0, max = length(input.str), style = 3)}}
        if(!quiet){setTxtProgressBar(pb, i)}
    }
    if(length(quantiles)==0) {
        pred.summary <- data.frame(means=unlist(input.str)) } else {
        pred.summary <- rbindlist(input.str)
    }
  }
  
  if(class(object)=='rbart') {output = pnorm(as.matrix(pred.summary))} else {
    output <- as.matrix(pred.summary)
  }
  
  blankout[whichvals,] <- output
  output <- blankout
  
  outlist <- lapply(1:ncol(output), function(x) {
      output.m <- t(matrix(output[,x],
                       nrow = ncol(x.layers),
                       ncol = nrow(x.layers)))
      return(raster(output.m,
                     xmn=xmin(x.layers[[1]]), xmx=xmax(x.layers[[1]]),
                     ymn=ymin(x.layers[[1]]), ymx=ymax(x.layers[[1]]),
                     crs=x.layers[[1]]@crs))
  })
  
  outlist <- stack(outlist)
  return(outlist)
  
}

dfextract <- function(df, quant) {
  if(length(quant)==0) {return(colMeans(df))} else 
    return(cbind(data.frame(colMeans(df)),
                 colQuantiles(df, probs=quant)))
}
cjcarlson/embarcadero documentation built on Sept. 9, 2023, 10:47 p.m.