R/metric-sse.R

Defines functions sse_ratio_impl sse_ratio_vec sse_ratio.cluster_fit sse_ratio.cluster_spec sse_ratio sse_total_impl sse_total_vec sse_total.cluster_fit sse_total.cluster_spec sse_total sse_within_total_impl sse_within_total_vec sse_within_total.cluster_fit sse_within_total.cluster_spec sse_within_total sse_within

Documented in sse_ratio sse_ratio.cluster_fit sse_ratio.cluster_spec sse_ratio_vec sse_total sse_total.cluster_fit sse_total.cluster_spec sse_total_vec sse_within sse_within_total sse_within_total.cluster_fit sse_within_total.cluster_spec sse_within_total_vec

#' Calculates Sum of Squared Error in each cluster
#'
#' @param object A fitted kmeans tidyclust model
#' @param new_data A dataset to predict on.  If `NULL`, uses trained clustering.
#' @param dist_fun A function for calculating distances to centroids.  Defaults
#'   to Euclidean distance on processed data.
#'
#' @details [sse_within_total()] is the corresponding cluster metric function
#' that returns the sum of the values given by `sse_within()`.
#'
#' @return A tibble with two columns, the cluster name and the SSE within that
#'   cluster.
#'
#' @examples
#' kmeans_spec <- k_means(num_clusters = 5) %>%
#'   set_engine("stats")
#'
#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
#'
#' sse_within(kmeans_fit)
#' @export
sse_within <- function(object, new_data = NULL, dist_fun = Rfast::dista) {
  if (inherits(object, "cluster_spec")) {
    rlang::abort(
      paste(
        "This function requires a fitted model.",
        "Please use `fit()` on your cluster specification."
      )
    )
  }

  # Preprocess data before computing distances if appropriate
  if (inherits(object, "workflow") && !is.null(new_data)) {
    new_data <- extract_post_preprocessor(object, new_data)
  }

  summ <- extract_fit_summary(object)

  if (is.null(new_data)) {
    res <- tibble::tibble(
      .cluster = factor(summ$cluster_names),
      wss = summ$sse_within_total_total,
      n_members = summ$n_members
    )
  } else {
    dist_to_centroids <- dist_fun(summ$centroids, new_data)

    res <- dist_to_centroids %>%
      tibble::as_tibble(.name_repair = "minimal") %>%
      map(~ c(
        .cluster = which.min(.x),
        dist = min(.x)^2
      )) %>%
      dplyr::bind_rows() %>%
      dplyr::mutate(
        .cluster = factor(paste0("Cluster_", .cluster))
      ) %>%
      dplyr::group_by(.cluster) %>%
      dplyr::summarize(
        wss = sum(dist),
        n_obs = dplyr::n()
      )
  }

  return(res)
}

#' Compute the sum of within-cluster SSE
#'
#' @param object A fitted kmeans tidyclust model
#' @param new_data A dataset to predict on.  If `NULL`, uses trained clustering.
#' @param dist_fun A function for calculating distances to centroids. Defaults
#'   to Euclidean distance on processed data.
#' @param ... Other arguments passed to methods.
#'
#' @details Not to be confused with [sse_within()] that returns a tibble
#'   with within-cluster SSE, one row for each cluster.
#'
#' @return A tibble with 3 columns; `.metric`, `.estimator`, and `.estimate`.
#'
#' @family cluster metric
#'
#' @examples
#' kmeans_spec <- k_means(num_clusters = 5) %>%
#'   set_engine("stats")
#'
#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
#'
#' sse_within_total(kmeans_fit)
#'
#' sse_within_total_vec(kmeans_fit)
#' @export
sse_within_total <- function(object, ...) {
  UseMethod("sse_within_total")
}

sse_within_total <- new_cluster_metric(
  sse_within_total,
  direction = "zero"
)

#' @export
#' @rdname sse_within_total
sse_within_total.cluster_spec <- function(object, ...) {
  rlang::abort(
    paste(
      "This function requires a fitted model.",
      "Please use `fit()` on your cluster specification."
    )
  )
}

#' @export
#' @rdname sse_within_total
sse_within_total.cluster_fit <- function(object, new_data = NULL,
                                         dist_fun = NULL, ...) {
  if (is.null(dist_fun)) {
    dist_fun <- Rfast::dista
  }

  res <- sse_within_total_impl(object, new_data, dist_fun, ...)

  tibble::tibble(
    .metric = "sse_within_total",
    .estimator = "standard",
    .estimate = res
  )
}

#' @export
#' @rdname sse_within_total
sse_within_total.workflow <- sse_within_total.cluster_fit

#' @export
#' @rdname sse_within_total
sse_within_total_vec <- function(object, new_data = NULL,
                                 dist_fun = Rfast::dista, ...) {
  sse_within_total_impl(object, new_data, dist_fun, ...)
}

sse_within_total_impl <- function(object, new_data = NULL,
                                  dist_fun = Rfast::dista, ...) {
  sum(sse_within(object, new_data, dist_fun, ...)$wss, na.rm = TRUE)
}

#' Compute the total sum of squares
#'
#' @param object A fitted kmeans tidyclust model
#' @param new_data A dataset to predict on.  If `NULL`, uses trained clustering.
#' @param dist_fun A function for calculating distances to centroids.  Defaults
#'   to Euclidean distance on processed data.
#' @param ... Other arguments passed to methods.
#'
#' @return A tibble with 3 columns; `.metric`, `.estimator`, and `.estimate`.
#'
#' @family cluster metric
#'
#' @examples
#' kmeans_spec <- k_means(num_clusters = 5) %>%
#'   set_engine("stats")
#'
#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
#'
#' sse_total(kmeans_fit)
#'
#' sse_total_vec(kmeans_fit)
#' @export
sse_total <- function(object, ...) {
  UseMethod("sse_total")
}

sse_total <- new_cluster_metric(
  sse_total,
  direction = "zero"
)

#' @export
#' @rdname sse_total
sse_total.cluster_spec <- function(object, ...) {
  rlang::abort(
    paste(
      "This function requires a fitted model.",
      "Please use `fit()` on your cluster specification."
    )
  )
}

#' @export
#' @rdname sse_total
sse_total.cluster_fit <- function(object, new_data = NULL, dist_fun = NULL,
                                  ...) {
  if (is.null(dist_fun)) {
    dist_fun <- Rfast::dista
  }

  res <- sse_total_impl(object, new_data, dist_fun, ...)

  tibble::tibble(
    .metric = "sse_total",
    .estimator = "standard",
    .estimate = res
  )
}

#' @export
#' @rdname sse_total
sse_total.workflow <- sse_total.cluster_fit

#' @export
#' @rdname sse_total
sse_total_vec <- function(object, new_data = NULL, dist_fun = Rfast::dista, ...) {
  sse_total_impl(object, new_data, dist_fun, ...)
}

sse_total_impl <- function(object, new_data = NULL, dist_fun = Rfast::dista,
                           ...) {
  # Preprocess data before computing distances if appropriate
  if (inherits(object, "workflow") && !is.null(new_data)) {
    new_data <- extract_post_preprocessor(object, new_data)
  }

  summ <- extract_fit_summary(object)

  if (is.null(new_data)) {
    tot <- summ$sse_total
  } else {
    overall_mean <- colSums(summ$centroids * summ$n_members) /
      sum(summ$n_members)
    tot <- dist_fun(t(as.matrix(overall_mean)), new_data)^2 %>% sum()
  }

  return(tot)
}

#' Compute the ratio of the WSS to the total SSE
#'
#' @param object A fitted kmeans tidyclust model
#' @param new_data A dataset to predict on.  If `NULL`, uses trained clustering.
#' @param dist_fun A function for calculating distances to centroids.  Defaults
#'   to Euclidean distance on processed data.
#' @param ... Other arguments passed to methods.
#'
#' @return A tibble with 3 columns; `.metric`, `.estimator`, and `.estimate`.
#'
#' @family cluster metric
#'
#' @examples
#' kmeans_spec <- k_means(num_clusters = 5) %>%
#'   set_engine("stats")
#'
#' kmeans_fit <- fit(kmeans_spec, ~., mtcars)
#'
#' sse_ratio(kmeans_fit)
#'
#' sse_ratio_vec(kmeans_fit)
#' @export
sse_ratio <- function(object, ...) {
  UseMethod("sse_ratio")
}

sse_ratio <- new_cluster_metric(
  sse_ratio,
  direction = "zero"
)

#' @export
#' @rdname sse_ratio
sse_ratio.cluster_spec <- function(object, ...) {
  rlang::abort(
    paste(
      "This function requires a fitted model.",
      "Please use `fit()` on your cluster specification."
    )
  )
}

#' @export
#' @rdname sse_ratio
sse_ratio.cluster_fit <- function(object, new_data = NULL,
                                  dist_fun = NULL, ...) {
  if (is.null(dist_fun)) {
    dist_fun <- Rfast::dista
  }
  res <- sse_ratio_impl(object, new_data, dist_fun, ...)

  tibble::tibble(
    .metric = "sse_ratio",
    .estimator = "standard",
    .estimate = res
  )
}

#' @export
#' @rdname sse_ratio
sse_ratio.workflow <- sse_ratio.cluster_fit

#' @export
#' @rdname sse_ratio
sse_ratio_vec <- function(object,
                          new_data = NULL,
                          dist_fun = Rfast::dista,
                          ...) {
  sse_ratio_impl(object, new_data, dist_fun, ...)
}

sse_ratio_impl <- function(object,
                           new_data = NULL,
                           dist_fun = Rfast::dista,
                           ...) {
  sse_within_total_vec(object, new_data, dist_fun) /
    sse_total_vec(object, new_data, dist_fun)
}

Try the tidyclust package in your browser

Any scripts or data that you put into this service are public.

tidyclust documentation built on Sept. 26, 2023, 1:08 a.m.