View source: R/trainSplitPermute.R
trainSplitPermute | R Documentation |
Find a desired train/val/test split of a dataset through random permutation. Uses a variable in your dataset to randomly split by (for example, could be the location of different sites, or different months of data), then tries to find the split that most closesly matches your desired distribution of data for a set of labels. It can often be difficult to find a good split if the distribution of your labels is not consistent across sites, so this function tries a bunch of random splits then uses a score to find the best one.
trainSplitPermute(
x,
probs = c(0.7, 0.15, 0.15),
n = 1000,
splitBy = "drift",
label = "species",
countCol = NULL,
minCount = c(1, 1, 1),
top = 3,
seed = 112188
)
x |
a dataframe of data you want to find splits for |
probs |
a vector of 3 values that sum to one defining what percentage of data should be in your training, validation, and test sets (respectively) |
n |
number of random samples to try. If your labels are fairly evenly distributed this can be smaller, but needs to be larger for more uneven distributions |
splitBy |
name of column containing the variable you want to split by |
label |
name of the column containing your dataset labels |
countCol |
the names of any additional columns in your dataset defining the quantities you want to count (see example for why this is useful) |
minCount |
minimum count for each split category, usually safe to leave this as the default of 1 for all splits |
top |
the number of results to return. Usually you want to use just the best scoring result, but this can occasionally result in splits that are distributed in an undesirable way by random chance (eg maybe all sites in your validation data are unintentionally clustered together) |
seed |
random seed to set for reproducibility |
a list of the top
results. Each
individual result contains $splitMap
containing the random split marked
as integer 1, 2, 3 corresponding to train, val, test and $splitVec
a vector
marking each row of x
with its category. These two results are named by
the levels of splitBy
. $distribution
a table of the distribution of label
in the split, and $score
the split score (lower is closer to desired
probs
)
Taiki Sakai taiki.sakai@noaa.gov
# making some dummy data
df <- data.frame(
species = sample(letters[1:5], prob=c(.4, .2, .1, .1, .2), 1e3, replace=TRUE),
site = sample(LETTERS[1:12], 1e3, replace=TRUE),
event = 1:1e3
)
# try a split with n=3
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=3, label='species', splitBy='site')
# assign the best split as the split cateogry
df$split <- split[[1]]$splitVec
# distribution is not close to our desired .7, .15, .15 split because n is too low
round(table(df$species, df$split) /
matrix(rep(table(df$species), 3), nrow=5), 2)
# rerun with higher n to get closer to desired distribution
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=1e3, label='species', splitBy='site')
df$split <- split[[1]]$splitVec
round(table(df$species, df$split) /
matrix(rep(table(df$species), 3), nrow=5), 2)
# adding a new site that has significantly more detections than others
addSite <- data.frame(
species = sample(letters[1:5], 500, replace=TRUE),
site = rep(LETTERS[13], 500),
event = 1001:1500)
df$split <- NULL
df <- rbind(df, addSite)
# now just splitting by site does not result in a balanced split for our number of species
# it splits the sites to approx .7, .15, .15 but this does not result in balanced species
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=1e3, label='species', splitBy='site')
df$split <- split[[1]]$splitVec
round(table(df$species, df$split) /
matrix(rep(table(df$species), 3), nrow=5), 2)
# adding 'event' as a countCol fixes this
split <- trainSplitPermute(df, probs=c(.7, .15, .15), n=1e3, label='species',
splitBy='site', countCol='event')
df$split <- split[[1]]$splitVec
round(table(df$species, df$split) /
matrix(rep(table(df$species), 3), nrow=5), 2)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.