Nothing
# WARNING - Generated by {fusen} from dev/flat_teaching.Rmd: do not edit by hand
#' Apply Cross-Validation to Nested Data
#'
#' @description
#' The `nest_cv` function applies cross-validation splits to nested data frames or data tables within a data table. It uses the `rsample` package's `vfold_cv` function to create cross-validation splits for predictive modeling and analysis on nested datasets.
#'
#' @param nest_dt A `data.frame` or `data.table` containing at least one nested
#' `data.frame` or `data.table` column.
#' - Supports multi-level nested structures
#' - Requires at least one nested data column
#' @inheritParams rsample::vfold_cv
#'
#' @details
#' The function performs the following steps:
#' \enumerate{
#' \item Checks if the input `nest_dt` is non-empty and contains at least one nested column of `data.frame`s or `data.table`s.
#' \item Identifies the nested columns and non-nested columns within `nest_dt`.
#' \item Applies `rsample::vfold_cv` to each nested data frame in the specified nested column(s), creating the cross-validation splits.
#' \item Expands the cross-validation splits and associates them with the non-nested columns.
#' \item Extracts the training and validation data for each split and adds them to the output data table.
#' }
#'
#' If the `strata` parameter is provided, stratified sampling is performed during the cross-validation. Additional arguments can be passed to `rsample::vfold_cv` via `...`.
#'
#' @return A `data.table` containing the cross-validation splits for each nested dataset. It includes:
#' \itemize{
#' \item Original non-nested columns from `nest_dt`.
#' \item `splits`: The cross-validation split objects returned by `rsample::vfold_cv`.
#' \item `train`: The training data for each split.
#' \item `validate`: The validation data for each split.
#' }
#'
#' @note
#' \itemize{
#' \item The `nest_dt` must contain at least one nested column of `data.frame`s or `data.table`s.
#' \item The function converts `nest_dt` to a `data.table` internally to ensure efficient data manipulation.
#' \item The `strata` parameter should be a column name present in the nested data frames.
#' \item If `strata` is specified, ensure that the specified column exists in all nested data frames.
#' \item The `breaks` and `pool` parameters are used when `strata` is a numeric variable and control how stratification is handled.
#' \item Additional arguments passed through `...` are forwarded to `rsample::vfold_cv`.
#' }
#'
#'
#' @seealso
#' \itemize{
#' \item [`rsample::vfold_cv()`] Underlying cross-validation function
#' \item [`rsample::training()`] Extract training set
#' \item [`rsample::testing()`] Extract test set
#' }
#'
#' @import data.table
#' @importFrom rsample vfold_cv training testing
#' @export
#'
#' @examples
#' # Example: Cross-validation for nested data.table demonstrations
#'
#' # Setup test data
#' dt_nest <- w2l_nest(
#' data = iris, # Input dataset
#' cols2l = 1:2 # Nest first 2 columns
#' )
#'
#' # Example 1: Basic 2-fold cross-validation
#' nest_cv(
#' nest_dt = dt_nest, # Input nested data.table
#' v = 2 # Number of folds (2-fold CV)
#' )
#'
#' # Example 2: Repeated 2-fold cross-validation
#' nest_cv(
#' nest_dt = dt_nest, # Input nested data.table
#' v = 2, # Number of folds (2-fold CV)
#' repeats = 2 # Number of repetitions
#' )
nest_cv <- function(nest_dt, v = 10, repeats = 1, strata = NULL, breaks = 4, pool = 0.1, ...) {
# Initialize local variables to avoid global binding warnings
cv_split <- data <- splits <- NULL
# Validate input data is not empty
if (nrow(nest_dt) == 0) {
stop("Input 'nest_dt' cannot be empty")
}
# Identify nested data.frame or data.table columns
nested_cols <- names(nest_dt)[sapply(nest_dt, function(x) {
is.list(x) && all(sapply(x, function(y) {
inherits(y, c("data.frame", "data.table"))
}))
})]
# Ensure at least one nested column exists
if (length(nested_cols) == 0) {
stop("Input 'nest_dt' must contain at least one nested column of data.frames or data.tables")
}
# Check if "data" column exists in nested columns
if (!"data" %in% nested_cols) {
message("Available nested columns: ", paste(nested_cols, collapse = ", "))
message("Using first nested column '", nested_cols[1], "' for cross-validation")
}
# Create a copy of input data to prevent modification of original dataset
dt <- data.table::copy(nest_dt)
# Identify nested list columns
is_nested_list <- sapply(dt, function(x) all(vapply(x, is.list, logical(1))))
# Extract non-nested column names
non_nested_cols <- names(dt)[!is_nested_list]
# Apply cross-validation with flexible stratification
dt[, cv_split := lapply(get(nested_cols[1]), function(x) {
if (!is.null(strata)) {
rsample::vfold_cv(
data = x,
v = v,
repeats = repeats,
strata = strata,
breaks = breaks,
pool = pool,
...
)
} else {
rsample::vfold_cv(
data = x,
v = v,
repeats = repeats,
breaks = breaks,
pool = pool,
...
)
}
})
][, cv_split[[1]], by = non_nested_cols
][, ':='(
train = lapply(splits, \(x) rsample::training(x)),
validate = lapply(splits, \(x) rsample::testing(x))
)][]
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.