R/kmeans.R

Defines functions simple_kmeans_db

Documented in simple_kmeans_db

#' Simple kmeans routine that works in-database
#'
#' It uses 'tidyeval' and 'dplyr' to run multiple cycles of kmean
#' calculations, expressed in dplyr formulas until an the optimal
#' centers are found.
#'
#' @param df A Local or remote data frame
#' @param ... A list of variables to be used in the kmeans algorithm
#' @param centers The number of centers. Defaults to 3.
#' @param max_repeats The maximum number of cycles to run. Defaults to 100.
#' @param initial_kmeans A local dataframe with initial centroid values. Defaults to NULL.
#' @param safeguard_file Each cycle will update a file specified in this argument
#' with the current centers.  Defaults to 'kmeans.csv'. Pass NULL if no file is
#' desired.
#' @param verbose Indicates if the progress bar will be displayed during the model's fitting.
#'
#' @details
#' Because each cycle is an independent 'dplyr' operation, or SQL operation if using a remote source,
#' the latest centroid data frame is saved to the parent environment in case the process needs to be
#' canceled and then restarted at a later point.  Passing the `current_kmeans` as the `initial_kmeans`
#' will allow the operation to pick up where it left off.
#'
#' @examples
#' library(dplyr)
#'
#' mtcars %>%
#'   simple_kmeans_db(mpg, qsec, wt) %>%
#'   glimpse()
#'
#' @export
simple_kmeans_db <- function(df,
                             ...,
                             centers = 3,
                             max_repeats = 100,
                             initial_kmeans = NULL,
                             safeguard_file = "kmeans.csv",
                             verbose = TRUE) {
  vars <- enquos(...)

  if (length(vars) > 0) {
    f_df <- select(df, !!!vars)
  } else {
    vars <- syms(colnames(df))
    f_df <- df
  }

  f_df <- filter_all(f_df, all_vars(!is.na(.)))

  if (!is.null(initial_kmeans)) {
    centroids <- initial_kmeans
  } else {
    centroids <- head(f_df, centers)
    centroids <- collect(centroids)
  }

  if (verbose) {
    pb <- progress::progress_bar$new(
      format = paste0(
        " Cycle :current of ", max_repeats, " max. [:bar] [:var][:elapsed]"
      ),
      total = max_repeats, clear = TRUE, width = 80
    )
  }

  for (iteration in 1:max_repeats) {
    prev_centroids <- centroids
    new_centroids <- calculate_centers(df, centroids, centers, vars)

    centroids_db <- select(new_centroids, !!!vars, center)
    centroids_db <- group_by(centroids_db, center)
    centroids_db <- summarise_all(centroids_db, "mean", na.rm = TRUE)

    centroids <- select(centroids_db, -center)
    centroids <- collect(centroids)

    if (!is.null(safeguard_file)) {
      sfg <- file.path(tempdir(), safeguard_file)
      write.csv(centroids, sfg, row.names = FALSE)
    }
    variance <- (
      round(
        abs(sum(prev_centroids) - sum(centroids)) / sum(prev_centroids),
        digits = 4
      ) * 100
    )
    if (verbose) pb$tick(tokens = list(var = variance))
    if (all(prev_centroids == centroids)) break()
  }
  centroids_db <- rename_all(centroids_db, ~ paste0("k_", .))
  joined <- left_join(
    rename(new_centroids, k_center = center),
    centroids_db,
    by = "k_center"
  )
  select(joined, contains("k_"), everything())
}

calculate_centers <- function(df, center_df, centers, vars) {
  center_names <- paste0("center_", 1:centers)

  fields <- length(vars)

  f_dist <- imap(
    center_df, ~ {
      map2(
        .x, .y,
        function(x, y) expr((!!x) - (!!sym(y)))
      )
    }
  )

  f_inside <- function(curr_center) {
    fi <- map(1:fields, ~ {
      f <- pluck(f_dist, .x, curr_center)
      expr((((!!f)) * ((!!f))))
    })
    reduce(fi, function(l, r) expr((!!l) + (!!r)))
  }

  km <- map(1:centers, ~ expr(sqrt(!!f_inside(.x))))
  km <- set_names(km, center_names)

  all <- map(center_names, ~ {
    comp <- map2(.x, center_names, function(x, y) {
      if (x != y) {
        expr((!!sym(x)) < (!!sym(y)))
      } else {
        expr((!!sym(x)) >= (!!sym(y)))
      }
    })
    comp <- reduce(comp, function(l, r) expr((!!l) & (!!r)))
    reduce(c(comp, .x), function(l, r) expr((!!l) ~ !!(r)))
  })
  all <- flatten(all)

  comp <- expr(case_when(!!!all))

  res <- mutate(df, !!!km)
  res <- mutate(res, center = !!comp)
  res <- filter(res, !is.na(center))
  select(res, -contains("center_"))
}

Try the modeldb package in your browser

Any scripts or data that you put into this service are public.

modeldb documentation built on Nov. 2, 2023, 5:39 p.m.