fold: Create balanced folds for cross-validation

View source: R/fold.R

foldR Documentation

Create balanced folds for cross-validation

Description

\Sexpr[results=rd, stage=render]{lifecycle::badge("stable")}

Divides data into groups by a wide range of methods. Balances a given categorical variable and/or numerical variable between folds and keeps (if possible) all data points with a shared ID (e.g. participant_id) in the same fold. Can create multiple unique fold columns for repeated cross-validation.

Usage

fold(
  data,
  k = 5,
  cat_col = NULL,
  num_col = NULL,
  id_col = NULL,
  method = "n_dist",
  id_aggregation_fn = sum,
  extreme_pairing_levels = 1,
  num_fold_cols = 1,
  unique_fold_cols_only = TRUE,
  max_iters = 5,
  use_of_triplets = "fill",
  handle_existing_fold_cols = "keep_warn",
  parallel = FALSE
)

Arguments

data

data.frame. Can be grouped, in which case the function is applied group-wise.

k

Depends on `method`.

Number of folds (default), fold size, with more (see `method`).

When `num_fold_cols` > 1, `k` can also be a vector with one `k` per fold column. This allows trying multiple `k` settings at a time. Note that the generated fold columns are not guaranteed to be in the order of `k`.

Given as whole number or percentage (0 < `k` < 1).

cat_col

Name of categorical variable to balance between folds.

E.g. when predicting a binary variable (a or b), we usually want both classes represented in every fold.

N.B. If also passing an `id_col`, `cat_col` should be constant within each ID.

num_col

Name of numerical variable to balance between folds.

N.B. When used with `id_col`, values for each ID are aggregated using `id_aggregation_fn` before being balanced.

N.B. When passing `num_col`, the `method` parameter is ignored.

id_col

Name of factor with IDs. This will be used to keep all rows that share an ID in the same fold (if possible).

E.g. If we have measured a participant multiple times and want to see the effect of time, we want to have all observations of this participant in the same fold.

N.B. When `data` is a grouped data.frame (see dplyr::group_by()), IDs that appear in multiple groupings might end up in different folds in those groupings.

method

"n_dist", "n_fill", "n_last", "n_rand", "greedy", or "staircase".

Notice: examples are sizes of the generated groups based on a vector with 57 elements.

n_dist (default)

Divides the data into a specified number of groups and distributes excess data points across groups (e.g. 11, 11, 12, 11, 12).

`k` is number of groups

n_fill

Divides the data into a specified number of groups and fills up groups with excess data points from the beginning (e.g. 12, 12, 11, 11, 11).

`k` is number of groups

n_last

Divides the data into a specified number of groups. It finds the most equal group sizes possible, using all data points. Only the last group is able to differ in size (e.g. 11, 11, 11, 11, 13).

`k` is number of groups

n_rand

Divides the data into a specified number of groups. Excess data points are placed randomly in groups (only 1 per group) (e.g. 12, 11, 11, 11, 12).

`k` is number of groups

greedy

Divides up the data greedily given a specified group size (e.g. 10, 10, 10, 10, 10, 7).

`k` is group size

staircase

Uses step size to divide up the data. Group size increases with 1 step for every group, until there is no more data (e.g. 5, 10, 15, 20, 7).

`k` is step size

id_aggregation_fn

Function for aggregating values in `num_col` for each ID, before balancing `num_col`.

N.B. Only used when `num_col` and `id_col` are both specified.

extreme_pairing_levels

How many levels of extreme pairing to do when balancing folds by a numerical column (i.e. `num_col` is specified).

Extreme pairing: Rows/pairs are ordered as smallest, largest, second smallest, second largest, etc. If extreme_pairing_levels > 1, this is done "recursively" on the extreme pairs. See `Details/num_col` for more.

N.B. Larger values work best with large datasets. If set too high, the result might not be stochastic. Always check if an increase actually makes the folds more balanced. See example.

num_fold_cols

Number of fold columns to create. Useful for repeated cross-validation.

If num_fold_cols > 1, columns will be named ".folds_1", ".folds_2", etc. Otherwise simply ".folds".

N.B. If `unique_fold_cols_only` is TRUE, we can end up with fewer columns than specified, see `max_iters`.

N.B. If `data` has existing fold columns, see `handle_existing_fold_cols`.

unique_fold_cols_only

Check if fold columns are identical and keep only unique columns.

As the number of column comparisons can be time consuming, we can run this part in parallel. See `parallel`.

N.B. We can end up with fewer columns than specified in `num_fold_cols`, see `max_iters`.

N.B. Only used when `num_fold_cols` > 1 or `data` has existing fold columns.

max_iters

Maximum number of attempts at reaching `num_fold_cols` unique fold columns.

When only keeping unique fold columns, we risk having fewer columns than expected. Hence, we repeatedly create the missing columns and remove those that are not unique. This is done until we have `num_fold_cols` unique fold columns or we have attempted `max_iters` times.

In some cases, it is not possible to create `num_fold_cols` unique combinations of the dataset, e.g. when specifying `cat_col`, `id_col` and `num_col`. `max_iters` specifies when to stop trying. Note that we can end up with fewer columns than specified in `num_fold_cols`.

N.B. Only used when `num_fold_cols` > 1.

use_of_triplets

"fill", "instead" or "never".

When to use extreme triplet grouping in numerical balancing (when `num_col` is specified).

fill (default)

When extreme pairing cannot create enough unique fold columns, use extreme triplet grouping to create additional unique fold columns.

instead

Use extreme triplet grouping instead of extreme pairing. For some datasets, grouping in triplets give better balancing than grouping in pairs. This can be worth exploring when numerical balancing is important.

Tip: Compare the balances with summarize_balances() and ranked_balances().

never

Never use extreme triplet grouping.

Extreme triplet grouping

Similar to extreme pairing (see Details >> num_col), extreme triplet grouping orders the rows as smallest, closest to the median, largest, second smallest, second closest to the median, second largest, etc. Each triplet gets a group identifier and we either perform recursive extreme triplet grouping on the identifiers or fold the identifiers and transfer the fold IDs to the original rows.

For some datasets, this can be give more balanced groups than extreme pairing, but on average, extreme pairing works better. Due to the grouping into triplets instead of pairs they tend to create different groupings though, so when creating many fold columns and extreme pairing cannot create enough unique fold columns, we can create the remaining (or at least some additional number) with extreme triplet grouping.

Extreme triplet grouping is implemented in rearrr::triplet_extremes().

handle_existing_fold_cols

How to handle existing fold columns. Either "keep_warn", "keep", or "remove".

To add extra fold columns, use "keep" or "keep_warn". Note that existing fold columns might be renamed.

To replace the existing fold columns, use "remove".

parallel

Whether to parallelize the fold column comparisons, when `unique_fold_cols_only` is TRUE.

Requires a registered parallel backend. Like doParallel::registerDoParallel.

Details

cat_col

  1. `data` is subset by `cat_col`.

  2. Subsets are grouped and merged.

id_col

  1. Groups are created from unique IDs.

num_col

  1. Rows are shuffled. Note that this will only affect rows with the same value in `num_col`.

  2. Extreme pairing 1: Rows are ordered as smallest, largest, second smallest, second largest, etc. Each pair get a group identifier. (See rearrr::pair_extremes())

  3. If `extreme_pairing_levels` > 1: These group identifiers are reordered as smallest, largest, second smallest, second largest, etc., by the sum of `num_col` in the represented rows. These pairs (of pairs) get a new set of group identifiers, and the process is repeated `extreme_pairing_levels`-2 times. Note that the group identifiers at the last level will represent 2^`extreme_pairing_levels` rows, why you should be careful when choosing that setting.

  4. The group identifiers from the last pairing are folded (randomly divided into groups), and the fold identifiers are transferred to the original rows.

N.B. When doing extreme pairing of an unequal number of rows, the row with the smallest value is placed in a group by itself, and the order is instead: smallest, second smallest, largest, third smallest, second largest, etc.

N.B. When `num_fold_cols` > 1 and fewer than `num_fold_cols` fold columns have been created after `max_iters` attempts, we try with extreme triplets instead (see rearrr::triplet_extremes()). It groups the elements as smallest, closest to the median, largest, second smallest, second closest to the median, second largest, etc. We can also choose to never/only use extreme triplets via `use_of_triplets`.

cat_col AND id_col

  1. `data` is subset by `cat_col`.

  2. Groups are created from unique IDs in each subset.

  3. Subsets are merged.

cat_col AND num_col

  1. `data` is subset by `cat_col`.

  2. Subsets are grouped by `num_col`.

  3. Subsets are merged such that the largest group (by sum of `num_col`) from the first category is merged with the smallest group from the second category, etc.

num_col AND id_col

  1. Values in `num_col` are aggregated for each ID, using `id_aggregation_fn`.

  2. The IDs are grouped, using the aggregated values as "num_col".

  3. The groups of the IDs are transferred to the rows.

cat_col AND num_col AND id_col

  1. Values in `num_col` are aggregated for each ID, using `id_aggregation_fn`.

  2. IDs are subset by `cat_col`.

  3. The IDs in each subset are grouped, by using the aggregated values as "num_col".

  4. The subsets are merged such that the largest group (by sum of the aggregated values) from the first category is merged with the smallest group from the second category, etc.

  5. The groups of the IDs are transferred to the rows.

Value

data.frame with grouping factor for subsetting in cross-validation.

Author(s)

Ludvig Renbo Olsen, r-pkgs@ludvigolsen.dk

See Also

partition for balanced partitions

Other grouping functions: all_groups_identical(), collapse_groups(), collapse_groups_by, group(), group_factor(), partition(), splt()

Examples

# Attach packages
library(groupdata2)
library(dplyr)

# Create data frame
df <- data.frame(
  "participant" = factor(rep(c("1", "2", "3", "4", "5", "6"), 3)),
  "age" = rep(sample(c(1:100), 6), 3),
  "diagnosis" = factor(rep(c("a", "b", "a", "a", "b", "b"), 3)),
  "score" = sample(c(1:100), 3 * 6)
)
df <- df %>% arrange(participant)
df$session <- rep(c("1", "2", "3"), 6)

# Using fold()

## Without balancing
df_folded <- fold(data = df, k = 3, method = "n_dist")

## With cat_col
df_folded <- fold(
  data = df,
  k = 3,
  cat_col = "diagnosis",
  method = "n_dist"
)

## With id_col
df_folded <- fold(
  data = df,
  k = 3,
  id_col = "participant",
  method = "n_dist"
)

## With num_col
# Note: 'method' would not be used in this case
df_folded <- fold(data = df, k = 3, num_col = "score")

# With cat_col and id_col
df_folded <- fold(
  data = df,
  k = 3,
  cat_col = "diagnosis",
  id_col = "participant", method = "n_dist"
)

## With cat_col, id_col and num_col
df_folded <- fold(
  data = df,
  k = 3,
  cat_col = "diagnosis",
  id_col = "participant", num_col = "score"
)

# Order by folds
df_folded <- df_folded %>% arrange(.folds)

## Multiple fold columns
# Useful for repeated cross-validation
# Note: Consider running in parallel
df_folded <- fold(
  data = df,
  k = 3,
  cat_col = "diagnosis",
  id_col = "participant",
  num_fold_cols = 5,
  unique_fold_cols_only = TRUE,
  max_iters = 4
)

# Different `k` per fold column
# Note: `length(k) == num_fold_cols`
df_folded <- fold(
  data = df,
  k = c(2, 3),
  cat_col = "diagnosis",
  id_col = "participant",
  num_fold_cols = 2,
  unique_fold_cols_only = TRUE,
  max_iters = 4
)

# Check the generated columns
# with `summarize_group_cols()`
summarize_group_cols(
  data = df_folded,
  group_cols = paste0('.folds_', 1:2)
)

## Check if additional `extreme_pairing_levels`
## improve the numerical balance
set.seed(2) # try with seed 1 as well
df_folded_1 <- fold(
  data = df,
  k = 3,
  num_col = "score",
  extreme_pairing_levels = 1
)
df_folded_1 %>%
  dplyr::ungroup() %>%
  summarize_balances(group_cols = '.folds', num_cols = 'score')

set.seed(2)  # Try with seed 1 as well
df_folded_2 <- fold(
  data = df,
  k = 3,
  num_col = "score",
  extreme_pairing_levels = 2
)
df_folded_2 %>%
  dplyr::ungroup() %>%
  summarize_balances(group_cols = '.folds', num_cols = 'score')

# We can directly compare how balanced the 'score' is
# in the two fold columns using a combination of
# `summarize_balances()` and `ranked_balances()`
# We see that the second fold column (made with `extreme_pairing_levels = 2`)
# has a lower standard deviation of its mean scores - meaning that they
# are more similar and thus more balanced
df_folded_1$.folds_2 <- df_folded_2$.folds
df_folded_1 %>%
  dplyr::ungroup() %>%
  summarize_balances(group_cols = c('.folds', '.folds_2'), num_cols = 'score') %>%
  ranked_balances()


LudvigOlsen/groupdata2 documentation built on March 7, 2024, 12:57 p.m.