#' Generate k cross-validation sets
#' @param cv.k Number of cross validation sets (k-fold).
#' @param Y Dependent variable in vector or dummy matrix format (see Details).
#' @param type Indicating if a regression (\code{R}) or discriminant analysis (\code{DA}) will be performed.
#' @param method Type of cross validation: k-fold or Monte Carlo-Cross Validation (MCCV), sampling can be performed totally random or group stratified: 'k-fold', 'k-fold_stratified', 'MC', 'MC_stratified'.
#' @details If input argument \code{Y} is a dummy marix then function \cite{\code{centre_scale}} has been applied beforehand.The number of cross validation sets is related to the number of samples in each group. If in doubt, set \code{k=nrow(Y)} or \code{k=length(Y)} if Y is matrix or vector, respectively.
#' @seealso center_scale
#' @author Torben Kimhofer \email{tkimhofer@@gmail.com}
cv_sets_method=function(cv.k=7, Y, type=c('R', 'DA'), method='k-fold_stratified'){
if(type=='R' & (method == 'k-fold_stratified' | method == 'MC_stratified')){cat('Stratified k-fold CV sampling not implemented for numeric Y. Changed to k-fold!\n'); method='kfold'}
if(is.null(ncol(Y))){Y=as.matrix(Y, ncol=1)}
switch (method,
'k-fold' = {
sets = sample(1:cv.k, nrow(Y), replace = T)
# n-fold case
if (cv.k == nrow(Y)) {
sets = 1:nrow(Y)
}
sets_list = lapply(1:cv.k, function(i, idc = sets) {
which(idc != i)
})
},
'MC' = {
sets_list=lapply(1:cv.k, function(i, le=nrow(Y)){
sample(1:nrow(Y), nrow(Y)*2/3)
})
},
'k-fold_stratified' = {
# preserves the number of percentage of each class in each fold
levs=as.vector(unique(Y))
ct=table(Y)
nobs=floor(min(ct)/cv.k)
if(min(ct)<=cv.k){cat('Number of observations in group ', names(which.min(ct)), ' is too low (',min(ct), ')!\n', sep=''); return(NULL)}
#cat('Minimum number of observations in each fold: ', names(which.min(ct)), ' = ', nobs, '\n', sep='')
set_lev=lapply(levs, function(x, k=cv.k, y=Y, nob=nobs){
idx=sample(which(y==x))
k1=rep(1:k, length.out=length(idx))
names(k1)=c(idx)
return(k1)
})
sets_list = lapply(1:cv.k, function(i, idc = unlist(set_lev)) {
idx=which(idc != i)
as.numeric(names(idc)[idx])
})
},
'MC_stratified' = {
# sample in a way that 2/3 training set, with equal number of samples per class in each fold
levs=as.vector(unique(Y))
ct=table(Y)
nobs=floor(min(ct)*2/3)
sets_list=lapply(1:cv.k, function(i,y=Y, nob=nobs, ctab=ct){
print(i)
k1=list()
for(j in 1:length(ctab)){
k1[[j]]=sample(which(y==names(ctab)[j]), nob)
}
unlist(k1)
})
}
)
return(sets_list)
}
# cv_sets(cv.k=6, Y, type=c('kfold_stratified'))
# sum(unlist(lapply(cv_sets(cv.k=6, Y, type=c('kfold_stratified')), length)))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.