R/extract_cluster_assignment.R

Defines functions cluster_assignment_tibble extract_cluster_assignment.hclust extract_cluster_assignment.kmodes extract_cluster_assignment.kproto extract_cluster_assignment.KMeansCluster extract_cluster_assignment.kmeans extract_cluster_assignment.workflow extract_cluster_assignment.cluster_fit extract_cluster_assignment.cluster_spec extract_cluster_assignment

Documented in extract_cluster_assignment

#' Extract cluster assignments from model
#'
#' When applied to a fitted cluster specification, returns a tibble with cluster
#' assignments of the data used to train the model.
#'
#' @param object An fitted [`cluster_spec`] object.
#' @param ... Other arguments passed to methods. Using the `prefix` allows you
#'   to change the prefix in the levels of the factor levels.
#'
#' @details
#'
#' Some model types such as K-means as seen in [k_means()] stores the
#' cluster assignments in the object itself. leading the use of this function to
#' act as an simple extract. Other model types such as Hierarchical
#' (Agglomerative) Clustering as seen in [hier_clust()], are fit in such a way
#' that the number of clusters can be determined at any time after the fit.
#' Setting the `num_clusters` or `cut_height` in this function will be used to
#' determine the clustering when reported.
#'
#' The ordering of the clusters is such that the first observation in the
#' training data set will be in cluster 1, the next observation that doesn't
#' belong to cluster 1 will be in cluster 2, and so on and forth. As the
#' ordering of clustering doesn't matter, this is done to avoid identical sets
#' of clustering having different labels if fit multiple times.
#'
#' ## Related functions
#'
#' `extract_cluster_assignment()` is a part of a trio of functions doing
#' similar things:
#'
#' - [extract_cluster_assignment()] returns the cluster assignments of the
#'   training observations
#' - [extract_centroids()] returns the location of the centroids
#' - \code{\link[=predict.cluster_fit]{predict()}} returns the cluster a new
#'   observation belongs to
#'
#' @return A `tibble::tibble()` with 1 column named `.cluster`. This tibble will
#'   correspond the the training data set.
#'
#' @seealso [extract_centroids()] [predict.cluster_fit()]
#' @examples
#' kmeans_spec <- k_means(num_clusters = 5) %>%
#'   set_engine("stats")
#'
#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
#'
#' kmeans_fit %>%
#'   extract_cluster_assignment()
#'
#' kmeans_fit %>%
#'   extract_cluster_assignment(prefix = "C_")
#'
#' # Some models such as `hier_clust()` fits in such a way that you can specify
#' # the number of clusters after the model is fit
#' hclust_spec <- hier_clust() %>%
#'   set_engine("stats")
#'
#' hclust_fit <- fit(hclust_spec, ~., mtcars)
#'
#' hclust_fit %>%
#'   extract_cluster_assignment(num_clusters = 2)
#'
#' hclust_fit %>%
#'   extract_cluster_assignment(cut_height = 250)
#' @export
extract_cluster_assignment <- function(object, ...) {
  UseMethod("extract_cluster_assignment")
}

#' @export
extract_cluster_assignment.cluster_spec <- function(object, ...) {
  rlang::abort(
    paste(
      "This function requires a fitted model.",
      "Please use `fit()` on your cluster specification."
    )
  )
}

#' @export
extract_cluster_assignment.cluster_fit <- function(object, ...) {
  extract_cluster_assignment(object$fit, ...)
}

#' @export
extract_cluster_assignment.workflow <- function(object, ...) {
  extract_cluster_assignment(object$fit$fit$fit, ...)
}

#' @export
extract_cluster_assignment.kmeans <- function(object, ...) {
  cluster_assignment_tibble(object$cluster, length(object$size), ...)
}

#' @export
extract_cluster_assignment.KMeansCluster <- function(object, ...) {
  n_clusters <- length(object$obs_per_cluster)
  cluster_assignment_tibble(object$clusters, n_clusters, ...)
}

#' @export
extract_cluster_assignment.kproto <- function(object, ...) {
  n_clusters <- length(object$size)
  cluster_assignment_tibble(object$cluster, n_clusters, ...)
}

#' @export
extract_cluster_assignment.kmodes <- function(object, ...) {
  n_clusters <- length(object$size)
  cluster_assignment_tibble(object$cluster, n_clusters, ...)
}

#' @export
extract_cluster_assignment.hclust <- function(object,
                                              ...,
                                              call = rlang::caller_env(0)) {
  # if k or h is passed in the dots, use those.  Otherwise, use attributes
  # from original model specification
  args <- list(...)

  if (!is.null(args[["h"]])) {
    rlang::abort(
      paste(
        "Using `h` argument is not supported.",
        "Please use `cut_height` instead."
      ),
      call = call
    )
  }

  if (!is.null(args[["k"]])) {
    rlang::abort(
      paste(
        "Using `k` argument is not supported.",
        "Please use `num_clusters` instead."
      ),
      call = call
    )
  }

  if (!("num_clusters" %in% names(args) || "cut_height" %in% names(args))) {
    num_clusters <- attr(object, "num_clusters")
    cut_height <- attr(object, "cut_height")
  } else {
    num_clusters <- args[["num_clusters"]]
    cut_height <- args[["cut_height"]]
  }

  if (is.null(num_clusters) && is.null(cut_height)) {
    rlang::abort(
      "Please specify either `num_clusters` or `cut_height`.",
      call = call
    )
  }

  clusters <- stats::cutree(object, k = num_clusters, h = cut_height)
  cluster_assignment_tibble(clusters, length(unique(clusters)), ...)
}

# ------------------------------------------------------------------------------

cluster_assignment_tibble <- function(clusters,
                                      n_clusters,
                                      ...,
                                      prefix = "Cluster_") {
  reorder_clusts <- order(union(unique(clusters), seq_len(n_clusters)))
  names <- paste0(prefix, seq_len(n_clusters))
  res <- names[reorder_clusts][clusters]

  tibble::tibble(.cluster = factor(res))
}

Try the tidyclust package in your browser

Any scripts or data that you put into this service are public.

tidyclust documentation built on Sept. 26, 2023, 1:08 a.m.