R/SOptim_RasterOutputs.R

Defines functions predictSegments getTrainRasterSegments

Documented in getTrainRasterSegments predictSegments

#' Generate a raster data set with train segments
#' 
#' An utility function used to create a raster dataset containing train segments for a given segmentation 
#' solution and a train raster.
#' 
#' @inheritParams getTrainData
#' 
#' @param trainData Input train data. The input can be a \code{SpatRaster}, an integer vector (containing 
#' raster data after using \code{\link[terra]{values}} function) or a character string with a path to a 
#' raster layer. If trainData is an integer vector then it should have a length equal to the number of 
#' pixels in rstSegm.
#' 
#' @param filename A filename/path used to write the output raster (default: NULL)

#' @param verbose Print comments with function progress? (default: TRUE)
#' 
#' @param ... Additional options passed to \code{writeRaster} function. The option \code{datatype} 
#' is already internally set to \emph{INT2U}.
#' 
#' @return A SpatRaster object with train segments.
#' 
#' @importFrom terra values
#' @importFrom terra rast
#' @importFrom terra ncell
#' @importFrom terra writeRaster
#' @export
#' 


getTrainRasterSegments <- function(trainData, rstSegm, filename = NULL, useThresh = TRUE, thresh = 0.5, 
                                   na.rm = TRUE, dup.rm = TRUE, minImgSegm = 30, ignore = FALSE, 
                                   verbose = TRUE, ...){
  
  
  if(verbose) cat("-> Generating train data...\n") ## --------------------------------------------------- ##
  trainDataDF <- getTrainData(x=trainData, rstSegm=rstSegm, useThresh=useThresh, thresh=thresh, 
                              na.rm=na.rm, dup.rm=dup.rm, minImgSegm=minImgSegm,ignore=ignore)

  if(verbose) cat("done.\n\n")
  
  if(verbose) cat("-> Loading segmented raster...\n") ## ------------------------------------------------ ##
  rstSegmDF <- data.frame(cell_ID = 1:(terra::ncell(rstSegm)), SID = terra::values(rstSegm)[,1])
  colnames(rstSegmDF) <- c("cell_ID", "SID")
  
  
  if(verbose) cat("done.\n\n")

  if(verbose) cat("-> Joining and re-ordering train data with the segmented raster ...\n") ##------------ ##
  rstSegmDF <- merge(rstSegmDF, trainDataDF, by="SID", all.x=TRUE)
  rstSegmDF <- rstSegmDF[order(rstSegmDF$cell_ID), ]
  
  if(verbose) cat("done.\n\n")
  
  if(verbose) cat("-> Creating a new raster dataset with train segments...\n") ## --------------------- ##
  newRstTrainData <- terra::rast(rstSegm)
  terra::values(newRstTrainData) <- rstSegmDF$train
  
  if(verbose) cat("done.\n\n")
  
  if(!is.null(filename)){
    if(verbose) cat("-> Writing data...\n") ## ---------------------------------------------------------- ##
    terra::writeRaster(newRstTrainData, filename = filename, datatype="INT2U", ...)
    
    if(verbose) cat("done.\n\n")
  }
  return(newRstTrainData)
}




#' Predict class labels for image segments
#' 
#' This function uses an input \code{SOptim.Classifier} object to predict class labels for train or all segments in the 
#' input segmented image.
#' 
#' @param classifierObj An object of class \code{SOptim.Classifier} containing a classification algorithm 
#' generated by function \code{\link{calibrateClassifier}} with option \code{runFullCalibration = TRUE}.
#' 
#' @param calData An object of class \code{SOptim.CalData} generated by function \code{\link{prepareCalData}} containing 
#' calibration data for train segments and the entire image.
#' 
#' @param rstSegm A string defining the path to the raster with segment IDs, a \code{SpatRaster} object or 
#' a \code{SOptim.SegmentationResult} object (generated by any segmentation function (check \link{segmentationGeneric}). 
#' 
#' @param predictFor Either option \code{"train"} which predicts class labels only for train segments 
#' or option \code{"all"} which predicts for all existing segments in \code{rstSegm} (default: all).
#'  
#' @param filename A file name/path used to write the output raster (default: NULL).
#' 
#' @param verbose Print comments with function progress? (default: TRUE).
#' 
#' @param na.rm Remove NA's? (default: TRUE).
#' 
#' @param ... Additional arguments for \code{\link[terra]{writeRaster}} function except \code{datatype} 
#' which is already internally set to 'INT4U'. 
#' 
#' @param forceWriteByLine Use memory-safe writing of raster output by line? (default: FALSE). 
#' If \code{forceWriteByLine = TRUE}, then \code{filename} must define a valid file path.
#' 
#' @details By default the function uses the classifier ran with the 'full' dataset (i.e., no train/test splits) 
#' for making class label predictions. In case of single-class problems the threshold that maximizes the selected 
#' evaluation metric (check \code{evalMetric} in \code{\link{calibrateClassifier}}) is used to dichotomize predictions.   
#' For multi-class problems the output class label is set for the one with highest probability value.  
#' 
#' @return An object of class \code{SpatRaster} containing the predicted class labels for 
#' each image segment. If the file name is defined, the function will write a file containing the output 
#' raster. The output data type is INT4U (see \code{\link[terra]{datatype}} for more details) which means 
#' negative values for class labels are not valid.
#' 
#' @importFrom terra ncell
#' @importFrom terra values
#' @importFrom terra rast
#' @importFrom terra writeRaster
#' @importFrom terra readStart 
#' @importFrom terra readStop
#' @importFrom terra writeStart
#' @importFrom terra writeStop
#' @importFrom terra writeValues
#' @importFrom stats predict
#' @importFrom stats na.omit
#' @importFrom data.table setDT
#' @importFrom data.table setkey
#' @importFrom data.table setorder
#' @importFrom utils txtProgressBar
#' @importFrom utils setTxtProgressBar
#' 
#' @export

predictSegments <- function(classifierObj, calData, rstSegm, predictFor = "all", 
                            filename = NULL, verbose = TRUE, na.rm = TRUE,
                            forceWriteByLine = FALSE, ...){
  
  if(!inherits(classifierObj, "SOptim.Classifier"))
    stop("classifierObj must be an object of class SOptim.Classifier generated by 
         calibrateClassifier function with option runFullCalibration = TRUE")
  
  if(!inherits(calData, "SOptim.CalData"))
    stop("calData must be an object of class SOptim.CalData generated by prepareCalData function")
  
  if(!(predictFor %in% c("train","all")))
    stop("predictFor must be either \"train\" (only for train segments) or \"all\"!")
  
  # TODO: Replace canProcessInMemory
  #
  #
  # # If large raster filename must be defined to write data to disk
  # if(!canProcessInMemory(rstSegm) && is.null(filename)){
  #   stop("Large raster detected! To process data please define a valid output path in parameter filename.\n")
  # }
  
  # If large raster filename must be defined to write data to disk
  if(forceWriteByLine && is.null(filename)){
    stop("If forceWriteByLine = TRUE please define a valid output path in parameter filename.\n")
  }
  
  if(verbose) cat("-> Pre-processing data...\n") ## ----------------------------------------- ##
  
  if(inherits(rstSegm,"SOptim.SegmentationResult")){
    rstSegm <- terra::rast(rstSegm$segm)
  }
  
  # Get the classifier trained with the full dataset
  clObj <- classifierObj$ClassObj$FULL
  
  # Get metadata on number of classes
  nClassType <- classifierObj$ClassParams$nClassType
  
  # Get threshold values for dichotomizing single-class probability predictions 
  if(nClassType=="single-class") threshVals <- classifierObj$Thresh
  
  # Remove NA's from data
  if(na.rm){
    if(predictFor == "train") calData$calData <- na.omit(calData$calData)
    if(predictFor == "all") calData$classifFeatData <- na.omit(calData$classifFeatData)
  }
  
  if(verbose) cat("done.\n\n") ## ----- ##
    
  
  if(verbose) cat("-> Classifying raster segments...\n") ## -------------------------------- ##
  
  ## --------------------------------------------------------------------------------------- ##
  ## Random forest classifier ----
  ## --------------------------------------------------------------------------------------- ##
  
  if(inherits(clObj, "randomForest")){
    if(nClassType == "single-class"){
      
      # Get the threshold value used to binarize the classifier predictions
      threshValFullClassif <- threshVals[length(threshVals)]
      
      if(predictFor == "train"){
        pred <- stats::predict(clObj, type="prob") ## Gets the out-of-bag predictions
        pred <- as.integer(pred[,2] > threshValFullClassif)
      } 
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, type="prob", newdata = calData$classifFeatData)
        pred <- as.integer(pred[,2] > threshValFullClassif)
      }
    }
    if(nClassType == "multi-class"){
      if(predictFor == "train"){
        pred <- stats::predict(clObj, type = "response") ## Gets the out-of-bag predictions
      }
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, type = "response", newdata = calData$classifFeatData)
      }
    }
  }
  
  ## --------------------------------------------------------------------------------------- ##
  ## GBM classifier ----
  ## --------------------------------------------------------------------------------------- ##
  
  if(inherits(clObj, "gbm")){
    
    # Get the number of trees in GBM classifier
    nt <- clObj$n.trees
    
    if(nClassType == "single-class"){
      
      # Get the threshold value used to binarize the classifier predictions
      threshValFullClassif <- threshVals[length(threshVals)]
      
      if(predictFor == "train"){
        pred <- stats::predict(clObj, newdata = calData$calData, type = "response", n.trees = nt)
        pred <- as.integer(pred > threshValFullClassif)
      }
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "response", n.trees = nt)
        pred <- as.integer(pred > threshValFullClassif)
      }
    }
    if(nClassType == "multi-class"){
      # Use only the first element of the prediction array in [,,1] and extract the class with highest probability 
      # outcome using as.integer(colnames(pred))[apply(pred,1,which.max)]
      if(predictFor == "train"){
        pred <- stats::predict(clObj, newdata = calData$calData, type = "response", n.trees = nt)[,,1]
        pred <- as.integer(colnames(pred))[apply(pred,1,which.max)]
      }
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "response", n.trees = nt)[,,1]
        pred <- as.integer(colnames(pred))[apply(pred,1,which.max)]
      }
    }
  }
  
  ## --------------------------------------------------------------------------------------- ##
  ## SVM classifier ----
  ## --------------------------------------------------------------------------------------- ##
  
  if(inherits(clObj, "svm")){
    
    if(nClassType == "single-class"){
      
      # Get the threshold value used to binarize the classifier predictions
      threshValFullClassif <- threshVals[length(threshVals)]
      
      if(predictFor == "train"){
        pred <- attr(stats::predict(clObj, newdata = calData$calData, probability = TRUE), "probabilities")
        pred <- as.integer(pred[,2] > threshValFullClassif)
      }
      else if(predictFor == "all"){
        pred <- attr(stats::predict(clObj, newdata = calData$classifFeatData, probability = TRUE), "probabilities")
        pred <- as.integer(pred[,2] > threshValFullClassif)
      }
    }
    if(nClassType == "multi-class"){
      if(predictFor == "train"){
        pred <- stats::predict(clObj, newdata = calData$calData)
      }
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, newdata = calData$classifFeatData)
      }
    }
  }
  
  ## --------------------------------------------------------------------------------------- ##
  ## KNN classifier ----
  ## --------------------------------------------------------------------------------------- ##
  
  if(inherits(clObj, "knn")){
    
    if(nClassType == "single-class" || nClassType == "multi-class"){
      
      if(predictFor == "train"){
        pred <- stats::predict(clObj, type = "response")
      }
      else if(predictFor == "all"){
        
        k <- classifierObj$ClassParams$classificationMethodParams$KNN$k
        balanceTrainData <- classifierObj$ClassParams$balanceTrainData
        balanceMethod <- classifierObj$ClassParams$balanceMethod
        
        pred <- stats::predict(clObj, 
                               newdata   = calData$classifFeatData[,-1], 
                               traindata = calData$calData[,-c(1:2)], 
                               cl        = calData$calData$train, 
                               k         = k, 
                               type      = "response",
                               balanceTrainData = balanceTrainData,
                               balanceMethod    = balanceMethod)
      }
    }
  }
  
  ## --------------------------------------------------------------------------------------- ##
  ## FDA classifier ----
  ## --------------------------------------------------------------------------------------- ##
  
  if(inherits(clObj, "fda")){
    
    if(nClassType == "single-class"){
      
      # Get the threshold value used to binarize the classifier predictions
      threshValFullClassif <- threshVals[length(threshVals)]
      
      if(predictFor == "train"){
        pred <- stats::predict(clObj, newdata = calData$calData, type = "posterior")
        pred <- as.integer(pred[,2] > threshValFullClassif)
      }
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "posterior")
        pred <- as.integer(pred[,2] > threshValFullClassif)
      }
    }
    if(nClassType == "multi-class"){
      if(predictFor == "train"){
        pred <- stats::predict(clObj, newdata = calData$calData, type = "class")
      }
      else if(predictFor == "all"){
        pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "class")
      }
    }
  }
  
  if(verbose) cat("done.\n\n")
  
  
  #if(canProcessInMemory(rstSegm) && !forceWriteByLine){
  if(!forceWriteByLine){
      
    if(verbose) cat("-> Loading segmented raster data...\n") ## ----- ##
    
    # Get some preliminary data on raster segments
    rstSegmDF <- data.frame(cell_ID=1:terra::ncell(rstSegm), SID=terra::values(rstSegm)[,1])
    colnames(rstSegmDF) <- c("cell_ID", "SID")
    
    data.table::setDT(rstSegmDF) # Convert to data.table by reference
    data.table::setkey(rstSegmDF, SID) # Set key
    
    if(verbose) cat("done.\n\n")
    
    if(verbose) cat("-> Merging data and creating a new raster dataset with predicted class labels...\n") ## ----- ##
    
    # Convert factors to integers if needed otherwise class integer codes may come out wrong...
    if(is.factor(pred)) pred <- as.integer(levels(pred)[as.integer(pred)])
    
    # Append the segment IDs (SID field) to the predicted object
    if(predictFor == "train"){
      predDF <- data.frame(SID = calData$calData$SID, predClass = pred)
    }else if(predictFor == "all"){
      predDF <- data.frame(SID = calData$classifFeatData$SID, predClass = pred)
    } 
    rm(pred)
    
    # Convert to data.table object
    data.table::setDT(predDF) # Convert to data.table by reference
    data.table::setkey(predDF)
    
    # Join and update column by reference
    rstSegmDF[predDF, on = 'SID', predClass := predClass]
    
    # Reorder by cell cell_ID otherwise it will give problems when writing data to raster
    data.table::setorder(rstSegmDF, cell_ID, na.last = FALSE)
    
    # Make new raster to fill with predicted class labels
    newRstPred <- terra::rast(rstSegm)
    terra::values(newRstPred) <- rstSegmDF[["predClass"]]
    
    if(verbose) cat("done.\n\n")
    
    
    if(!is.null(filename)){ ## ------------------------------------------------------ ##
      if(verbose) cat("-> Writing data...\n")
      
      # Changes output dataType to INT4U (0	- 4,294,967,296)
      #
      terra::writeRaster(newRstPred, filename = filename, datatype="INT4U", ...)
      
      if(verbose) cat("done.\n\n")
    }
    
    return(newRstPred)

  }#else if(!canProcessInMemory(rstSegm) || forceWriteByLine){
   else if(forceWriteByLine){
     
    ## ------------------------------------------------------------------------------ ##
    ## Use memory-safe ops ----
    ## ------------------------------------------------------------------------------ ##
    
    if(verbose){
      cat("-> Outputting data using memory-safe operations...\n\n")
      pb <- utils::txtProgressBar(min = 1, max = nrow(rstSegm), style = 3)
    }
    
    # Convert factors to integers if needed otherwise class integer codes may come out wrong...
    if(is.factor(pred)) pred <- as.integer(levels(pred)[as.integer(pred)])
    
    # Append the segment IDs (SID field) to the predicted object
    if(predictFor == "train"){
      predDF <- data.frame(SID = calData$calData$SID, predClass = pred)
    }else if(predictFor == "all"){
      predDF <- data.frame(SID = calData$classifFeatData$SID, predClass = pred)
    } 
    rm(pred)
    
    # Convert to data.table object
    data.table::setDT(predDF) # Convert to data.table by reference
    data.table::setkey(predDF)
    
    #rstSegm <- terra::readStart(rstSegm)
    suppressMessages(
      suppressWarnings(terra::readStart(rstSegm)))
    
    r_out <- terra::rast(rstSegm)
    #r_out <- suppressWarnings(terra::writeStart(r_out, filename = filename, datatype="INT4U", ...))
    suppressMessages(
      suppressWarnings(
        terra::writeStart(r_out, filename = filename, datatype="INT4U", ...)))
    
    for(nr in 1:nrow(rstSegm)){
      
      # Get data on raster segments by line
      rstSegmDF <- data.frame(cell_ID = 1:ncol(rstSegm), 
                              SID     = terra::values(rstSegm, row = nr, nrows = 1)[,1])
      
      colnames(rstSegmDF) <- c("cell_ID", "SID")
      
      data.table::setDT(rstSegmDF) # Convert to data.table by reference
      data.table::setkey(rstSegmDF, SID) # Set key
      
      # Join and update column by reference
      rstSegmDF[predDF, on = 'SID', predClass := predClass]
    
      # Reorder by cell_ID otherwise it will give problems when writing data to raster
      data.table::setorder(rstSegmDF, cell_ID, na.last = FALSE)
      
      # write to output file
      #r_out <- terra::writeValues(r_out, rstSegmDF$predClass, nr, 1)
      terra::writeValues(r_out, rstSegmDF$predClass, nr, 1)
      
      
      if(verbose){
        utils::setTxtProgressBar(pb,nr)
      }
    }
    
    # close files
    #r_out <- terra::writeStop(r_out)
    #rstSegm <- terra::readStop(rstSegm)
    
    terra::writeStop(r_out)
    terra::readStop(rstSegm)

    newRstPred <- terra::rast(filename)
    return(newRstPred)
    
  }else{
    stop("Invalid option set!")
  }

}
joaofgoncalves/SegOptim documentation built on Feb. 5, 2024, 11:10 p.m.