#' Generic cross-validation function
#'
#' Generic cross-validation
#'
#' @param x input covariates' matrix
#' @param y response variable; a vector
#' @param fit_func a function for fitting the model
#' @param predict_func a function for predicting values from the model
#' @param fit_params a list; additional (model-specific) parameters to be passed
#' to \code{fit_func}
#' @param k an integer; number of folds in k-fold cross validation
#' @param repeats an integer; number of repeats for the k-fold cross validation
#' @param p a double; proportion of data in the training/testing set, default is 1 and
#' must be > 0.5. If \code{p} < 1, a validation set error is calculated on the
#' remaining 1-\code{p} fraction data
#' @param seed random seed for reproducibility of results
#' @param eval_metric a function measuring the test errors; if not provided: RMSE for regression and
#' accuracy for classification
#' @param cl an integer; the number of clusters for parallel execution
#' @param errorhandling specifies how a task evalution error should be handled.
#' If value is "stop", then execution will be stopped if an error occurs. If value
#' is "remove", the result for that task will not be returned. If value is "pass",
#' then the error object generated by task evaluation will be included with the
#' rest of the results. The default value is "stop".
#' @param packages character vector of packages that the tasks depend on
#' @param verbose logical flag enabling verbose messages. This can be very useful for
#' troubleshooting.
#' @param show_progress show evolution of the algorithm
#' @param ... additional parameters
#'
#' @return
#' @export
#'
#' @examples
#'
#'# dataset
#'
#' set.seed(123)
#' n <- 1000 ; p <- 10
#' X <- matrix(rnorm(n * p), n, p)
#' y <- rnorm(n)
#'
#'# linear model example -----
#'
#' crossval::crossval_ml(x = X, y = y, k = 5, repeats = 3)
#'
#'
#'# randomForest example -----
#'
#'require(randomForest)
#'
#'# fit randomForest with mtry = 2
#'
#'crossval::crossval_ml(x = X, y = y, k = 5, repeats = 3,
#' fit_func = randomForest::randomForest, predict_func = predict,
#' packages = "randomForest", fit_params = list(mtry = 2))
#'
#'# fit randomForest with mtry = 4
#'
#'crossval::crossval_ml(x = X, y = y, k = 5, repeats = 3,
#' fit_func = randomForest::randomForest, predict_func = predict,
#' packages = "randomForest", fit_params = list(mtry = 4))
#'
#'# fit randomForest with mtry = 4, with a validation set
#'
#'crossval::crossval_ml(x = X, y = y, k = 5, repeats = 2, p = 0.8,
#' fit_func = randomForest::randomForest, predict_func = predict,
#' packages = "randomForest", fit_params = list(mtry = 4))
#'
crossval_ml <- function(x,
y,
fit_func = crossval::fit_lm,
predict_func = crossval::predict_lm,
fit_params = NULL,
# and hyperparameters
k = 5,
repeats = 3,
p = 1,
seed = 123,
eval_metric = NULL,
cl = NULL,
errorhandling = c('stop', 'remove', 'pass'),
packages = c("stats", "Rcpp"),
verbose = FALSE,
show_progress = TRUE,
...) {
n_y <- length(y)
stopifnot(n_y == nrow(x))
set.seed(seed)
if (p == 1)
# default
{
x <- as.matrix(x)
} else {
index_train <- sample.int(n_y, size = floor(p * n_y))
x <- as.matrix(x[index_train,])
y <- y[index_train]
x_validation <- as.matrix(x[-index_train,])
y_validation <- y[-index_train]
}
errorhandling <- match.arg(errorhandling)
stopifnot(floor(k) == k || k > 10)
stopifnot(p >= 0.5 && p <= 1)
stopifnot(floor(repeats) == repeats)
# evaluation metric for cv error
if (is.null(eval_metric))
{
if (is.factor(y))
# classification
{
eval_metric <- function (preds, actual)
{
res <- mean(preds == actual)
names(res) <- "accuracy"
return(res)
}
} else {
# regression
eval_metric <- function (preds, actual)
{
res <- sqrt(mean((preds - actual) ^ 2))
names(res) <- "rmse"
return(res)
}
}
}
set.seed(seed)
list_folds <- lapply(1:repeats,
function (i)
crossval::create_folds(y = y, k = k))
ptm <- proc.time()
# parallel exec.
if (!is.null(cl) && cl > 0)
{
cl_SOCK <- parallel::makeCluster(cl, type = "SOCK")
doSNOW::registerDoSNOW(cl_SOCK)
`%op1%` <- foreach::`%dopar%`
`%op2%` <- foreach::`%do%`
nb_iter <- k * repeats
pb <- txtProgressBar(min = 0,
max = k,
style = 3)
progress <- function(n)
utils::setTxtProgressBar(pb, n)
opts <- list(progress = progress)
i <- NULL
j <- NULL
res <- foreach::foreach(
i = 1:k,
.packages = packages,
.combine = rbind,
.errorhandling = errorhandling,
.options.snow = opts,
.verbose = verbose,
.export = c("create_folds")
) %op1% {
foreach::foreach(
j = 1:repeats,
.packages = packages,
.combine = cbind,
.verbose = FALSE,
.errorhandling = errorhandling,
.export = c("fit_params")
) %op2% {
train_index <- -list_folds[[j]][[i]]
test_index <-
-train_index
# fit
set.seed(seed) # in case the algo is randomized
fit_func_train <-
function(x, y, ...)
fit_func(x = x[train_index,],
y = y[train_index],
...)
fit_obj <-
do.call(what = fit_func_train,
args = c(list(x = x[train_index,],
y = y[train_index]),
fit_params))
# predict
preds <-
try(predict_func(fit_obj, newdata = x[test_index,]),
silent = TRUE)
if (class(preds) == "try-error")
{
preds <- try(predict_func(fit_obj, newx = x[test_index,]),
silent = TRUE)
if (class(preds) == "try-error")
{
preds <- rep(NA, length(test_index))
}
}
# measure the error
error_measure <-
eval_metric(preds, y[test_index])
if (show_progress)
{
setTxtProgressBar(pb, i * j)
}
if (p == 1) {
error_measure
} else {
# there is a validation set
# predict on validation set
preds_validation <-
try(predict_func(fit_obj,
newdata = x_validation),
silent = TRUE)
if (class(preds_validation) == "try-error")
{
preds_validation <- try(predict_func(fit_obj,
newx = x_validation),
silent = TRUE)
if (class(preds_validation) == "try-error")
{
preds_validation <- rep(NA, length(y_validation))
}
}
# measure the validation error
c(error_measure,
eval_metric(preds_validation, y_validation))
}
}
}
close(pb)
snow::stopCluster(cl_SOCK)
} else {
# sequential exec.
`%op%` <- foreach::`%do%`
pb <- txtProgressBar(min = 0,
max = k,
style = 3)
progress <- function(n)
utils::setTxtProgressBar(pb, n)
i <- NULL
res <- foreach::foreach(
i = 1:k,
.packages = packages,
.combine = rbind,
.errorhandling = errorhandling,
.verbose = verbose,
.export = c("create_folds")
) %op% {
if (show_progress)
{
setTxtProgressBar(pb, i)
}
temp <-
foreach::foreach(
j = 1:repeats,
.packages = packages,
.combine = cbind,
.verbose = FALSE,
.errorhandling = errorhandling,
.export = c("fit_params")
) %op% {
train_index <- -list_folds[[j]][[i]]
test_index <-
-train_index
# fit
set.seed(seed) # in case the algo is randomized
fit_func_train <-
function(x, y, ...)
fit_func(x = x[train_index,],
y = y[train_index],
...)
fit_obj <-
do.call(what = fit_func_train,
args = c(list(x = x[train_index,],
y = y[train_index]),
fit_params))
# predict
preds <-
try(predict_func(fit_obj, newdata = x[test_index,]),
silent = TRUE)
if (class(preds) == "try-error")
{
preds <- try(predict_func(fit_obj, newx = x[test_index,]),
silent = TRUE)
if (class(preds) == "try-error")
{
preds <- rep(NA, length(test_index))
}
}
# measure the error
error_measure <-
eval_metric(preds, y[test_index])
if (p == 1) {
error_measure
} else {
# there is a validation set
# predict on validation set
preds_validation <-
try(predict_func(fit_obj,
newdata = x_validation),
silent = TRUE)
if (class(preds_validation) == "try-error")
{
preds_validation <- try(predict_func(fit_obj,
newx = x_validation),
silent = TRUE)
if (class(preds_validation) == "try-error")
{
preds_validation <- rep(NA, length(y_validation))
}
}
# measure the validation error
c(error_measure,
eval_metric(preds_validation, y_validation))
}
}
}
}
if (show_progress)
{
cat("\n")
print(proc.time() - ptm)
cat("\n")
}
if (p == 1)
{
colnames(res) <- paste0("repeat_", 1:ncol(res))
rownames(res) <- paste0("fold_", 1:nrow(res))
return(list(
folds = res,
mean = mean(res, na.rm = TRUE),
sd = sd(res, na.rm = TRUE),
median = median(res, na.rm = TRUE)
))
} else {
if (repeats > 1)
{
colnames(res) <- paste0("repeat_", 1:ncol(res))
rownames(res) <-
paste0(rep(c(
"fold_training_", "fold_validation_"
), k),
rep(1:k, each = 2))
} else {
res <- as.numeric(res)
names(res) <-
paste0(rep(c(
"fold_training_", "fold_validation_"
), k),
rep(1:k, each = 2))
}
n_folds <- nrow(res)
train_test_df <- res[seq(1, n_folds, by = 2), ]
validation_df <- res[seq(2, n_folds, by = 2), ]
return(
list(
folds = res,
mean_training = mean(train_test_df),
mean_validation = mean(validation_df),
sd_training = sd(train_test_df),
sd_validation = sd(validation_df),
median_training = median(train_test_df),
median_validation = median(validation_df)
)
)
}
}
compiler::cmpfun(crossval_ml)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.