Nothing
#' Repeated nested CV
#'
#' Performs repeated calls to a `nestedcv` model to determine performance across
#' repeated runs of nested CV.
#'
#' @param expr An expression containing a call to [nestcv.glmnet()],
#' [nestcv.train()], [nestcv.SuperLearner()] or [outercv()].
#' @param n Number of repeats
#' @param repeat_folds Optional list containing fold indices to be applied to
#' the outer CV folds.
#' @param keep Logical whether to save repeated outer CV predictions for ROC
#' curves etc.
#' @param extra Logical whether additional performance metrics are gathered for
#' binary classification models. See [metrics()].
#' @param progress Logical whether to show progress.
#' @param rep.cores Integer specifying number of cores/threads to invoke.
#' @details
#' We recommend using this with the R pipe `|>` (see examples).
#'
#' When comparing models, it is recommended to fix the sets of outer CV folds
#' used across each repeat for comparing performance between models. The
#' function [repeatfolds()] can be used to create a fixed set of outer CV folds
#' for each repeat.
#'
#' Parallelisation over repeats is performed using `parallel::mclapply` (not
#' available on windows). Beware that `cv.cores` can still be set within calls
#' to `nestedcv` models (= nested parallelisation). This means that `rep.cores`
#' x `cv.cores` number of processes/forks will be spawned, so be careful not to
#' overload your CPU. In general parallelisation of repeats using `rep.cores` is
#' faster than parallelisation using `cv.cores`.
#'
#' @returns List of S3 class 'repeatcv' containing:
#' \item{call}{the model call}
#' \item{result}{matrix of performance metrics}
#' \item{output}{(if `keep = TRUE`) a matrix or dataframe containing the outer CV
#' predictions from each repeat}
#' \item{roc}{(binary classification models only) a ROC curve object based on
#' predictions across all repeats as returned in `output`, generated by
#' `pROC::roc()`}
#' @importFrom utils setTxtProgressBar txtProgressBar
#' @examples
#' \donttest{
#' data("iris")
#' dat <- iris
#' y <- dat$Species
#' x <- dat[, 1:4]
#'
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#' n_outer_folds = 4) |>
#' repeatcv(3, rep.cores = 2)
#' res
#' summary(res)
#'
#' ## set up fixed fold indices
#' set.seed(123, "L'Ecuyer-CMRG")
#' folds <- repeatfolds(y, repeats = 3, n_outer_folds = 4)
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#' n_outer_folds = 4) |>
#' repeatcv(3, repeat_folds = folds, rep.cores = 2)
#' res
#' }
#' @export
repeatcv <- function(expr, n = 5, repeat_folds = NULL, keep = TRUE,
extra = FALSE,
progress = TRUE, rep.cores = 1L) {
start <- Sys.time()
cl <- match.call()
if (!is.null(repeat_folds) && length(repeat_folds) != n)
stop("mismatch between n and repeat_folds")
ex0 <- ex <- substitute(expr)
# modify args in expr call
d <- deparse(ex[[1]])
if (d == "nestcv.glmnet" | d == "nestcv.train") ex$finalCV <- NA
if (d == "nestcv.SuperLearner" | d == "outercv") ex$final <- FALSE
if (d == "nestcv.train") d <- ex$method
d <- gsub("nestcv.", "", d)
if (Sys.info()["sysname"] == "Windows" & rep.cores > 1) {
message("'rep.cores' > 1 is not supported on Windows. Set to 1.")
rep.cores <- 1L
}
cv.cores <- ex$cv.cores
if (is.null(cv.cores)) cv.cores <- 1
if (progress) {
if (rep.cores == 1) {pb <- txtProgressBar2(title = d)
} else {
cat_parallel("Nested cv with ", n, " repeats")
if (cv.cores > 1) {
message_parallel(":\n", rep.cores, " cores for repeats x ",
cv.cores, " cores for CV = ",
rep.cores * cv.cores, " cores total (",
parallel::detectCores(logical = FALSE), "-core CPU)")
} else {
message_parallel(" over ", rep.cores, " cores (",
parallel::detectCores(logical = FALSE), "-core CPU)")
}
cat_parallel(d, " |")
}
}
# disable openMP multithreading (fix for xgboost)
if (rep.cores >= 2) {
threads <- RhpcBLASctl::omp_get_max_threads()
if (!is.na(threads) && threads > 1) {
RhpcBLASctl::omp_set_num_threads(1L)
on.exit(RhpcBLASctl::omp_set_num_threads(threads))
}
}
res <- mclapply(seq_len(n), function(i) {
if (progress & rep.cores > 1 & i %% rep.cores == 1) {
pc <- round(((i-1) / rep.cores) / ceiling(n / rep.cores) * 100)
if (pc > 0 & pc < 100) cat_parallel(pc, "%")
ex$verbose <- 2
} else ex$verbose <- 0
if (!is.null(repeat_folds)) ex$outer_folds <- repeat_folds[[i]]
fit <- try(eval.parent(ex), silent = TRUE)
if (progress & rep.cores == 1) setTxtProgressBar(pb, i / n)
if (inherits(fit, "try-error")) {
ret <- if (keep) list(NA, NA) else NA
if (progress) {
if (rep.cores > 1) cat_parallel("x")
attr(ret, "error") <- fit[1]
}
return(ret)
}
s <- metrics(fit, extra = extra)
if (!keep) return(s)
output <- fit$output
output$rep <- i
list(s, output)
}, mc.cores = rep.cores)
if (progress) {
if (rep.cores == 1) {close(pb)
} else {
end <- Sys.time()
message_parallel("| (", format(end - start, digits = 3), ")")
}
# error messages
errs <- unique(unlist(lapply(res, function(i) attr(i, "error"))))
if (length(errs) > 0) {
for (i in errs) {warning(i, call. = FALSE)}
}
}
if (!is.null(ex$family) && ex$family == "mgaussian") {
# glmnet mgaussian
if (keep) {
res1 <- lapply(res, "[[", 1)
yn <- length(res1[[1]])
result <- lapply(seq_len(yn), function(i) {
res1b <- lapply(res1, "[[", i)
res1b <- do.call(rbind, res1b)
rownames(res1b) <- seq_len(nrow(res1b))
res1b
})
names(result) <- names(res1[[1]])
res2 <- lapply(res, "[[", 2)
output <- do.call(rbind, res2)
out <- list(call = ex0, result = result, output = output)
} else {
yn <- length(res[[1]])
result <- lapply(seq_len(yn), function(i) {
res1b <- lapply(res, "[[", i)
res1b <- do.call(rbind, res1b)
rownames(res1b) <- seq_len(nrow(res1b))
res1b
})
names(result) <- names(res[[1]])
out <- list(call = ex0, result = result)
}
} else {
# all other models
if (keep) {
res1 <- lapply(res, "[[", 1)
result <- do.call(rbind, res1)
rownames(result) <- seq_len(nrow(result))
res2 <- lapply(res, "[[", 2)
output <- do.call(rbind, res2)
out <- list(call = ex0, result = result, output = output)
if ("AUC" %in% colnames(result) & !all(is.na(output))) {
out$roc <- pROC::roc(output$testy, output$predyp, direction = "<",
quiet = TRUE)
}
} else {
result <- do.call(rbind, res)
rownames(result) <- seq_len(nrow(result))
out <- list(call = ex0, result = result)
}
}
class(out) <- c("repeatcv")
out
}
#' Create folds for repeated nested CV
#'
#' @param y Outcome vector
#' @param repeats Number of repeats
#' @param n_outer_folds Number of outer CV folds
#' @returns List containing indices of outer CV folds
#' @examples
#' \donttest{
#' data("iris")
#' dat <- iris
#' y <- dat$Species
#' x <- dat[, 1:4]
#'
#' ## set up fixed fold indices
#' set.seed(123, "L'Ecuyer-CMRG")
#' folds <- repeatfolds(y, repeats = 3, n_outer_folds = 4)
#'
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#' n_outer_folds = 4, cv.cores = 2) |>
#' repeatcv(3, repeat_folds = folds)
#' res
#' }
#' @export
#'
repeatfolds <- function(y, repeats = 5, n_outer_folds = 10) {
rfolds <- lapply(seq_len(repeats), function(i) createFolds(y, k = n_outer_folds))
names(rfolds) <- paste0("Rep", seq_len(repeats))
rfolds
}
#' @export
print.repeatcv <- function(x, digits = max(3L, getOption("digits") - 3L),
...) {
cat("Call:\n")
print(x$call)
cat("\n")
print(x$result, digits = digits)
}
#' @export
summary.repeatcv <- function(object, ...) {
if (is.list(object$result)) {
# mgaussian
df <- lapply(object$result, function(i) {
m <- colMeans(i, na.rm = TRUE)
sd <- apply(i, 2, sd, na.rm = TRUE)
sem <- sd / sqrt(nrow(i))
data.frame(mean = m, sd = sd, sem = sem)
})
n <- nrow(object$result[[1]])
} else {
m <- colMeans(object$result, na.rm = TRUE)
sd <- apply(object$result, 2, sd, na.rm = TRUE)
sem <- sd / sqrt(nrow(object$result))
df <- data.frame(mean = m, sd = sd, sem = sem)
n <- nrow(object$result)
}
structure(list(call = object$call, n = n, summary = df),
class = "summary.repeatcv")
}
#' @export
print.summary.repeatcv <- function(x,
digits = max(3L, getOption("digits") - 3L),
...) {
cat("Call:\n")
print(x$call)
cat(x$n, "repeats\n")
print(x$summary, digits = digits)
}
# Prints using shell echo from inside mclapply when run in Rstudio
cat_parallel <- function(...) {
if (Sys.getenv("RSTUDIO") != "1") return()
system(sprintf('echo "%s', paste0(..., '\\c"', collapse = "")))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.