## ----------------------------------------
## K-FOLD CROSS VALIDATION FUNCTIONS
## Ceres June 03 2020
## ----------------------------------------
#' CROSS-VALIDATION FUNCTION
#'
#' @param fullDT `data.table` with full dataset
#' @param statsModel the statistical model to validate. Only works with gamlss models
#' @param k integer with number of chunks that the data should be partitioned in
#' @param idCol column with pixel/observation IDs (optional)
#' @param origData the data used to fit the statsModule, needs to be passed to [`gamlss::predictAll()`]
#' (it may not be able to access it) but also to make sure newdata in [`gamlss::predictAll()`]
#' has the same variables (even if they're not used in the model)
#' @param level passed to `gamlss:::predict`
#' @param cacheObj1 object used by [`reproducible::Cache()`] for digesting,
#' to avoid digesting the (potentially) large data arguments
#' @param cacheObj2 object used by [`reproducible::Cache()`] for digesting,
#' to avoid digesting the (potentially) large data arguments
#' @param parallel logical. Uses [`future.apply::future_lapply()`] to parallelise
#' model fitting across the k-folds, using `plan(multiprocess)`. Defaults to FALSE
#' @param ... further arguments passed to [`future::plan()`]
#' @param cacheArgs a named `list` of arguments passed to inner `Cache` calls
#'
#' @export
crossValidFunction <- function(fullDT, statsModel, origData, k = 4, idCol,
parallel = FALSE, cacheObj1 = NULL, cacheObj2 = NULL,
cacheArgs = NULL, level = NULL, ...) {
if (!requireNamespace("gamlss", quietly = TRUE)) {
stop("'gamlss' is not installed. Please install using:",
"\ninstall.packages('gamlss')")
}
if (!is.null(idCol))
origDataVars <- c(names(origData), idCol)
## remove NAs from the data without subsetting columns
if (any(is.na(fullDT[, ..origDataVars])))
stop("Please remove NAs from the variables going in the model")
## partition data into roughly equal chunks
sampDT <- unique(fullDT[, ..idCol])
sampDT[, sampID := sample(1:k, size = length(get(idCol)), replace = TRUE)]
## join samp IDs with data
fullDT <- sampDT[fullDT, on = idCol]
origDataVars <- c(origDataVars, "sampID")
message(paste("Starting cross-validation using", k, "folds"))
if (parallel) {
if (!requireNamespace("future", quietly = TRUE)) {
stop("'future' is not installed. Please install using:",
"\ninstall.packages('future')")
}
if (!requireNamespace("future.apply", quietly = TRUE)) {
stop("'future.apply' is not installed. Please install using:",
"\ninstall.packages('future.apply')")
}
if (Sys.info()[["sysname"]] == "Windows") {
future::plan(future::multisession, gc = TRUE, ...)
} else future::plan(future::multicore, ...)
crossValidResults <- future.apply::future_lapply(unique(fullDT$sampID), FUN = calcCrossValidMetrics,
fullDT = fullDT, origData = origData,
statsModel = statsModel, origDataVars = origDataVars,
level = level, cacheArgs = cacheArgs)
## Explicitly close workers
future:::ClusterRegistry("stop")
} else {
crossValidResults <- lapply(unique(fullDT$sampID), FUN = calcCrossValidMetrics,
fullDT = fullDT, origData = origData,
statsModel = statsModel, origDataVars = origDataVars,
level = level, cacheArgs = cacheArgs)
}
return(crossValidResults)
}
#' CALCULATE VALIDATION METRICS AND CONFUSION MATRIX
#'
#' to allow caching without digesting the large data table
#'
#' @param samp the sample number to pick to use as the test data set
#' @param fullDT the full dataset (not necessarily the one used to fit
#' `statsModel`, which could have been a subset (e.g. fewer columns))
#' @param origData the data used to fit `statsModel`
#' @param statsModel the fitted model
#' @param level passed to `gamlss:::predict`
#' @param origDataVars a character vector of the variables used in model fitting (including response variable and random effects.)
#' @param cacheArgs a named `list` of arguments passed to `Cache`
#'
#' @return a list with 2 entries
#'
#' @importFrom reproducible Cache
#' @importFrom data.table as.data.table set
#' @importFrom stats update
#'
#' @export
calcCrossValidMetrics <- function(samp, fullDT, origData, statsModel, origDataVars, level = NULL, cacheArgs = NULL) {
if (!requireNamespace("gamlss", quietly = TRUE)) {
stop("'gamlss' is not installed. Please install using:",
"\ninstall.packages('gamlss')")
}
if (!requireNamespace("caret", quietly = TRUE)) {
stop("'caret' is not installed. Please install using:",
"\ninstall.packages('caret')")
}
message(paste("Fold", samp))
## predict requires the original and new data to have the same columns
if (!all(names(origData) %in% names(fullDT)))
stop("'fullDT' needs to include all the columns in 'origData'")
## subset
trainData <<- fullDT[sampID != samp, ..origDataVars]
testData <- fullDT[sampID == samp, ..origDataVars]
if (any(is.na(trainData)) | any(is.na(testData)))
stop("Please remove NAs from the variables going in the model")
## trainData an testData cannot have extra cols with respect to those in the original
## data used to fit the model
cols <- names(origData)
trainData <<- trainData[, .SD, .SDcols = cols] ## need to export to .Global for gamlss...
testData <- testData[, .SD, .SDcols = cols]
## refit model on training sample then predict
trainModel <- tryCatch(update(object = statsModel, data = trainData), error = function(e) e)
if (is(trainModel, "error")) {
message("Model could not be re-fit. Error:")
message(trainModel)
validMetrics <- c("RMSE" = NA, "Rsquared" = NA, "MAE" = NA,
"Rsq" = NA, "TGD" = NA, "predictError" = NA)
} else {
params <- c("mu", "nu", "tau")
names(params) <- params
predictionsDT <- lapply(params, FUN = function(param) {
predict(trainModel, what = param,
newdata = testData, data = trainData,
type = "response", level = level)
})
predictionsDT <- as.data.table(do.call(cbind, predictionsDT))
## add response variable
set(predictionsDT, NULL, "invRobust", testData$invRobust)
## predict using meanBEINF approach
if (trainModel$family[1] != "BEINF")
stop("the object does not have a BEINF distribution")
predictionsDT[, predinvRobust := calcMeanBEINF(mu, nu, tau),
by = row.names(predictionsDT)]
## VALIDATION STATISTICS WITH CONTINUOUS VARIABLE -----------------------
RsqGAMLSS <- gamlss::Rsq(trainModel)
TGDstats <- gamlss::getTGD(trainModel, newdata = testData, data = trainData)
validMetrics <- c(caret::defaultSummary(data.frame(obs = predictionsDT$invRobust, pred = predictionsDT$predinvRobust)),
"Rsq" = RsqGAMLSS,
TGD = TGDstats$TGD,
predictError = TGDstats$predictError)
}
list(validMetrics = validMetrics)
}
#' CACHE-COMPATIBLE MODEL UPDATE
#'
#' to allow caching without digesting the large data table and model
#'
#' @param ... arguments passed to `update`
#' @param cacheObj1 an object used by Cache for digesting, to avoid digesting
#' the (potentially) large data arguments e.g.: model coefficients and a
#' sample of a column drawn with a set seed.
#' @param cacheObj2 an object used by Cache for digesting, to avoid digesting
#' the (potentially) large data arguments e.g.: model coefficients and a
#' sample of a column drawn with a set seed.
#'
#' @export
updateModelCached <- function(..., cacheObj1 = NULL, cacheObj2 = NULL) {
updatedModel <- update(...)
return(updatedModel)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.