trainSplitPermute: trainSplitPermute

View source: R/trainSplitPermute.R

trainSplitPermuteR Documentation

trainSplitPermute

Description

Find a desired train/val/test split of a dataset through random permutation. Uses a variable in your dataset to randomly split by (for example, could be the location of different sites, or different months of data), then tries to find the split that most closesly matches your desired distribution of data for a set of labels. It can often be difficult to find a good split if the distribution of your labels is not consistent across sites, so this function tries a bunch of random splits then uses a score to find the best one.

Usage

trainSplitPermute(
  x,
  probs = c(0.7, 0.15, 0.15),
  n = 1000,
  splitBy = "drift",
  label = "species",
  countCol = NULL,
  minCount = c(1, 1, 1),
  top = 3,
  seed = 112188
)

Arguments

x

a dataframe of data you want to find splits for

probs

a vector of 3 values that sum to one defining what percentage of data should be in your training, validation, and test sets (respectively)

n

number of random samples to try. If your labels are fairly evenly distributed this can be smaller, but needs to be larger for more uneven distributions

splitBy

name of column containing the variable you want to split by

label

name of the column containing your dataset labels

countCol

the names of any additional columns in your dataset defining the quantities you want to count (see example for why this is useful)

minCount

minimum count for each split category, usually safe to leave this as the default of 1 for all splits

top

the number of results to return. Usually you want to use just the best scoring result, but this can occasionally result in splits that are distributed in an undesirable way by random chance (eg maybe all sites in your validation data are unintentionally clustered together)

seed

random seed to set for reproducibility

Value

a list of the top results. Each individual result contains $splitMap containing the random split marked as integer 1, 2, 3 corresponding to train, val, test and $splitVec a vector marking each row of x with its category. These two results are named by the levels of splitBy. $distribution a table of the distribution of label in the split, and $score the split score (lower is closer to desired probs)

Author(s)

Taiki Sakai taiki.sakai@noaa.gov

Examples


# making some dummy data
df <- data.frame(
    species = sample(letters[1:5], prob=c(.4, .2, .1, .1, .2), 1e3, replace=TRUE),
    site = sample(LETTERS[1:12], 1e3, replace=TRUE),
    event = 1:1e3
)
# try a split with n=3
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=3, label='species', splitBy='site')
# assign the best split as the split cateogry
df$split <- split[[1]]$splitVec
# distribution is not close to our desired .7, .15, .15 split because n is too low
round(table(df$species, df$split) /
    matrix(rep(table(df$species), 3), nrow=5), 2)

# rerun with higher n to get closer to desired distribution
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=1e3, label='species', splitBy='site')
df$split <- split[[1]]$splitVec
round(table(df$species, df$split) /
    matrix(rep(table(df$species), 3), nrow=5), 2)
    
# adding a new site that has significantly more detections than others
addSite <- data.frame(
    species = sample(letters[1:5], 500, replace=TRUE),
    site = rep(LETTERS[13], 500),
    event = 1001:1500)
df$split <- NULL
df <- rbind(df, addSite)

# now just splitting by site does not result in a balanced split for our number of species
# it splits the sites to approx .7, .15, .15 but this does not result in balanced species
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=1e3, label='species', splitBy='site')
df$split <- split[[1]]$splitVec
round(table(df$species, df$split) /
    matrix(rep(table(df$species), 3), nrow=5), 2)

# adding 'event' as a countCol fixes this
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=1e3, label='species', 
    splitBy='site', countCol='event')
df$split <- split[[1]]$splitVec
round(table(df$species, df$split) /
    matrix(rep(table(df$species), 3), nrow=5), 2)



TaikiSan21/PAMmisc documentation built on April 27, 2024, 2:04 p.m.