R/repeatcv.R

Defines functions cat_parallel print.summary.repeatcv summary.repeatcv print.repeatcv repeatfolds repeatcv

Documented in repeatcv repeatfolds

#' Repeated nested CV
#'
#' Performs repeated calls to a `nestedcv` model to determine performance across
#' repeated runs of nested CV.
#' 
#' @param expr An expression containing a call to [nestcv.glmnet()],
#'   [nestcv.train()], [nestcv.SuperLearner()] or [outercv()].
#' @param n Number of repeats
#' @param repeat_folds Optional list containing fold indices to be applied to
#'   the outer CV folds.
#' @param keep Logical whether to save repeated outer CV fitted models for
#'   variable importance, SHAP etc. Note this can make the resulting object very
#'   large.
#' @param extra Logical whether additional performance metrics are gathered for
#'   binary classification models. See [metrics()].
#' @param progress Logical whether to show progress.
#' @param rep_parallel Either "mclapply" or "future". This determines which
#'   parallel backend to use.
#' @param rep.cores Integer specifying number of cores/threads to invoke.
#' Ignored if `rep_parallel = "future"`.
#' @details
#' We recommend using this with the R pipe `|>` (see examples).
#' 
#' When comparing models, it is recommended to fix the sets of outer CV folds
#' used across each repeat for comparing performance between models. The
#' function [repeatfolds()] can be used to create a fixed set of outer CV folds
#' for each repeat.
#' 
#' Parallelisation over repeats is performed using `parallel::mclapply` (not
#' available on windows) or `future` depending on how `rep_parallel` is set.
#' Beware that `cv.cores` can still be set within calls to `nestedcv` models (=
#' nested parallelisation). This means that `rep.cores` x `cv.cores` number of
#' processes/forks will be spawned, so be careful not to overload your CPU. In
#' general parallelisation of repeats using `rep.cores` is faster than
#' parallelisation using `cv.cores`. `rep.cores` is ignored if you are using
#' future. Set the number of workers for future using `future::plan()`.
#' 
#' @returns List of S3 class 'repeatcv' containing:
#' \item{call}{the model call}
#' \item{result}{matrix of performance metrics}
#' \item{output}{a matrix or dataframe containing the outer CV predictions from
#' each repeat}
#' \item{roc}{(binary classification models only) a ROC curve object based on 
#' predictions across all repeats as returned in `output`, generated by 
#' `pROC::roc()`}
#' \item{fits}{(if `keep = TRUE`) list of length `n` containing slimmed 
#' 'nestedcv' model objects for calculating variable importance or SHAP values}
#' @importFrom utils setTxtProgressBar txtProgressBar
#' @examples
#' \donttest{
#' data("iris")
#' dat <- iris
#' y <- dat$Species
#' x <- dat[, 1:4]
#'
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#'                      n_outer_folds = 4) |>
#'        repeatcv(3, rep.cores = 2)
#' res
#' summary(res)
#' 
#' ## set up fixed fold indices
#' set.seed(123, "L'Ecuyer-CMRG")
#' folds <- repeatfolds(y, repeats = 3, n_outer_folds = 4)
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#'                      n_outer_folds = 4) |>
#'        repeatcv(3, repeat_folds = folds, rep.cores = 2)
#' res
#' }
#' @export

repeatcv <- function(expr, n = 5, repeat_folds = NULL, keep = FALSE,
                     extra = FALSE, progress = TRUE,
                     rep_parallel = "mclapply", rep.cores = 1L) {
  start <- Sys.time()
  cl <- match.call()
  if (!is.null(repeat_folds) && length(repeat_folds) != n)
    stop("mismatch between n and repeat_folds")
  ex0 <- ex <- substitute(expr)
  if (!is.call(ex)) stop("expr must be a function call")
  # modify args in expr call
  d <- deparse(ex[[1]])
  if (d == "nestcv.glmnet" | d == "nestcv.train") ex$finalCV <- NA
  if (d == "nestcv.SuperLearner" | d == "outercv") ex$final <- FALSE
  if (d == "nestcv.train") d <- ex$method
  d <- gsub("nestcv.", "", d)
  
  rep_parallel <- match.arg(rep_parallel, c("mclapply", "future"))
  if (Sys.info()["sysname"] == "Windows" & rep.cores > 1 &
      rep_parallel == "mclapply") {
    message("'rep.cores' > 1 is not supported on Windows. Set to 1.")
    rep.cores <- 1L
  }
  
  cv.cores <- ex$cv.cores
  if (is.null(cv.cores)) cv.cores <- 1
  if (progress) {
    if (rep_parallel != "future") {
      if (rep.cores == 1) {
        pb <- txtProgressBar2(title = d)
      } else {
        cat_parallel("Nested cv with ", n, " repeats")
        if (cv.cores > 1) {
          message_parallel(":\n", rep.cores, " cores for repeats x ",
                           cv.cores, " cores for CV = ",
                           rep.cores * cv.cores, " cores total (",
                           parallel::detectCores(logical = FALSE), "-core CPU)")
        } else {
          message_parallel(" over ", rep.cores, " cores (",
                           parallel::detectCores(logical = FALSE), "-core CPU)")
        }
        cat_parallel(d, "  |")
      }
    } else {
      cat_parallel("Nested cv with ", n, " repeats  ")
    }
  }
  
  # disable openMP multithreading (fix for xgboost)
  if (rep.cores >= 2) {
    threads <- RhpcBLASctl::omp_get_max_threads()
    if (!is.na(threads) && threads > 1) {
      RhpcBLASctl::omp_set_num_threads(1L)
      on.exit(RhpcBLASctl::omp_set_num_threads(threads))
    }
  }
  
  if (rep_parallel == "mclapply") { 
    res <- mclapply(seq_len(n), function(i) {
      if (progress & rep.cores > 1 & i %% rep.cores == 1) {
        pc <- round(((i-1) / rep.cores) / ceiling(n / rep.cores) * 100)
        if (pc > 0 & pc < 100) cat_parallel(pc, "%")
        ex$verbose <- 2
      } else ex$verbose <- 0
      if (!is.null(repeat_folds)) ex$outer_folds <- repeat_folds[[i]]
      fit <- try(eval.parent(ex), silent = TRUE)
      if (progress & rep.cores == 1) setTxtProgressBar(pb, i / n)
      if (inherits(fit, "try-error")) {
        ret <-  if (keep) list(NA, NA, NA) else list(NA, NA)
        if (progress) {
          if (rep.cores > 1) cat_parallel("x")
        }
        attr(ret, "error") <- fit[1]
        return(ret)
      }
      s <- metrics(fit, extra = extra)
      output <- fit$output
      output$rep <- i
      if (!keep) return(list(s, output))
      list(s, output, slim(fit))
    }, mc.cores = rep.cores)
    
    if (progress) {
      if (rep.cores == 1) {close(pb)
      } else {
        end <- Sys.time()
        message_parallel("|  (", format(end - start, digits = 3), ")")
      }
    }
    
  } else if (rep_parallel == "future") {
    ex$verbose <- 0
    # make call function and args available inside future_lapply
    ex_fun_name <- ex[[1]]
    ex_fun <- eval(ex_fun_name)
    ex_arg_exprs <- as.list(ex[-1])
    ex_args <- lapply(ex_arg_exprs, eval, envir = parent.frame())
    
    res <- future_lapply(seq_len(n), function(i) {
      if (!is.null(repeat_folds)) ex_args$outer_folds <- repeat_folds[[i]]
      fit <- try(do.call(ex_fun, ex_args), silent = TRUE)
      if (inherits(fit, "try-error")) {
        ret <-  if (keep) list(NA, NA, NA) else list(NA, NA)
        attr(ret, "error") <- fit[1]
        return(ret)
      }
      s <- metrics(fit, extra = extra)
      output <- fit$output
      output$rep <- i
      if (!keep) return(list(s, output))
      list(s, output, slim(fit))
    }, future.seed = TRUE) |> suppressMessages()
    
    if (progress) {
      end <- Sys.time()
      message_parallel("(", format(end - start, digits = 3), ")")
    }
  }
  
  # error messages
  errs <- unique(unlist(lapply(res, function(i) attr(i, "error"))))
  if (length(errs) > 0) {
    for (i in errs) {warning(i, call. = FALSE)}
  }
   
  # wrap up
  if (!is.null(ex$family) && ex$family == "mgaussian") {
    # glmnet mgaussian
    res1 <- lapply(res, "[[", 1)
    yn <- length(res1[[1]])
    result <- lapply(seq_len(yn), function(i) {
      res1b <- lapply(res1, "[[", i)
      res1b <- do.call(rbind, res1b)
      rownames(res1b) <- seq_len(nrow(res1b))
      res1b
    })
    names(result) <- names(res1[[1]])
    res2 <- lapply(res, "[[", 2)
    output <- do.call(rbind, res2)
    fits <- if (keep) lapply(res, "[[", 3) else NULL
    out <- list(call = ex0, result = result, output = output, fits = fits)
  } else {
    # all other models
    res1 <- lapply(res, "[[", 1)
    result <- do.call(rbind, res1)
    rownames(result) <- seq_len(nrow(result))
    res2 <- lapply(res, "[[", 2)
    output <- do.call(rbind, res2)
    fits <- if (keep) lapply(res, "[[", 3) else NULL
    out <- list(call = ex0, result = result, output = output, fits = fits)
    if ("AUC" %in% colnames(result) & !all(is.na(output))) {
      out$roc <- pROC::roc(output$testy, output$predyp, direction = "<", 
                           quiet = TRUE)
    }
  }
  
  class(out) <- c("repeatcv")
  out
}


#' Create folds for repeated nested CV
#' 
#' @param y Outcome vector
#' @param repeats Number of repeats
#' @param n_outer_folds Number of outer CV folds
#' @returns List containing indices of outer CV folds
#' @examples
#' \donttest{
#' data("iris")
#' dat <- iris
#' y <- dat$Species
#' x <- dat[, 1:4]
#' 
#' ## set up fixed fold indices
#' set.seed(123, "L'Ecuyer-CMRG")
#' folds <- repeatfolds(y, repeats = 3, n_outer_folds = 4)
#' 
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#'                      n_outer_folds = 4, cv.cores = 2) |>
#'        repeatcv(3, repeat_folds = folds)
#' res
#' }
#' @export
#' 
repeatfolds <- function(y, repeats = 5, n_outer_folds = 10) {
  rfolds <- lapply(seq_len(repeats), function(i) createFolds(y, k = n_outer_folds))
  names(rfolds) <- paste0("Rep", seq_len(repeats))
  rfolds
}


#' @export
print.repeatcv <- function(x, digits = max(3L, getOption("digits") - 3L),
                           ...) {
  cat("Call:\n")
  print(x$call)
  cat("\n")
  print(x$result, digits = digits)
}


#' @export
summary.repeatcv <- function(object, ...) {
  if (is.list(object$result)) {
    # mgaussian
    df <- lapply(object$result, function(i) {
      m <- colMeans(i, na.rm = TRUE)
      sd <- apply(i, 2, sd, na.rm = TRUE)
      sem <- sd / sqrt(nrow(i))
      data.frame(mean = m, sd = sd, sem = sem)
    })
    n <- nrow(object$result[[1]])
  } else {
    m <- colMeans(object$result, na.rm = TRUE)
    sd <- apply(object$result, 2, sd, na.rm = TRUE)
    sem <- sd / sqrt(nrow(object$result))
    df <- data.frame(mean = m, sd = sd, sem = sem)
    n <- nrow(object$result)
  }
  
  structure(list(call = object$call, n = n, summary = df),
            class = "summary.repeatcv")
}


#' @export
print.summary.repeatcv <- function(x,
                                   digits = max(3L, getOption("digits") - 3L),
                                   ...) {
  cat("Call:\n")
  print(x$call)
  cat(x$n, "repeats\n")
  print(x$summary, digits = digits)
}


# Prints using shell echo from inside mclapply when run in Rstudio
cat_parallel <- function(...) {
  if (Sys.getenv("RSTUDIO") != "1") return()
  system(sprintf('echo "%s', paste0(..., '\\c"', collapse = "")))
}

Try the nestedcv package in your browser

Any scripts or data that you put into this service are public.

nestedcv documentation built on April 4, 2025, 2:21 a.m.