Nothing
# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand
#' Cross-Validation Split Generator
#'
#' @description
#' A robust cross-validation splitting utility for multiple datasets with advanced stratification and configuration options.
#'
#' @param split_dt `list` of input datasets
#' - Must contain `data.frame` or `data.table` elements
#' - Supports multiple dataset processing
#' - Cannot be empty
#' @inheritParams rsample::vfold_cv
#'
#' @details
#' Advanced Cross-Validation Mechanism:
#' \enumerate{
#' \item Input dataset validation
#' \item Stratified or unstratified sampling
#' \item Flexible fold generation
#' \item Train-validate set creation
#' }
#'
#' Sampling Strategies:
#' \itemize{
#' \item Supports multiple dataset processing
#' \item Handles stratified and unstratified sampling
#' \item Generates reproducible cross-validation splits
#' }
#'
#' @return `list` of `data.table` objects containing:
#' \itemize{
#' \item `splits`: Cross-validation split objects
#' \item `train`: Training dataset subsets
#' \item `validate`: Validation dataset subsets
#' }
#'
#' @note Important Constraints:
#' \itemize{
#' \item Requires non-empty input datasets
#' \item All datasets must be `data.frame` or `data.table`
#' \item Strata column must exist if specified
#' \item Computational resources impact large dataset processing
#' }
#'
#' @seealso
#' \itemize{
#' \item [`rsample::vfold_cv()`] Core cross-validation function
#' }
#'
#' @import data.table
#' @importFrom rsample vfold_cv
#' @export
#' @examples
#' # Prepare example data: Convert first 3 columns of iris dataset to long format and split
#' dt_split <- w2l_split(data = iris, cols2l = 1:3)
#' # dt_split is now a list containing 3 data tables for Sepal.Length, Sepal.Width, and Petal.Length
#'
#' # Example 1: Single cross-validation (no repeats)
#' split_cv(
#' split_dt = dt_split, # Input list of split data
#' v = 3, # Set 3-fold cross-validation
#' repeats = 1 # Perform cross-validation once (no repeats)
#' )
#' # Returns a list where each element contains:
#' # - splits: rsample split objects
#' # - id: fold numbers (Fold1, Fold2, Fold3)
#' # - train: training set data
#' # - validate: validation set data
#'
#' # Example 2: Repeated cross-validation
#' split_cv(
#' split_dt = dt_split, # Input list of split data
#' v = 3, # Set 3-fold cross-validation
#' repeats = 2 # Perform cross-validation twice
#' )
#' # Returns a list where each element contains:
#' # - splits: rsample split objects
#' # - id: repeat numbers (Repeat1, Repeat2)
#' # - id2: fold numbers (Fold1, Fold2, Fold3)
#' # - train: training set data
#' # - validate: validation set data
split_cv <- function(split_dt, v = 10, repeats = 1, strata = NULL, breaks = 4, pool = 0.1, ...) {
id <- splits <- NULL
# Input validation
if (!is.list(split_dt)) {
stop("split_dt must be a list")
}
if (length(split_dt) == 0) {
stop("The input split_dt cannot be empty")
}
# Check if all elements are data.frames or data.tables
is_valid <- all(sapply(split_dt, function(x) {
inherits(x, c("data.frame", "data.table"))
}))
if (!is_valid) {
stop("All elements in split_dt must be data.frames or data.tables")
}
# Initialize result list
result <- vector("list", length(split_dt))
names(result) <- names(split_dt)
# Process each element in the list
for (i in seq_along(split_dt)) {
current_data <- split_dt[[i]]
# Convert to data.table if not already
if (!data.table::is.data.table(current_data)) {
current_data <- data.table::as.data.table(current_data)
}
# Create CV splits arguments
cv_args <- list(
data = current_data,
v = v,
repeats = repeats,
breaks = breaks,
pool = pool,
...
)
# Add strata to arguments if provided and exists in data
if (!is.null(strata)) {
if (strata %in% names(current_data)) {
cv_args$strata <- strata
} else {
warning(sprintf("Strata variable '%s' not found in dataset %s, performing unstratified CV",
strata, names(split_dt)[i]))
}
}
# Perform cross-validation
cv_obj <- do.call(rsample::vfold_cv, cv_args)
# Create result data.table
result_dt <- data.table::data.table(
splits = cv_obj$splits
)
# Set id and id2 based on repeats
if (repeats == 1) {
result_dt[, id := cv_obj$id] # fold column
} else {
result_dt[, `:=`(
id = cv_obj$id, # repeat column
id2 = cv_obj$id2 # fold column
)]
}
# Add train and validation sets
result_dt[, `:=`(
train = lapply(splits, function(x) rsample::training(x)),
validate = lapply(splits, function(x) rsample::testing(x))
)]
result[[i]] <- result_dt
}
return(result)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.