R/dpmeans.R

Defines functions print.dpmeans dpmeans

Documented in dpmeans print.dpmeans

#' Dirichlet Process K-Means Clustering
#'
#' This function uses a Bayesian Dirichlet process algorithm presented by
#' Kullis & Jordan (2011) to perform K Means Clustering. Rather than setting a fixed
#' number of clusters as in K-means clustering, the user specifies a concentration
#' parameter τ which controls the precision of a Dirichlet prior on the number of
#' clusters. Higher values of τ lead to a smaller number of clusters, and smaller values
#' lead to a larger number of clusters.
#'
#' @param data a data frame or matrix of numeric variables
#' @param tau the concentration parameter. set to higher values to get fewer clusters. the default is 2.
#' @param prior.labels a custom vector (character or numeric) or factor with prior cluster labels. this can be manually created, or can be the output of another clustering algorithm. if left as NULL, all observations are initialized in one cluster.
#' @param max.iter number of iterations. Defaults to 500.
#' @param tolerance tolerance for convegence. defaults to 1e-6
#' @export
#' @examples
#' dpmeans(iris[,1:4])
#'
#' @references Kullis, B.; Jordan, M. (2011) Revisiting k-means: New Algorithms via Bayesian Nonparametrics. Proceedings of the 29th International Conference on Machine Learning
#'
dpmeans <-  function(data, tau = 2, prior.labels = NULL, max.iter = 500, tolerance = 1e-6){

  orig.data = as.data.frame(data)
  n <- nrow( orig.data )
  data <- as.matrix(sapply(orig.data,as.numeric))

  if (is.null(prior.labels)){
    k <- ceiling(dimbox(data)$est)
    pick.clusters <- sample(1:k, n, replace = T)
  } else {
    if (is.factor(prior.labels)) prior.labels <- as.numeric(prior.labels)
    if (is.character(prior.labels)) prior.labels <- as.numeric(as.factor(prior.labels))
    pick.clusters <- prior.labels
    k <- length(unique(pick.clusters))
    rm(prior.labels)
  }


  mu <- matrix(colMeans(data), nrow=1, ncol=ncol(data))

  is.converged <- FALSE
  iteration <- 0
  ss.old <- Inf
  ss.curr <- Inf
  k <- 1

  while (!is.converged && iteration < max.iter) {
    iteration <- iteration + 1
    for( i in 1:n ) {
      distances <- rep(NA, k)
      for( j in 1:k ){
        distances[j] <- sum(abs(data[i, ] - mu[j, ])^2)  # Distance formula.
      }
      if(min(distances) > tau) {
        k <- k + 1
        pick.clusters[i] <- k
        mu <- rbind(mu, data[i, ])
      } else {
        k <- k
        pick.clusters[i] <- which(distances == min(distances))
      }

    }

    for( j in 1:k ) {
      if( length(pick.clusters == j) > 0 ) {
        mu[j, ] <- colMeans(subset(data,pick.clusters == j))
      }
    }

    ss.curr <- 0
    for( i in 1:n ) {
      ss.curr <- ss.curr +
        sum( (data[i, ] - mu[pick.clusters[i], ])^2)
    }
    ss.diff <- ss.old - ss.curr
    ss.old <- ss.curr
    if( !is.nan( ss.diff ) & ss.diff < tolerance ) {
      is.converged <- TRUE
    }

  }

  fpe <- ss.curr * (  (k + n + 1) / (n - k - 1) )
  bic <- log(ss.curr) + (k * log(n) * n)/(n - k - 1)
  centers <- data.frame(mu)
  clusters <- as.numeric(pick.clusters)

  ss.within <- rep(1, k)
  for (i in 1:k){
    ss.within[i] <- sum(colSums(scale(orig.data[clusters == i, ],T,F)^2))
  }

  ret.val <- structure(list("centers" = centers, "cluster" = factor(pick.clusters),
                            "k" = k, "iterations" = iteration, "sse" = ss.curr,
                            "ss.within" = ss.within,  "fpe" = fpe, "bic" = bic),
                       class = c("dpmeans", "kmeans"))

  return(ret.val)
}


#' Print method for dpmeans objects
#'
#' @param fit the model fit
#' @export
#' @method print dpmeans
#'
print.dpmeans <- function(fit){
  cat(crayon::blue("Centers : "), "\n\n")
  print(fit$centers)
  cat("\n")
  lightred <- crayon::make_style(rgb(1,0.024,0.362))
  bluegreen <- crayon::make_style(rgb(0.000,0.486,0.404))
  cat(lightred("Clusters : "), "\n\n")
  print(fit$cluster)
  cat("\n")
  cat(crayon::magenta("Within cluster sum of squares by cluster:\n"))
  print(fit$ss.within)
  cat("\n")
  cat(bluegreen("Total SSE ="), fit$sse, bluegreen("FPE ="), fit$fpe, bluegreen("BICc ="), fit$bic)
}
abnormally-distributed/cvreg documentation built on May 3, 2020, 3:45 p.m.