R/clustering.R

Defines functions clustering_splits clustering_cv

Documented in clustering_cv

#' Cluster Cross-Validation
#'
#' Cluster cross-validation splits the data into V groups of
#'  disjointed sets using k-means clustering of some variables.
#'  A resample of the analysis data consists of V-1 of the
#'  folds/clusters while the assessment set contains the final fold/cluster. In
#'  basic cross-validation (i.e. no repeats), the number of resamples
#'  is equal to V.
#'
#' @details
#' The variables in the `vars` argument are used for k-means clustering of
#'  the data into disjointed sets or for hierarchical clustering of the data.
#'  These clusters are used as the folds for cross-validation. Depending on how
#'  the data are distributed, there may not be an equal number of points
#'  in each fold.
#'
#' You can optionally provide a custom function to `distance_function`. The
#' function should take a data frame (as created via `data[vars]`) and return
#' a [stats::dist()] object with distances between data points.
#'
#' You can optionally provide a custom function to `cluster_function`. The
#' function must take three arguments:
#' - `dists`, a [stats::dist()] object with distances between data points
#' - `v`, a length-1 numeric for the number of folds to create
#' - `...`, to pass any additional named arguments to your function
#'
#' The function should return a vector of cluster assignments of length
#' `nrow(data)`, with each element of the vector corresponding to the matching
#' row of the data frame.
#'
#' @inheritParams vfold_cv
#' @param vars A vector of bare variable names to use to cluster the data.
#' @param repeats The number of times to repeat the clustered partitioning.
#' @param distance_function Which function should be used for distance calculations?
#' Defaults to [stats::dist()]. You can also provide your own
#' function; see `Details`.
#' @param cluster_function Which function should be used for clustering?
#' Options are either `"kmeans"` (to use [stats::kmeans()])
#' or `"hclust"` (to use [stats::hclust()]). You can also provide your own
#' function; see `Details`.
#' @param ... Extra arguments passed on to `cluster_function`.
#'
#' @return A tibble with classes `rset`, `tbl_df`, `tbl`, and `data.frame`.
#'  The results include a column for the data split objects and
#'  an identification variable `id`.
#'
#' @examplesIf rlang::is_installed("modeldata")
#' data(ames, package = "modeldata")
#' clustering_cv(ames, vars = c(Sale_Price, First_Flr_SF, Second_Flr_SF), v = 2)
#'
#' @rdname clustering_cv
#' @export
clustering_cv <- function(data,
                          vars,
                          v = 10,
                          repeats = 1,
                          distance_function = "dist",
                          cluster_function = c("kmeans", "hclust"),
                          ...) {
  check_repeats(repeats)

  if (!rlang::is_function(cluster_function)) {
    cluster_function <- rlang::arg_match(cluster_function)
  }

  vars <- tidyselect::eval_select(rlang::enquo(vars), data = data)
  if (rlang::is_empty(vars)) {
    rlang::abort("`vars` are required and must be variables in `data`.")
  }
  vars <- data[vars]

  if (repeats == 1) {
    dists <- rlang::exec(distance_function, vars)
    split_objs <- clustering_splits(
      data = data,
      dists = dists,
      v = v,
      cluster_function = cluster_function,
      ...
    )
  } else {
    for (i in 1:repeats) {
      dists <- rlang::exec(distance_function, vars)
      tmp <- clustering_splits(
        data = data,
        dists = dists,
        v = v,
        cluster_function = cluster_function,
        ...
      )
      tmp$id2 <- tmp$id
      tmp$id <- names0(repeats, "Repeat")[i]
      split_objs <- if (i == 1) {
        tmp
      } else {
        rbind(split_objs, tmp)
      }
    }
  }

  split_objs$splits <- map(split_objs$splits, rm_out)

  ## Save some overall information

  cv_att <- list(
    v = v,
    vars = names(vars),
    repeats = repeats,
    distance_function = distance_function,
    cluster_function = cluster_function
  )

  new_rset(
    splits = split_objs$splits,
    ids = split_objs[, grepl("^id", names(split_objs))],
    attrib = cv_att,
    subclass = c("clustering_cv", "rset")
  )
}

clustering_splits <- function(data,
                              dists,
                              v = 10,
                              cluster_function = c("kmeans", "hclust"),
                              ...) {
  if (!rlang::is_function(cluster_function)) {
    cluster_function <- rlang::arg_match(cluster_function)
  }

  check_v(v, nrow(data), "rows", call = rlang::caller_env())
  n <- nrow(data)

  clusterer <- ifelse(
    rlang::is_function(cluster_function),
    "custom",
    cluster_function
  )
  folds <- switch(
    clusterer,
    "kmeans" = {
      clusters <- stats::kmeans(dists, centers = v, ...)
      clusters$cluster
    },
    "hclust" = {
      clusters <- stats::hclust(dists, ...)
      stats::cutree(clusters, k = v)
    },
    do.call(cluster_function, list(dists = dists, v = v, ...))
  )

  idx <- seq_len(n)
  indices <- split_unnamed(idx, folds)
  indices <- lapply(indices, default_complement, n = n)

  split_objs <- purrr::map(
    indices,
    make_splits,
    data = data,
    class = c("clustering_split")
  )
  tibble::tibble(
    splits = split_objs,
    id = names0(length(split_objs), "Fold")
  )
}

Try the rsample package in your browser

Any scripts or data that you put into this service are public.

rsample documentation built on Aug. 23, 2023, 5:08 p.m.