R/predict_parallel.R

Defines functions predict_raster

Documented in predict_raster

#' Spatial model predictions
#'
#' @param x Raster object with predictor layers
#' @param model Model object
#' @param ncores Number of cores used for paralellization
#' @param predict_package Optional. Name of package used for prediction.
#' @param stack whether to return a stack with all dimensions (e.g., classes) of prediction
#' @param filename Optional. filename for writing raster to disk.
#' @param ... Arguments passed on to predict function (e.g. type argument)
#'
#' @return A raster layer with model predictions
#' @export

predict_raster <- function(x,
                           model,
                           ncores,
                           predict_package = NULL,
                           stack = FALSE,
                           filename = NULL,
                           ...) {


  if(is.null(predict_package)){
    predict_package <- .packages()
  }

  if(!stack) {
    out <- raster::raster(x)
    combine <- "c"
  } else {
    out <- raster::stack(x)
    combine <- "rbind"
  }

  factor_bands <- names(x)[is.factor(x)]
  bs <- raster::blockSize(out)
  cl <- parallel::makeCluster(ncores)
  doParallel::registerDoParallel(cl)
  values <- foreach::foreach(i=1:bs$n, .packages=c("raster", predict_package), .combine=combine) %dopar% {

    v <- as.matrix(getValues(x, row=bs$row[i], nrows=bs$nrows[i]))
    if(!stack) {
      m <- rep(NA, nrow(v))
    } else {
      m <- matrix(NA, nrow(v), ncol(v))
    }
    newdata <- as.data.frame(v[which(rowSums(is.na(v))==0),])
    names(newdata) <- names(x)
    newdata[factor_bands] <- lapply(newdata[factor_bands], factor)
    preds <- predict(object=model, newdata, ...)
    if(!is.vector(preds) & !stack) {
      preds <- as.data.frame(preds)[[1]]
    }

    if(!stack) {
      preds <- as.numeric(as.character(preds))
      m[which(rowSums(is.na(v))==0)] <- preds
    } else {
      preds <- apply(preds, 2, function(x) as.numeric(as.character(x)))
      m[which(rowSums(is.na(v))==0),] <- preds
    }

    m

    }

  parallel::stopCluster(cl)
  out <- setValues(out, values)
  if(!is.null(filename)) {
    writeRaster(out, filename, overwrite=TRUE)
  }
  return(out)
}
juoe/spatialtools documentation built on May 25, 2019, 6:25 p.m.