
Defines functions mlh_outsample_row_indices

Documented in mlh_outsample_row_indices

#' @title mlh_outsample_row_indices
#' @description Machine learning helper function to convert a vector of (in-
#'   sample) row indices of a fold into out-of-sample row indices.
#' @param fold_list A list of integer vectors that describe the row indices
#'   of cross-validation folds. The list must be named.
#' @param dataset_nrows An integer. The number of rows in the dataset
#'   dataset. This parameter is required in order to compute the out-of-sample
#'   row indices.
#' @param type A character. To be used if the out-of-sample row indices need to
#'   be formatted in a special manner (default: `NULL`). Currently, the only
#'   allowed value is `type = "glmnet"` in order to format the row indices as
#'   required by `glmnet::cv.glmnet`'s argument `foldid`.
#' @return If `type = NULL`, returns a list of same length as `fold_list` with
#'   each item containing a vector of out-of-sample row indices. If
#'   `type = "glmnet"`, a data.table is returned with two columns and each row
#'   representing one observation of the dataset that is assigned to a specific
#'   test fold. The column "fold_id" should be passed further on to the argument
#'   `foldid` of `glmnet::cv.glmnet`.

#' @examples
#' fold_list <- list(
#'   "Fold1" = setdiff(seq_len(100), 1:33),
#'   "Fold2" = setdiff(seq_len(100),66:100),
#'   "Fold3" = setdiff(seq_len(100),34:65)
#' )
#' mlh_outsample_row_indices(fold_list, 100)
#' mlh_outsample_row_indices(fold_list, 100, "glmnet")
#' @export
mlh_outsample_row_indices <- function(
    type = NULL
  ) {
    "`fold_list` needs to be a `list()-object`"  = is.list(fold_list),
    "`fold_list` must contain at least one item"  = length(fold_list) > 0L,
    "`dataset_nrows` must be an integer" = is.integer(
    "Please provide exactly one number for `dataset_nrows`" =
      length(dataset_nrows) == 1L,
    "`type` must be `'glmne '`" =
      ifelse(is.null(type), TRUE, type %in% c("glmnet")),
    "`fold_list` must contain at least one named item" =
      length(names(fold_list)) > 0L
  dataset_nrows <- as.integer(dataset_nrows)

  fold_ids <- sapply(
    X = names(fold_list),
    FUN = function(x) {
      in_sample_ids <- fold_list[[x]]
      setdiff(seq_len(dataset_nrows), in_sample_ids)
    simplify = FALSE,

  if (is.null(type)) {
      "Length of `fold_ids` is not equal to length of `fold_list`" =
        length(fold_ids) == length(fold_list))
  } else if (type == "glmnet") {
    # assign each row of the dataset to a specific test fold
    fids <- data.table::data.table()

    for (fid in seq_along(fold_ids)) {
      fids <- data.table::rbindlist(
        l = list(
            "row_id" = fold_ids[[fid]],
            "fold_id" = fid
    fids <- fids[order(get("row_id"))]
      "Duplicates occurred in `row_id`" = !any(duplicated(fids$row_id)),
      "Lenght of unique `fold_id`s is not equal to number of provided folds" =
        length(unique(fids$fold_id)) == length(fold_list),
      "Number of fold ids needs to match number of observations in the dataset"
      = nrow(fids) == dataset_nrows
  } else {
    stop(paste0("Type '", type, "' not implemented."))

Try the kdry package in your browser

Any scripts or data that you put into this service are public.

kdry documentation built on July 4, 2024, 9:07 a.m.