R/explore_clustering.R

Defines functions explore_clustering

Documented in explore_clustering

#' fit and plot K-Means, DBScan clustering algorithm on the dataset
#'
#' @param df dataframe
#' @param hyperparameter_dict dictionary of hyperparameters to be used, default is NULL, a default set of hyperparameters will be used
#'
#' @return a dictionary with each key = a clustering model name, value = list of plots generated by that model
#' @export
#'
#' @examples
#' library(palmerpenguins)
#' explore_clustering(penguins)
explore_clustering <- function(df, hyperparameter_dict=NULL){


  kmeans_keys <- c("centers", "iter.max", "algorithm")
  dbscan_keys <- c("eps", "minPts")

  if (is.null(hyperparameter_dict)) {
    kmeans_hparams <- list(seq(from = 2, to = 10), 10, c("Lloyd"))
    names(kmeans_hparams)  <-  kmeans_keys
    dbscan_hparams <- list(c(1), c(5))
    names(dbscan_hparams) <- dbscan_keys
    hyperparameter_dict  <-  list(kmeans_hparams, dbscan_hparams)
    names(hyperparameter_dict) <- c("KMeans", "DBSCAN")
  }
  else if (!is.list(hyperparameter_dict))
    stop("Invalid type for hyperparameter_dict, it must be a list.")
  else{
    if (!("KMeans" %in% names(hyperparameter_dict)))
      stop("Please specify hypeparameters for KMeans.")

    if (!("DBSCAN" %in% names(hyperparameter_dict)))
      stop("Please specify hypeparameters for DBSCAN.")

    kmeans_hparams <- hyperparameter_dict$KMeans

    for (k in kmeans_keys) {
      if (!(k %in% names(kmeans_hparams)))
        stop(paste0("Please specify ", k, " parameter for KMeans."))
    }

    dbscan_hparams <- hyperparameter_dict$DBSCAN
    for (k in dbscan_keys) {
      if (!(k %in% names(dbscan_hparams)))
        stop(paste0("Please specify ", k, " parameter for DBSCAN."))
    }
  }
  results <- list(NULL, NULL)
  names(results) <- c("KMeans", "DBSCAN")
  results$KMeans <-
    explore_KMeans_clustering(
      df,
      centers = kmeans_hparams$centers,
      iter.max = kmeans_hparams$iter.max,
      algorithm = kmeans_hparams$algorithm
    )
  results$DBSCAN <- explore_DBSCAN_clustering(df,
                                              eps = dbscan_hparams$eps,
                                              minPts = dbscan_hparams$minPts)
  return(results)

}
UBC-MDS/datascience.eda.R documentation built on March 24, 2021, 2:22 a.m.