# The 'cv_index' function
# Written by Kevin Potter
# email: kevin.w.potter@gmail.com
# Please email me directly if you
# have any questions or comments
# Last updated 2018-10-31
# Table of contents
# 1) cv_index
###
### 1) cv_index
###
#' Create Index for Cross-Validation
#'
#' Creates an index with randomized order, useful for splitting data
#' into training and test subsets for cross-validation.
#'
#' @param divisions The number of partitions by which to split the data.
#' If less than 1, instead denotes the proportion of data to withold
#' for the test sample.
#' @param n_obs The number of observations in the data.
#' @param rng_seed An optional integer (or pair of integers,
#' if \code{x} is provided) to set the random number generator
#' (RNG) state to ensure reproducibility.
#' @param x An optional discrete variable of length \code{n_obs},
#' a binary variable. If provided, the function will attempt
#' to balance the two instances for the variable for
#' more even representation across folds.
#'
#' @return An index from 1 to \code{divisions} indicating
#' group membership for each fold. Indices are
#' distributed so that each fold has approximately
#' equal numbers of observations, unless \code{divisions}
#' is less than 1. In this case, the index \code{2} represents
#' observations assigned to the test sample.
#'
#' @export
#' @examples
#' # 3 folds for 90 observations
#' index = cv_index( 3, 90 )
#' table( index )
#' # Create unbalanced data
#' x = c( rep( 1, 10 ), rep( 0, 90 ) )
#' index = cv_index( 10, length(x), x = x )
#' table( index[ x == 1 ] )
#' table( index[ x == 0 ] )
#' # Withold 30% for test sample
#' index = cv_index( .3, 100 )
#' table( index )
#' # Create unbalanced data
#' x = c( rep( 1, 10 ), rep( 0, 90 ) )
#' index = cv_index( .3, length(x), x = x )
#' table( index[ x == 1 ] )
#' table( index[ x == 0 ] )
cv_index = function( divisions, n_obs, rng_seed = NULL, x = NULL ) {
# Check if 'divisions' is instead the proportion of observations
# to withhold for the test sample
if ( divisions < 1 ) {
# If vector of instances is provided to create
# balanced folds
if ( !is.null( x ) ) {
# Frequncies
n_obs = table( x )
if ( 2 > min( n_obs ) ) {
warning( 'More folds than counts, cannot fully balance',
call. = FALSE )
}
# Initialize the index
index = rep( 0, sum( n_obs ) )
for ( i in 1:2 ) {
# Fill out index
sel = x == names( n_obs )[i]
tmp = rep( 1, round( sum(sel) * ( 1 - divisions ) ) )
tmp = c( tmp, rep( 2, sum( sel ) - length( tmp ) ) )
# Randomize
if ( !is.null( rng_seed ) ) {
# Check that two values were
# provided
if ( length( rng_seed ) < 2 ) {
rng_seed = c( rng_seed[1], rng_seed[1] + 1 )
warning(
'Adjusted RNG seed by 1 for second instance',
call. = F
)
}
set.seed( rng_seed[i] )
}
tmp = sample( tmp )
index[sel] = tmp
}
} else {
# Create index for training sample observations
index = rep( 1, round( n_obs * (1-divisions) ) )
# Add index for test sample observations
index = c( index, rep( 2, n_obs - length( index ) ) )
# Randomize
if ( !is.null( rng_seed ) ) {
set.seed( rng_seed )
}
index = sample( index )
}
} else {
if ( divisions > n_obs ) {
warning( 'More folds than observations, set to number of observations',
call. = FALSE )
}
# If vector of instances is provided to create
# balanced folds
if ( !is.null( x ) ) {
# Frequncies
n_obs = table( x )
if ( divisions > min( n_obs ) ) {
warning( 'More folds than counts, cannot fully balance',
call. = FALSE )
}
# Initialize the index
index = rep( 0, sum( n_obs ) )
# Specify indices separately for each instance
inc = floor( n_obs/divisions )
sep_index = sapply( inc, function(x) rep( 1:divisions, x ) )
# Add indices for remaining cases when n_obs is unbalanced
extra_n = c(
n_obs[1] - length( sep_index[[1]] ),
n_obs[2] - length( sep_index[[2]] )
)
# Loop over discrete categories
for ( i in 1:2 ) {
if ( extra_n[i] > 0 ) {
adj = 1:extra_n[i]
sep_index[[i]] = c( sep_index[[i]], adj )
}
# Fill out index
sel = x == names( n_obs )[i]
# Randomize
if ( !is.null( rng_seed ) ) {
# Check that two values were
# provided
if ( length( rng_seed ) < 2 ) {
rng_seed = c( rng_seed[1], rng_seed[1] + 1 )
warning(
'Adjusted RNG seed by 1 for second instance',
call. = F
)
}
set.seed( rng_seed[i] )
}
index[sel] = sample( sep_index[[i]] )
}
} else {
inc = floor( n_obs/divisions )
index = rep( 1:divisions, inc )
extra_n = n_obs - length( index )
if ( extra_n > 0 ) {
adj = 1:extra_n
index = c( index, adj )
}
# Randomize
if ( !is.null( rng_seed ) ) {
set.seed( rng_seed )
}
index = sample( index )
}
}
return( index )
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.