R/createFolds.R

Defines functions createFolds

Documented in createFolds

#' @name createFolds
#' @title Internal function
#' @description Taken from caret package, used to create folds
#' @param y target variable
#' @param k number of folds
#'
#' @return null
#' @keywords internal
createFolds <- function(y, k = 10, list = TRUE, returnTrain = FALSE)
{
    if (class(y)[1] == "Surv")
        y <- y[, "time"]
    if (is.numeric(y)) {
        cuts <- floor(length(y)/k)
        if (cuts < 2)
            cuts <- 2
        if (cuts > 5)
            cuts <- 5
        breaks <- unique(stats::quantile(y, probs = seq(0, 1, length = cuts)))
        y <- cut(y, breaks, include.lowest = TRUE)
    }
    if (k < length(y)) {
        y <- factor(as.character(y))
        numInClass <- table(y)
        foldVector <- vector(mode = "integer", length(y))
        for (i in 1:length(numInClass)) {
            min_reps <- numInClass[i]%/%k
            if (min_reps > 0) {
                spares <- numInClass[i]%%k
                seqVector <- rep(1:k, min_reps)
                if (spares > 0)
                    seqVector <- c(seqVector, sample(1:k, spares))
                foldVector[which(y == names(numInClass)[i])] <- sample(seqVector)
            }
            else {
                foldVector[which(y == names(numInClass)[i])] <- sample(1:k,
                                                                       size = numInClass[i])
            }
        }
    }
    else foldVector <- seq(along = y)
    if (list) {
        out <- split(seq(along = y), foldVector)
        names(out) <- paste("Fold", gsub(" ", "0", format(seq(along = out))),
                            sep = "")
        if (returnTrain)
            out <- lapply(out, function(data, y) y[-data], y = seq(along = y))
    }
    else out <- foldVector
    out
}

Try the superml package in your browser

Any scripts or data that you put into this service are public.

superml documentation built on Nov. 14, 2022, 9:05 a.m.