R/repeatcv.R

Defines functions cat_parallel print.summary.repeatcv summary.repeatcv print.repeatcv repeatfolds repeatcv

Documented in repeatcv repeatfolds

#' Repeated nested CV
#'
#' Performs repeated calls to a `nestedcv` model to determine performance across
#' repeated runs of nested CV.
#' 
#' @param expr An expression containing a call to [nestcv.glmnet()],
#'   [nestcv.train()], [nestcv.SuperLearner()] or [outercv()].
#' @param n Number of repeats
#' @param repeat_folds Optional list containing fold indices to be applied to
#'   the outer CV folds.
#' @param keep Logical whether to save repeated outer CV predictions for ROC
#'   curves etc.
#' @param extra Logical whether additional performance metrics are gathered for
#'   binary classification models. See [metrics()].
#' @param progress Logical whether to show progress.
#' @param rep.cores Integer specifying number of cores/threads to invoke.
#' @details
#' We recommend using this with the R pipe `|>` (see examples).
#' 
#' When comparing models, it is recommended to fix the sets of outer CV folds
#' used across each repeat for comparing performance between models. The
#' function [repeatfolds()] can be used to create a fixed set of outer CV folds
#' for each repeat.
#' 
#' Parallelisation over repeats is performed using `parallel::mclapply` (not
#' available on windows). Beware that `cv.cores` can still be set within calls
#' to `nestedcv` models (= nested parallelisation). This means that `rep.cores`
#' x `cv.cores` number of processes/forks will be spawned, so be careful not to
#' overload your CPU. In general parallelisation of repeats using `rep.cores` is
#' faster than parallelisation using `cv.cores`.
#' 
#' @returns List of S3 class 'repeatcv' containing:
#' \item{call}{the model call}
#' \item{result}{matrix of performance metrics}
#' \item{output}{(if `keep = TRUE`) a matrix or dataframe containing the outer CV 
#' predictions from each repeat}
#' \item{roc}{(binary classification models only) a ROC curve object based on 
#' predictions across all repeats as returned in `output`, generated by 
#' `pROC::roc()`}
#' @importFrom utils setTxtProgressBar txtProgressBar
#' @examples
#' \donttest{
#' data("iris")
#' dat <- iris
#' y <- dat$Species
#' x <- dat[, 1:4]
#'
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#'                      n_outer_folds = 4) |>
#'        repeatcv(3, rep.cores = 2)
#' res
#' summary(res)
#' 
#' ## set up fixed fold indices
#' set.seed(123, "L'Ecuyer-CMRG")
#' folds <- repeatfolds(y, repeats = 3, n_outer_folds = 4)
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#'                      n_outer_folds = 4) |>
#'        repeatcv(3, repeat_folds = folds, rep.cores = 2)
#' res
#' }
#' @export

repeatcv <- function(expr, n = 5, repeat_folds = NULL, keep = TRUE,
                     extra = FALSE,
                     progress = TRUE, rep.cores = 1L) {
  start <- Sys.time()
  cl <- match.call()
  if (!is.null(repeat_folds) && length(repeat_folds) != n)
    stop("mismatch between n and repeat_folds")
  ex0 <- ex <- substitute(expr)
  # modify args in expr call
  d <- deparse(ex[[1]])
  if (d == "nestcv.glmnet" | d == "nestcv.train") ex$finalCV <- NA
  if (d == "nestcv.SuperLearner" | d == "outercv") ex$final <- FALSE
  if (d == "nestcv.train") d <- ex$method
  d <- gsub("nestcv.", "", d)
  
  if (Sys.info()["sysname"] == "Windows" & rep.cores > 1) {
    message("'rep.cores' > 1 is not supported on Windows. Set to 1.")
    rep.cores <- 1L
  }
  
  cv.cores <- ex$cv.cores
  if (is.null(cv.cores)) cv.cores <- 1
  if (progress) {
    if (rep.cores == 1) {pb <- txtProgressBar2(title = d)
    } else {
      cat_parallel("Nested cv with ", n, " repeats")
      if (cv.cores > 1) {
        message_parallel(":\n", rep.cores, " cores for repeats x ",
                         cv.cores, " cores for CV = ",
                         rep.cores * cv.cores, " cores total (",
                         parallel::detectCores(logical = FALSE), "-core CPU)")
      } else {
        message_parallel(" over ", rep.cores, " cores (",
                         parallel::detectCores(logical = FALSE), "-core CPU)")
      }
      cat_parallel(d, "  |")
    }
  }
  
  # disable openMP multithreading (fix for xgboost)
  if (rep.cores >= 2) {
    threads <- RhpcBLASctl::omp_get_max_threads()
    if (!is.na(threads) && threads > 1) {
      RhpcBLASctl::omp_set_num_threads(1L)
      on.exit(RhpcBLASctl::omp_set_num_threads(threads))
    }
  }
  
  res <- mclapply(seq_len(n), function(i) {
    if (progress & rep.cores > 1 & i %% rep.cores == 1) {
      pc <- round(((i-1) / rep.cores) / ceiling(n / rep.cores) * 100)
      if (pc > 0 & pc < 100) cat_parallel(pc, "%")
      ex$verbose <- 2
    } else ex$verbose <- 0
    if (!is.null(repeat_folds)) ex$outer_folds <- repeat_folds[[i]]
    fit <- try(eval.parent(ex), silent = TRUE)
    if (progress & rep.cores == 1) setTxtProgressBar(pb, i / n)
    if (inherits(fit, "try-error")) {
      ret <-  if (keep) list(NA, NA) else NA
      if (progress) {
        if (rep.cores > 1) cat_parallel("x")
        attr(ret, "error") <- fit[1]
      }
      return(ret)
    }
    s <- metrics(fit, extra = extra)
    if (!keep) return(s)
    output <- fit$output
    output$rep <- i
    list(s, output)
  }, mc.cores = rep.cores)
  if (progress) {
    if (rep.cores == 1) {close(pb)
    } else {
      end <- Sys.time()
      message_parallel("|  (", format(end - start, digits = 3), ")")
    }
    # error messages
    errs <- unique(unlist(lapply(res, function(i) attr(i, "error"))))
    if (length(errs) > 0) {
      for (i in errs) {warning(i, call. = FALSE)}
    }
  }
  
  if (!is.null(ex$family) && ex$family == "mgaussian") {
    # glmnet mgaussian
    if (keep) {
      res1 <- lapply(res, "[[", 1)
      yn <- length(res1[[1]])
      result <- lapply(seq_len(yn), function(i) {
        res1b <- lapply(res1, "[[", i)
        res1b <- do.call(rbind, res1b)
        rownames(res1b) <- seq_len(nrow(res1b))
        res1b
      })
      names(result) <- names(res1[[1]])
      res2 <- lapply(res, "[[", 2)
      output <- do.call(rbind, res2)
      out <- list(call = ex0, result = result, output = output)
    } else {
      yn <- length(res[[1]])
      result <- lapply(seq_len(yn), function(i) {
        res1b <- lapply(res, "[[", i)
        res1b <- do.call(rbind, res1b)
        rownames(res1b) <- seq_len(nrow(res1b))
        res1b
      })
      names(result) <- names(res[[1]])
      out <- list(call = ex0, result = result)
    }
  } else {
    # all other models
    if (keep) {
      res1 <- lapply(res, "[[", 1)
      result <- do.call(rbind, res1)
      rownames(result) <- seq_len(nrow(result))
      res2 <- lapply(res, "[[", 2)
      output <- do.call(rbind, res2)
      out <- list(call = ex0, result = result, output = output)
      if ("AUC" %in% colnames(result) & !all(is.na(output))) {
        out$roc <- pROC::roc(output$testy, output$predyp, direction = "<", 
                             quiet = TRUE)
      }
    } else {
      result <- do.call(rbind, res)
      rownames(result) <- seq_len(nrow(result))
      out <- list(call = ex0, result = result)
    }
  }
  
  class(out) <- c("repeatcv")
  out
}


#' Create folds for repeated nested CV
#' 
#' @param y Outcome vector
#' @param repeats Number of repeats
#' @param n_outer_folds Number of outer CV folds
#' @returns List containing indices of outer CV folds
#' @examples
#' \donttest{
#' data("iris")
#' dat <- iris
#' y <- dat$Species
#' x <- dat[, 1:4]
#' 
#' ## set up fixed fold indices
#' set.seed(123, "L'Ecuyer-CMRG")
#' folds <- repeatfolds(y, repeats = 3, n_outer_folds = 4)
#' 
#' res <- nestcv.glmnet(y, x, family = "multinomial", alphaSet = 1,
#'                      n_outer_folds = 4, cv.cores = 2) |>
#'        repeatcv(3, repeat_folds = folds)
#' res
#' }
#' @export
#' 
repeatfolds <- function(y, repeats = 5, n_outer_folds = 10) {
  rfolds <- lapply(seq_len(repeats), function(i) createFolds(y, k = n_outer_folds))
  names(rfolds) <- paste0("Rep", seq_len(repeats))
  rfolds
}


#' @export
print.repeatcv <- function(x, digits = max(3L, getOption("digits") - 3L),
                           ...) {
  cat("Call:\n")
  print(x$call)
  cat("\n")
  print(x$result, digits = digits)
}


#' @export
summary.repeatcv <- function(object, ...) {
  if (is.list(object$result)) {
    # mgaussian
    df <- lapply(object$result, function(i) {
      m <- colMeans(i, na.rm = TRUE)
      sd <- apply(i, 2, sd, na.rm = TRUE)
      sem <- sd / sqrt(nrow(i))
      data.frame(mean = m, sd = sd, sem = sem)
    })
    n <- nrow(object$result[[1]])
  } else {
    m <- colMeans(object$result, na.rm = TRUE)
    sd <- apply(object$result, 2, sd, na.rm = TRUE)
    sem <- sd / sqrt(nrow(object$result))
    df <- data.frame(mean = m, sd = sd, sem = sem)
    n <- nrow(object$result)
  }
  
  structure(list(call = object$call, n = n, summary = df),
            class = "summary.repeatcv")
}


#' @export
print.summary.repeatcv <- function(x,
                                   digits = max(3L, getOption("digits") - 3L),
                                   ...) {
  cat("Call:\n")
  print(x$call)
  cat(x$n, "repeats\n")
  print(x$summary, digits = digits)
}


# Prints using shell echo from inside mclapply when run in Rstudio
cat_parallel <- function(...) {
  if (Sys.getenv("RSTUDIO") != "1") return()
  system(sprintf('echo "%s', paste0(..., '\\c"', collapse = "")))
}

Try the nestedcv package in your browser

Any scripts or data that you put into this service are public.

nestedcv documentation built on June 22, 2024, 11:30 a.m.