R/split_cv.R

Defines functions split_cv

Documented in split_cv

# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand

#' Cross-Validation Split Generator
#'
#' @description
#' A robust cross-validation splitting utility for multiple datasets with advanced stratification and configuration options.
#'
#' @param split_dt `list` of input datasets
#'   - Must contain `data.frame` or `data.table` elements
#'   - Supports multiple dataset processing
#'   - Cannot be empty
#' @inheritParams rsample::vfold_cv
#' 
#' @details
#' Advanced Cross-Validation Mechanism:
#' \enumerate{
#'   \item Input dataset validation
#'   \item Stratified or unstratified sampling
#'   \item Flexible fold generation
#'   \item Train-validate set creation
#' }
#'
#' Sampling Strategies:
#' \itemize{
#'   \item Supports multiple dataset processing
#'   \item Handles stratified and unstratified sampling
#'   \item Generates reproducible cross-validation splits
#' }
#'
#' @return `list` of `data.table` objects containing:
#'   \itemize{
#'     \item `splits`: Cross-validation split objects
#'     \item `train`: Training dataset subsets
#'     \item `validate`: Validation dataset subsets
#'   }
#'
#' @note Important Constraints:
#' \itemize{
#'   \item Requires non-empty input datasets
#'   \item All datasets must be `data.frame` or `data.table`
#'   \item Strata column must exist if specified
#'   \item Computational resources impact large dataset processing
#' }
#'
#' @seealso
#' \itemize{
#'   \item [`rsample::vfold_cv()`] Core cross-validation function
#' }
#'
#' @import data.table
#' @importFrom rsample vfold_cv
#' @export
#' @examples
#' # Prepare example data: Convert first 3 columns of iris dataset to long format and split
#' dt_split <- w2l_split(data = iris, cols2l = 1:3)
#' # dt_split is now a list containing 3 data tables for Sepal.Length, Sepal.Width, and Petal.Length
#'
#' # Example 1: Single cross-validation (no repeats)
#' split_cv(
#'   split_dt = dt_split,  # Input list of split data
#'   v = 3,                # Set 3-fold cross-validation
#'   repeats = 1           # Perform cross-validation once (no repeats)
#' )
#' # Returns a list where each element contains:
#' # - splits: rsample split objects
#' # - id: fold numbers (Fold1, Fold2, Fold3)
#' # - train: training set data
#' # - validate: validation set data
#'
#' # Example 2: Repeated cross-validation
#' split_cv(
#'   split_dt = dt_split,  # Input list of split data
#'   v = 3,                # Set 3-fold cross-validation
#'   repeats = 2           # Perform cross-validation twice
#' )
#' # Returns a list where each element contains:
#' # - splits: rsample split objects
#' # - id: repeat numbers (Repeat1, Repeat2)
#' # - id2: fold numbers (Fold1, Fold2, Fold3)
#' # - train: training set data
#' # - validate: validation set data
split_cv <- function(split_dt, v = 10, repeats = 1, strata = NULL, breaks = 4, pool = 0.1, ...) {
  id <- splits <- NULL
  # Input validation
  if (!is.list(split_dt)) {
    stop("split_dt must be a list")
  }
  
  if (length(split_dt) == 0) {
    stop("The input split_dt cannot be empty")
  }
  
  # Check if all elements are data.frames or data.tables
  is_valid <- all(sapply(split_dt, function(x) {
    inherits(x, c("data.frame", "data.table"))
  }))
  
  if (!is_valid) {
    stop("All elements in split_dt must be data.frames or data.tables")
  }
  
  # Initialize result list
  result <- vector("list", length(split_dt))
  names(result) <- names(split_dt)
  
  # Process each element in the list
  for (i in seq_along(split_dt)) {
    current_data <- split_dt[[i]]
    
    # Convert to data.table if not already
    if (!data.table::is.data.table(current_data)) {
      current_data <- data.table::as.data.table(current_data)
    }
    
    # Create CV splits arguments
    cv_args <- list(
      data = current_data,
      v = v,
      repeats = repeats,
      breaks = breaks,
      pool = pool,
      ...
    )
    
    # Add strata to arguments if provided and exists in data
    if (!is.null(strata)) {
      if (strata %in% names(current_data)) {
        cv_args$strata <- strata
      } else {
        warning(sprintf("Strata variable '%s' not found in dataset %s, performing unstratified CV",
                        strata, names(split_dt)[i]))
      }
    }
    
    # Perform cross-validation
    cv_obj <- do.call(rsample::vfold_cv, cv_args)
    
    # Create result data.table
    result_dt <- data.table::data.table(
      splits = cv_obj$splits
    )
    
    # Set id and id2 based on repeats
    if (repeats == 1) {
      result_dt[, id := cv_obj$id]    # fold column
    } else {
      result_dt[, `:=`(
        id = cv_obj$id,    # repeat column
        id2 = cv_obj$id2   # fold column
      )]
    }
    
    # Add train and validation sets
    result_dt[, `:=`(
      train = lapply(splits, function(x) rsample::training(x)),
      validate = lapply(splits, function(x) rsample::testing(x))
    )]
    
    result[[i]] <- result_dt
  }
  
  return(result)
}

Try the mintyr package in your browser

Any scripts or data that you put into this service are public.

mintyr documentation built on April 4, 2025, 2:56 a.m.