R/cv_index.R

# The 'cv_index' function
# Written by Kevin Potter
# email: kevin.w.potter@gmail.com
# Please email me directly if you
# have any questions or comments
# Last updated 2018-10-31

# Table of contents
# 1) cv_index

###
### 1) cv_index
###

#' Create Index for Cross-Validation
#'
#' Creates an index with randomized order, useful for splitting data
#' into training and test subsets for cross-validation.
#'
#' @param divisions The number of partitions by which to split the data.
#'   If less than 1, instead denotes the proportion of data to withold
#'   for the test sample.
#' @param n_obs The number of observations in the data.
#' @param rng_seed An optional integer (or pair of integers,
#'   if \code{x} is provided) to set the random number generator
#'   (RNG) state to ensure reproducibility.
#' @param x An optional discrete variable of length \code{n_obs},
#'   a binary variable. If provided, the function will attempt
#'   to balance the two instances for the variable for
#'   more even representation across folds.
#'
#' @return An index from 1 to \code{divisions} indicating
#'   group membership for each fold. Indices are
#'   distributed so that each fold has approximately
#'   equal numbers of observations, unless \code{divisions}
#'   is less than 1. In this case, the index \code{2} represents
#'   observations assigned to the test sample.
#'
#' @export
#' @examples
#' # 3 folds for 90 observations
#' index = cv_index( 3, 90 )
#' table( index )
#' # Create unbalanced data
#' x = c( rep( 1, 10 ), rep( 0, 90 ) )
#' index = cv_index( 10, length(x), x = x )
#' table( index[ x == 1 ] )
#' table( index[ x == 0 ] )
#' # Withold 30% for test sample
#' index = cv_index( .3, 100 )
#' table( index )
#' # Create unbalanced data
#' x = c( rep( 1, 10 ), rep( 0, 90 ) )
#' index = cv_index( .3, length(x), x = x )
#' table( index[ x == 1 ] )
#' table( index[ x == 0 ] )

cv_index = function( divisions, n_obs, rng_seed = NULL, x = NULL ) {

  # Check if 'divisions' is instead the proportion of observations
  # to withhold for the test sample
  if ( divisions < 1 ) {


    # If vector of instances is provided to create
    # balanced folds
    if ( !is.null( x ) ) {

      # Frequncies
      n_obs = table( x )

      if ( 2 > min( n_obs ) ) {
        warning( 'More folds than counts, cannot fully balance',
                 call. = FALSE )
      }

      # Initialize the index
      index = rep( 0, sum( n_obs ) )

      for ( i in 1:2 ) {

        # Fill out index
        sel = x == names( n_obs )[i]

        tmp = rep( 1, round( sum(sel) * ( 1 - divisions ) ) )
        tmp = c( tmp, rep( 2, sum( sel ) - length( tmp ) ) )

        # Randomize
        if ( !is.null( rng_seed ) ) {

          # Check that two values were
          # provided
          if ( length( rng_seed ) < 2 ) {
            rng_seed = c( rng_seed[1], rng_seed[1] + 1 )
            warning(
              'Adjusted RNG seed by 1 for second instance',
              call. = F
            )
          }

          set.seed( rng_seed[i] )
        }
        tmp = sample( tmp )
        index[sel] = tmp

      }

    } else {

      # Create index for training sample observations
      index = rep( 1, round( n_obs * (1-divisions) ) )
      # Add index for test sample observations
      index = c( index, rep( 2, n_obs - length( index ) ) )

      # Randomize
      if ( !is.null( rng_seed ) ) {
        set.seed( rng_seed )
      }
      index = sample( index )

    }


  } else {

    if ( divisions > n_obs ) {
      warning( 'More folds than observations, set to number of observations',
               call. = FALSE )
    }

    # If vector of instances is provided to create
    # balanced folds
    if ( !is.null( x ) ) {

      # Frequncies
      n_obs = table( x )

      if ( divisions > min( n_obs ) ) {
        warning( 'More folds than counts, cannot fully balance',
                 call. = FALSE )
      }

      # Initialize the index
      index = rep( 0, sum( n_obs ) )

      # Specify indices separately for each instance
      inc = floor( n_obs/divisions )
      sep_index = sapply( inc, function(x) rep( 1:divisions, x ) )

      # Add indices for remaining cases when n_obs is unbalanced
      extra_n = c(
        n_obs[1] - length( sep_index[[1]] ),
        n_obs[2] - length( sep_index[[2]] )
      )

      # Loop over discrete categories
      for ( i in 1:2 ) {

        if ( extra_n[i] > 0 ) {
          adj = 1:extra_n[i]
          sep_index[[i]] = c( sep_index[[i]], adj )
        }

        # Fill out index
        sel = x == names( n_obs )[i]

        # Randomize
        if ( !is.null( rng_seed ) ) {

          # Check that two values were
          # provided
          if ( length( rng_seed ) < 2 ) {
            rng_seed = c( rng_seed[1], rng_seed[1] + 1 )
            warning(
              'Adjusted RNG seed by 1 for second instance',
              call. = F
            )
          }

          set.seed( rng_seed[i] )
        }
        index[sel] = sample( sep_index[[i]] )

      }

    } else {

      inc = floor( n_obs/divisions )

      index = rep( 1:divisions, inc )

      extra_n = n_obs - length( index )
      if ( extra_n > 0 ) {

        adj = 1:extra_n
        index = c( index, adj )

      }

      # Randomize
      if ( !is.null( rng_seed ) ) {
        set.seed( rng_seed )
      }
      index = sample( index )

    }

  }

  return( index )
}
rettopnivek/binclass documentation built on May 13, 2019, 4:46 p.m.