# TODO Work on in branch so it doesn't ruin the code coverage statistic for the user-available code
# nested_fold <- function(data,
# ks = c(5, 10),
# cat_col = NULL,
# num_col = NULL,
# id_col = NULL,
# group_method = "n_dist",
# nesting_method = "x",
# id_aggregation_fn = sum,
# extreme_pairing_levels = 1,
# num_fold_cols = 1, # Currently only used at last level
# unique_fold_cols_only = TRUE,
# max_iters = 5,
# handle_existing_fold_cols = "keep_warn", # TODO Make sure this is meaningful and works!!
# parallel = FALSE) {
# original_colnames <- colnames(data)
#
# # Method x (rename):
# ## 'fold col 1': fold dataset
# ## 'fold col 2': group by 'fold col 1' and fold() each group separately
# ## 'fold col 3': group by 'fold col 1' and 'fold col 2' and fold() each group separately
# ## ...
# ## The last level may have multiple, unique fold columns
#
# # Method cv
# ## 'fold col 1' (outer): fold dataset
# ## 'fold col 2' (inner): for each k-1 training set, create folds, rest is test set
# ## 'fold col 3' (inner lv2): for each k-1 training set, create folds, rest is test set
# ## ...
# ## The last level may have multiple, unique fold columns
#
# if (nesting_method == "x") {
# nested <- internal_nested_fold_method_x(
# data = data,
# ks = ks,
# by = NULL,
# cat_col = cat_col,
# num_col = num_col,
# id_col = id_col,
# method = group_method,
# id_aggregation_fn = id_aggregation_fn,
# extreme_pairing_levels = extreme_pairing_levels,
# num_fold_cols = num_fold_cols,
# unique_fold_cols_only = unique_fold_cols_only,
# max_iters = max_iters,
# handle_existing_fold_cols = "remove",
# parallel = parallel
# )
#
# folded_data <- nested[["data"]]
# fold_col_names <- nested[["col_names"]]
#
# # Do check
# new_colnames <- setdiff(
# colnames(folded_data),
# original_colnames
# )
# if (length(setdiff(new_colnames, fold_col_names)) != 0 ||
# length(setdiff(fold_col_names, new_colnames)) != 0) {
# stop("something went wrong when creating nested folds")
# }
#
# fold_columns <- folded_data %>%
# base_select(cols = fold_col_names)
#
# # Rename new columns meaningfully
# meaningful_names <- create_nested_fold_cols_names(
# num_fold_cols = num_fold_cols,
# ks = ks,
# fold_col_names = fold_col_names
# )
#
# if (length(meaningful_names) != length(fold_col_names)) {
# stop("something went wrong when creating names for the nested folds")
# }
#
# fold_columns <- fold_columns %>%
# dplyr::rename_at(dplyr::vars(fold_col_names), ~meaningful_names)
#
# data %>%
# dplyr::bind_cols(fold_columns)
# }
# }
#
#
# # TODO Add tests:
# # create_nested_fold_cols_names(c(3,1,3), c(2,2,4), NULL)
# # create_nested_fold_cols_names(3, c(2,2), NULL)
# # create_nested_fold_cols_names(1, NULL, c("a","t","f","g"))
# create_nested_fold_cols_names <- function(num_fold_cols, ks, fold_col_names, nesting_method = "x") {
# if (nesting_method == "x") {
#
# # Rename new columns meaningfully
# if (length(num_fold_cols) > 1) {
# # Note: This part should currently not be used,
# # as we only allow multiple fold cols in the last level
# meaningful_names <-
# plyr::llply(
# seq_along(num_fold_cols),
# function(nm_ind) {
# paste0(".nested_folds_", nm_ind, "_", seq_len(num_fold_cols[[nm_ind]]))
# }
# ) %>% unlist()
# } else if (num_fold_cols > 1) {
# first_levels_names <- paste0(
# ".nested_folds_",
# seq_len(length(ks) - 1)
# )
# last_level_names <- paste0(
# ".nested_folds_",
# length(ks), "_",
# seq_len(num_fold_cols)
# )
# meaningful_names <- c(first_levels_names, last_level_names)
# } else {
# meaningful_names <- paste0(".nested_folds_", seq_along(fold_col_names))
# }
# }
# meaningful_names
# }
#
# internal_nested_fold_method_x <- function(data,
# ks = c(5, 10),
# by = NULL,
# cat_col = NULL,
# num_col = NULL,
# id_col = NULL,
# method = "n_dist",
# id_aggregation_fn = sum,
# extreme_pairing_levels = 1,
# num_fold_cols = 1,
# unique_fold_cols_only = TRUE,
# max_iters = 5,
# handle_existing_fold_cols = "keep_warn",
# parallel = FALSE) {
#
# group_col <- by
#
# # Get current number of fold cols
# # and update num_fold_cols for next iteration if necessary
# if (length(num_fold_cols) == 1) {
# if (length(ks) == 1) {
# current_num_fold_cols <- num_fold_cols
# } else {
# current_num_fold_cols <- 1
# }
# } else {
# current_num_fold_cols <- num_fold_cols[[1]]
# num_fold_cols <- num_fold_cols[-1]
# }
#
# # Get the current number of folds to create
# # and update ks if necessary
# if (length(ks) > 1) {
# current_k <- ks[[1]]
# ks <- ks[-1]
# last_level <- FALSE
# } else {
# current_k <- ks
# last_level <- TRUE
# }
#
# # Get original column names
# original_cols <- colnames(data)
#
# if (!is.null(group_col)) {
# data <- data %>%
# dplyr::group_by(!!!rlang::syms(group_col)) %>%
# dplyr::group_modify(~ fold_rename_wrapper(
# .x,
# k = current_k,
# cat_col = cat_col,
# num_col = num_col,
# id_col = id_col,
# method = method,
# id_aggregation_fn = id_aggregation_fn,
# extreme_pairing_levels = extreme_pairing_levels,
# num_fold_cols = current_num_fold_cols,
# unique_fold_cols_only = unique_fold_cols_only,
# max_iters = max_iters,
# handle_existing_fold_cols = handle_existing_fold_cols, # TODO think this arg through in nested
# parallel = parallel,
# cols_to_remove_post_fold = group_col
# ), keep = TRUE) %>%
# dplyr::ungroup()
# } else {
# data <- fold_rename_wrapper(
# data,
# k = current_k,
# cat_col = cat_col,
# num_col = num_col,
# id_col = id_col,
# method = method,
# id_aggregation_fn = id_aggregation_fn,
# extreme_pairing_levels = extreme_pairing_levels,
# num_fold_cols = current_num_fold_cols,
# unique_fold_cols_only = unique_fold_cols_only,
# max_iters = max_iters,
# handle_existing_fold_cols = handle_existing_fold_cols, # TODO think this arg through in nested
# parallel = parallel
# ) %>%
# dplyr::ungroup()
# }
#
# # Get name of new fold column
# new_col <- setdiff(colnames(data), original_cols)
#
# # Append new fold col to group variables
# group_col <- c(group_col, new_col)
#
# # If .folds was renamed to .folds_1, we need to remove the
# # .folds from group_col
# if (group_col[[1]] == ".folds" && length(group_col) > 1) {
# group_col <- group_col[-1]
# }
#
# if (!isTRUE(last_level)) {
# return(
# internal_nested_fold_method_x(
# data = data,
# ks = ks,
# by = group_col,
# cat_col = cat_col,
# num_col = num_col,
# id_col = id_col,
# method = method,
# id_aggregation_fn = id_aggregation_fn,
# extreme_pairing_levels = extreme_pairing_levels,
# num_fold_cols = num_fold_cols,
# unique_fold_cols_only = unique_fold_cols_only,
# max_iters = max_iters,
# handle_existing_fold_cols = "keep",
# parallel = parallel
# )
# )
# } else {
# return(
# list(
# "data" = data,
# "col_names" = group_col
# )
# )
# }
# }
#
# fold_rename_wrapper <- function(data, k, cat_col, num_col, id_col,
# method, id_aggregation_fn,
# extreme_pairing_levels, num_fold_cols,
# unique_fold_cols_only, max_iters,
# handle_existing_fold_cols, parallel,
# cols_to_remove_post_fold = NULL) {
#
# # Extract original column names
# original_cols <- colnames(data)
#
# # Fold the dataset
# data <-
# fold(
# data = data,
# k = k,
# cat_col = cat_col,
# num_col = num_col,
# id_col = id_col,
# method = method,
# id_aggregation_fn = id_aggregation_fn,
# extreme_pairing_levels = extreme_pairing_levels,
# num_fold_cols = num_fold_cols,
# unique_fold_cols_only = unique_fold_cols_only,
# max_iters = max_iters,
# handle_existing_fold_cols = handle_existing_fold_cols,
# parallel = parallel
# )
#
# # Extract new fold column names
# new_cols <- setdiff(colnames(data), original_cols)
#
# # Create new unique temporary column names
# # Needed when called in dplyr::group_modify
# new_tmp_names <- plyr::llply(new_cols, function(nc) {
# create_tmp_var(data, paste0("tmp_", nc, "_var_", num_fold_cols > 1),
# disallowed = cols_to_remove_post_fold
# )
# }) %>% unlist()
#
# # Remove specified columns
# # As dplyr::group_modify does not allow us to
# # return the original grouping variables
# if (!is.null(cols_to_remove_post_fold)) {
# data <- data %>%
# base_deselect(cols = cols_to_remove_post_fold)
# }
#
# data %>%
# dplyr::rename_at(dplyr::vars(new_cols), ~new_tmp_names)
# }
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.