R/nest_cv.R

Defines functions nest_cv

Documented in nest_cv

# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand

#' Apply Cross-Validation to Nested Data
#'
#' @description
#' The `nest_cv` function applies cross-validation splits to nested data frames or data tables within a data table. It uses the `rsample` package's `vfold_cv` function to create cross-validation splits for predictive modeling and analysis on nested datasets.
#'
#' @param nest_dt A `data.frame` or `data.table` containing at least one nested 
#' `data.frame` or `data.table` column.
#'   - Supports multi-level nested structures
#'   - Requires at least one nested data column
#' @inheritParams rsample::vfold_cv
#'
#' @details
#' The function performs the following steps:
#' \enumerate{
#'   \item Checks if the input `nest_dt` is non-empty and contains at least one nested column of `data.frame`s or `data.table`s.
#'   \item Identifies the nested columns and non-nested columns within `nest_dt`.
#'   \item Applies `rsample::vfold_cv` to each nested data frame in the specified nested column(s), creating the cross-validation splits.
#'   \item Expands the cross-validation splits and associates them with the non-nested columns.
#'   \item Extracts the training and validation data for each split and adds them to the output data table.
#' }
#'
#' If the `strata` parameter is provided, stratified sampling is performed during the cross-validation. Additional arguments can be passed to `rsample::vfold_cv` via `...`.
#' 
#' @return A `data.table` containing the cross-validation splits for each nested dataset. It includes:
#' \itemize{
#'   \item Original non-nested columns from `nest_dt`.
#'   \item `splits`: The cross-validation split objects returned by `rsample::vfold_cv`.
#'   \item `train`: The training data for each split.
#'   \item `validate`: The validation data for each split.
#' }
#'
#' @note
#' \itemize{
#'   \item The `nest_dt` must contain at least one nested column of `data.frame`s or `data.table`s.
#'   \item The function converts `nest_dt` to a `data.table` internally to ensure efficient data manipulation.
#'   \item The `strata` parameter should be a column name present in the nested data frames.
#'   \item If `strata` is specified, ensure that the specified column exists in all nested data frames.
#'   \item The `breaks` and `pool` parameters are used when `strata` is a numeric variable and control how stratification is handled.
#'   \item Additional arguments passed through `...` are forwarded to `rsample::vfold_cv`.
#' }
#'
#'
#' @seealso
#' \itemize{
#'   \item [`rsample::vfold_cv()`] Underlying cross-validation function
#'   \item [`rsample::training()`] Extract training set
#'   \item [`rsample::testing()`] Extract test set
#' }
#'
#' @import data.table
#' @importFrom rsample vfold_cv training testing
#' @export
#'
#' @examples
#' # Example: Cross-validation for nested data.table demonstrations
#'
#' # Setup test data
#' dt_nest <- w2l_nest(
#'   data = iris,                   # Input dataset
#'   cols2l = 1:2                   # Nest first 2 columns
#' )
#'
#' # Example 1: Basic 2-fold cross-validation
#' nest_cv(
#'   nest_dt = dt_nest,             # Input nested data.table
#'   v = 2                          # Number of folds (2-fold CV)
#' )
#'
#' # Example 2: Repeated 2-fold cross-validation
#' nest_cv(
#'   nest_dt = dt_nest,             # Input nested data.table
#'   v = 2,                         # Number of folds (2-fold CV)
#'   repeats = 2                    # Number of repetitions
#' )
nest_cv <- function(nest_dt, v = 10, repeats = 1, strata = NULL, breaks = 4, pool = 0.1, ...) {
  # Initialize local variables to avoid global binding warnings
  cv_split <- data <- splits <- NULL
  
  # Validate input data is not empty
  if (nrow(nest_dt) == 0) {
    stop("Input 'nest_dt' cannot be empty")
  }
  
  # Identify nested data.frame or data.table columns
  nested_cols <- names(nest_dt)[sapply(nest_dt, function(x) {
    is.list(x) && all(sapply(x, function(y) {
      inherits(y, c("data.frame", "data.table"))
    }))
  })]
  
  # Ensure at least one nested column exists
  if (length(nested_cols) == 0) {
    stop("Input 'nest_dt' must contain at least one nested column of data.frames or data.tables")
  }
  
  # Check if "data" column exists in nested columns
  if (!"data" %in% nested_cols) {
    message("Available nested columns: ", paste(nested_cols, collapse = ", "))
    message("Using first nested column '", nested_cols[1], "' for cross-validation")
  }
  
  # Create a copy of input data to prevent modification of original dataset
  dt <- data.table::copy(nest_dt)
  
  # Identify nested list columns
  is_nested_list <- sapply(dt, function(x) all(vapply(x, is.list, logical(1))))
  
  # Extract non-nested column names
  non_nested_cols <- names(dt)[!is_nested_list]
  
  # Apply cross-validation with flexible stratification
  dt[, cv_split := lapply(get(nested_cols[1]), function(x) {
    if (!is.null(strata)) {
      rsample::vfold_cv(
        data = x, 
        v = v, 
        repeats = repeats,
        strata = strata,
        breaks = breaks, 
        pool = pool, 
        ...
      )
    } else {
      rsample::vfold_cv(
        data = x, 
        v = v, 
        repeats = repeats,
        breaks = breaks, 
        pool = pool, 
        ...
      )
    }
  })
  ][, cv_split[[1]], by = non_nested_cols
  ][, ':='(
    train = lapply(splits, \(x) rsample::training(x)),
    validate = lapply(splits, \(x) rsample::testing(x))
  )][]
}

Try the mintyr package in your browser

Any scripts or data that you put into this service are public.

mintyr documentation built on April 4, 2025, 2:56 a.m.