#' Generate a raster data set with train segments
#'
#' An utility function used to create a raster dataset containing train segments for a given segmentation
#' solution and a train raster.
#'
#' @inheritParams getTrainData
#'
#' @param trainData Input train data. The input can be a \code{SpatRaster}, an integer vector (containing
#' raster data after using \code{\link[terra]{values}} function) or a character string with a path to a
#' raster layer. If trainData is an integer vector then it should have a length equal to the number of
#' pixels in rstSegm.
#'
#' @param filename A filename/path used to write the output raster (default: NULL)
#' @param verbose Print comments with function progress? (default: TRUE)
#'
#' @param ... Additional options passed to \code{writeRaster} function. The option \code{datatype}
#' is already internally set to \emph{INT2U}.
#'
#' @return A SpatRaster object with train segments.
#'
#' @importFrom terra values
#' @importFrom terra rast
#' @importFrom terra ncell
#' @importFrom terra writeRaster
#' @export
#'
getTrainRasterSegments <- function(trainData, rstSegm, filename = NULL, useThresh = TRUE, thresh = 0.5,
na.rm = TRUE, dup.rm = TRUE, minImgSegm = 30, ignore = FALSE,
verbose = TRUE, ...){
if(verbose) cat("-> Generating train data...\n") ## --------------------------------------------------- ##
trainDataDF <- getTrainData(x=trainData, rstSegm=rstSegm, useThresh=useThresh, thresh=thresh,
na.rm=na.rm, dup.rm=dup.rm, minImgSegm=minImgSegm,ignore=ignore)
if(verbose) cat("done.\n\n")
if(verbose) cat("-> Loading segmented raster...\n") ## ------------------------------------------------ ##
rstSegmDF <- data.frame(cell_ID = 1:(terra::ncell(rstSegm)), SID = terra::values(rstSegm)[,1])
colnames(rstSegmDF) <- c("cell_ID", "SID")
if(verbose) cat("done.\n\n")
if(verbose) cat("-> Joining and re-ordering train data with the segmented raster ...\n") ##------------ ##
rstSegmDF <- merge(rstSegmDF, trainDataDF, by="SID", all.x=TRUE)
rstSegmDF <- rstSegmDF[order(rstSegmDF$cell_ID), ]
if(verbose) cat("done.\n\n")
if(verbose) cat("-> Creating a new raster dataset with train segments...\n") ## --------------------- ##
newRstTrainData <- terra::rast(rstSegm)
terra::values(newRstTrainData) <- rstSegmDF$train
if(verbose) cat("done.\n\n")
if(!is.null(filename)){
if(verbose) cat("-> Writing data...\n") ## ---------------------------------------------------------- ##
terra::writeRaster(newRstTrainData, filename = filename, datatype="INT2U", ...)
if(verbose) cat("done.\n\n")
}
return(newRstTrainData)
}
#' Predict class labels for image segments
#'
#' This function uses an input \code{SOptim.Classifier} object to predict class labels for train or all segments in the
#' input segmented image.
#'
#' @param classifierObj An object of class \code{SOptim.Classifier} containing a classification algorithm
#' generated by function \code{\link{calibrateClassifier}} with option \code{runFullCalibration = TRUE}.
#'
#' @param calData An object of class \code{SOptim.CalData} generated by function \code{\link{prepareCalData}} containing
#' calibration data for train segments and the entire image.
#'
#' @param rstSegm A string defining the path to the raster with segment IDs, a \code{SpatRaster} object or
#' a \code{SOptim.SegmentationResult} object (generated by any segmentation function (check \link{segmentationGeneric}).
#'
#' @param predictFor Either option \code{"train"} which predicts class labels only for train segments
#' or option \code{"all"} which predicts for all existing segments in \code{rstSegm} (default: all).
#'
#' @param filename A file name/path used to write the output raster (default: NULL).
#'
#' @param verbose Print comments with function progress? (default: TRUE).
#'
#' @param na.rm Remove NA's? (default: TRUE).
#'
#' @param ... Additional arguments for \code{\link[terra]{writeRaster}} function except \code{datatype}
#' which is already internally set to 'INT4U'.
#'
#' @param forceWriteByLine Use memory-safe writing of raster output by line? (default: FALSE).
#' If \code{forceWriteByLine = TRUE}, then \code{filename} must define a valid file path.
#'
#' @details By default the function uses the classifier ran with the 'full' dataset (i.e., no train/test splits)
#' for making class label predictions. In case of single-class problems the threshold that maximizes the selected
#' evaluation metric (check \code{evalMetric} in \code{\link{calibrateClassifier}}) is used to dichotomize predictions.
#' For multi-class problems the output class label is set for the one with highest probability value.
#'
#' @return An object of class \code{SpatRaster} containing the predicted class labels for
#' each image segment. If the file name is defined, the function will write a file containing the output
#' raster. The output data type is INT4U (see \code{\link[terra]{datatype}} for more details) which means
#' negative values for class labels are not valid.
#'
#' @importFrom terra ncell
#' @importFrom terra values
#' @importFrom terra rast
#' @importFrom terra writeRaster
#' @importFrom terra readStart
#' @importFrom terra readStop
#' @importFrom terra writeStart
#' @importFrom terra writeStop
#' @importFrom terra writeValues
#' @importFrom stats predict
#' @importFrom stats na.omit
#' @importFrom data.table setDT
#' @importFrom data.table setkey
#' @importFrom data.table setorder
#' @importFrom utils txtProgressBar
#' @importFrom utils setTxtProgressBar
#'
#' @export
predictSegments <- function(classifierObj, calData, rstSegm, predictFor = "all",
filename = NULL, verbose = TRUE, na.rm = TRUE,
forceWriteByLine = FALSE, ...){
if(!inherits(classifierObj, "SOptim.Classifier"))
stop("classifierObj must be an object of class SOptim.Classifier generated by
calibrateClassifier function with option runFullCalibration = TRUE")
if(!inherits(calData, "SOptim.CalData"))
stop("calData must be an object of class SOptim.CalData generated by prepareCalData function")
if(!(predictFor %in% c("train","all")))
stop("predictFor must be either \"train\" (only for train segments) or \"all\"!")
# TODO: Replace canProcessInMemory
#
#
# # If large raster filename must be defined to write data to disk
# if(!canProcessInMemory(rstSegm) && is.null(filename)){
# stop("Large raster detected! To process data please define a valid output path in parameter filename.\n")
# }
# If large raster filename must be defined to write data to disk
if(forceWriteByLine && is.null(filename)){
stop("If forceWriteByLine = TRUE please define a valid output path in parameter filename.\n")
}
if(verbose) cat("-> Pre-processing data...\n") ## ----------------------------------------- ##
if(inherits(rstSegm,"SOptim.SegmentationResult")){
rstSegm <- terra::rast(rstSegm$segm)
}
# Get the classifier trained with the full dataset
clObj <- classifierObj$ClassObj$FULL
# Get metadata on number of classes
nClassType <- classifierObj$ClassParams$nClassType
# Get threshold values for dichotomizing single-class probability predictions
if(nClassType=="single-class") threshVals <- classifierObj$Thresh
# Remove NA's from data
if(na.rm){
if(predictFor == "train") calData$calData <- na.omit(calData$calData)
if(predictFor == "all") calData$classifFeatData <- na.omit(calData$classifFeatData)
}
if(verbose) cat("done.\n\n") ## ----- ##
if(verbose) cat("-> Classifying raster segments...\n") ## -------------------------------- ##
## --------------------------------------------------------------------------------------- ##
## Random forest classifier ----
## --------------------------------------------------------------------------------------- ##
if(inherits(clObj, "randomForest")){
if(nClassType == "single-class"){
# Get the threshold value used to binarize the classifier predictions
threshValFullClassif <- threshVals[length(threshVals)]
if(predictFor == "train"){
pred <- stats::predict(clObj, type="prob") ## Gets the out-of-bag predictions
pred <- as.integer(pred[,2] > threshValFullClassif)
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, type="prob", newdata = calData$classifFeatData)
pred <- as.integer(pred[,2] > threshValFullClassif)
}
}
if(nClassType == "multi-class"){
if(predictFor == "train"){
pred <- stats::predict(clObj, type = "response") ## Gets the out-of-bag predictions
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, type = "response", newdata = calData$classifFeatData)
}
}
}
## --------------------------------------------------------------------------------------- ##
## GBM classifier ----
## --------------------------------------------------------------------------------------- ##
if(inherits(clObj, "gbm")){
# Get the number of trees in GBM classifier
nt <- clObj$n.trees
if(nClassType == "single-class"){
# Get the threshold value used to binarize the classifier predictions
threshValFullClassif <- threshVals[length(threshVals)]
if(predictFor == "train"){
pred <- stats::predict(clObj, newdata = calData$calData, type = "response", n.trees = nt)
pred <- as.integer(pred > threshValFullClassif)
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "response", n.trees = nt)
pred <- as.integer(pred > threshValFullClassif)
}
}
if(nClassType == "multi-class"){
# Use only the first element of the prediction array in [,,1] and extract the class with highest probability
# outcome using as.integer(colnames(pred))[apply(pred,1,which.max)]
if(predictFor == "train"){
pred <- stats::predict(clObj, newdata = calData$calData, type = "response", n.trees = nt)[,,1]
pred <- as.integer(colnames(pred))[apply(pred,1,which.max)]
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "response", n.trees = nt)[,,1]
pred <- as.integer(colnames(pred))[apply(pred,1,which.max)]
}
}
}
## --------------------------------------------------------------------------------------- ##
## SVM classifier ----
## --------------------------------------------------------------------------------------- ##
if(inherits(clObj, "svm")){
if(nClassType == "single-class"){
# Get the threshold value used to binarize the classifier predictions
threshValFullClassif <- threshVals[length(threshVals)]
if(predictFor == "train"){
pred <- attr(stats::predict(clObj, newdata = calData$calData, probability = TRUE), "probabilities")
pred <- as.integer(pred[,2] > threshValFullClassif)
}
else if(predictFor == "all"){
pred <- attr(stats::predict(clObj, newdata = calData$classifFeatData, probability = TRUE), "probabilities")
pred <- as.integer(pred[,2] > threshValFullClassif)
}
}
if(nClassType == "multi-class"){
if(predictFor == "train"){
pred <- stats::predict(clObj, newdata = calData$calData)
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, newdata = calData$classifFeatData)
}
}
}
## --------------------------------------------------------------------------------------- ##
## KNN classifier ----
## --------------------------------------------------------------------------------------- ##
if(inherits(clObj, "knn")){
if(nClassType == "single-class" || nClassType == "multi-class"){
if(predictFor == "train"){
pred <- stats::predict(clObj, type = "response")
}
else if(predictFor == "all"){
k <- classifierObj$ClassParams$classificationMethodParams$KNN$k
balanceTrainData <- classifierObj$ClassParams$balanceTrainData
balanceMethod <- classifierObj$ClassParams$balanceMethod
pred <- stats::predict(clObj,
newdata = calData$classifFeatData[,-1],
traindata = calData$calData[,-c(1:2)],
cl = calData$calData$train,
k = k,
type = "response",
balanceTrainData = balanceTrainData,
balanceMethod = balanceMethod)
}
}
}
## --------------------------------------------------------------------------------------- ##
## FDA classifier ----
## --------------------------------------------------------------------------------------- ##
if(inherits(clObj, "fda")){
if(nClassType == "single-class"){
# Get the threshold value used to binarize the classifier predictions
threshValFullClassif <- threshVals[length(threshVals)]
if(predictFor == "train"){
pred <- stats::predict(clObj, newdata = calData$calData, type = "posterior")
pred <- as.integer(pred[,2] > threshValFullClassif)
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "posterior")
pred <- as.integer(pred[,2] > threshValFullClassif)
}
}
if(nClassType == "multi-class"){
if(predictFor == "train"){
pred <- stats::predict(clObj, newdata = calData$calData, type = "class")
}
else if(predictFor == "all"){
pred <- stats::predict(clObj, newdata = calData$classifFeatData, type = "class")
}
}
}
if(verbose) cat("done.\n\n")
#if(canProcessInMemory(rstSegm) && !forceWriteByLine){
if(!forceWriteByLine){
if(verbose) cat("-> Loading segmented raster data...\n") ## ----- ##
# Get some preliminary data on raster segments
rstSegmDF <- data.frame(cell_ID=1:terra::ncell(rstSegm), SID=terra::values(rstSegm)[,1])
colnames(rstSegmDF) <- c("cell_ID", "SID")
data.table::setDT(rstSegmDF) # Convert to data.table by reference
data.table::setkey(rstSegmDF, SID) # Set key
if(verbose) cat("done.\n\n")
if(verbose) cat("-> Merging data and creating a new raster dataset with predicted class labels...\n") ## ----- ##
# Convert factors to integers if needed otherwise class integer codes may come out wrong...
if(is.factor(pred)) pred <- as.integer(levels(pred)[as.integer(pred)])
# Append the segment IDs (SID field) to the predicted object
if(predictFor == "train"){
predDF <- data.frame(SID = calData$calData$SID, predClass = pred)
}else if(predictFor == "all"){
predDF <- data.frame(SID = calData$classifFeatData$SID, predClass = pred)
}
rm(pred)
# Convert to data.table object
data.table::setDT(predDF) # Convert to data.table by reference
data.table::setkey(predDF)
# Join and update column by reference
rstSegmDF[predDF, on = 'SID', predClass := predClass]
# Reorder by cell cell_ID otherwise it will give problems when writing data to raster
data.table::setorder(rstSegmDF, cell_ID, na.last = FALSE)
# Make new raster to fill with predicted class labels
newRstPred <- terra::rast(rstSegm)
terra::values(newRstPred) <- rstSegmDF[["predClass"]]
if(verbose) cat("done.\n\n")
if(!is.null(filename)){ ## ------------------------------------------------------ ##
if(verbose) cat("-> Writing data...\n")
# Changes output dataType to INT4U (0 - 4,294,967,296)
#
terra::writeRaster(newRstPred, filename = filename, datatype="INT4U", ...)
if(verbose) cat("done.\n\n")
}
return(newRstPred)
}#else if(!canProcessInMemory(rstSegm) || forceWriteByLine){
else if(forceWriteByLine){
## ------------------------------------------------------------------------------ ##
## Use memory-safe ops ----
## ------------------------------------------------------------------------------ ##
if(verbose){
cat("-> Outputting data using memory-safe operations...\n\n")
pb <- utils::txtProgressBar(min = 1, max = nrow(rstSegm), style = 3)
}
# Convert factors to integers if needed otherwise class integer codes may come out wrong...
if(is.factor(pred)) pred <- as.integer(levels(pred)[as.integer(pred)])
# Append the segment IDs (SID field) to the predicted object
if(predictFor == "train"){
predDF <- data.frame(SID = calData$calData$SID, predClass = pred)
}else if(predictFor == "all"){
predDF <- data.frame(SID = calData$classifFeatData$SID, predClass = pred)
}
rm(pred)
# Convert to data.table object
data.table::setDT(predDF) # Convert to data.table by reference
data.table::setkey(predDF)
#rstSegm <- terra::readStart(rstSegm)
suppressMessages(
suppressWarnings(terra::readStart(rstSegm)))
r_out <- terra::rast(rstSegm)
#r_out <- suppressWarnings(terra::writeStart(r_out, filename = filename, datatype="INT4U", ...))
suppressMessages(
suppressWarnings(
terra::writeStart(r_out, filename = filename, datatype="INT4U", ...)))
for(nr in 1:nrow(rstSegm)){
# Get data on raster segments by line
rstSegmDF <- data.frame(cell_ID = 1:ncol(rstSegm),
SID = terra::values(rstSegm, row = nr, nrows = 1)[,1])
colnames(rstSegmDF) <- c("cell_ID", "SID")
data.table::setDT(rstSegmDF) # Convert to data.table by reference
data.table::setkey(rstSegmDF, SID) # Set key
# Join and update column by reference
rstSegmDF[predDF, on = 'SID', predClass := predClass]
# Reorder by cell_ID otherwise it will give problems when writing data to raster
data.table::setorder(rstSegmDF, cell_ID, na.last = FALSE)
# write to output file
#r_out <- terra::writeValues(r_out, rstSegmDF$predClass, nr, 1)
terra::writeValues(r_out, rstSegmDF$predClass, nr, 1)
if(verbose){
utils::setTxtProgressBar(pb,nr)
}
}
# close files
#r_out <- terra::writeStop(r_out)
#rstSegm <- terra::readStop(rstSegm)
terra::writeStop(r_out)
terra::readStop(rstSegm)
newRstPred <- terra::rast(filename)
return(newRstPred)
}else{
stop("Invalid option set!")
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.