#' Clusternomics: Context-dependent clustering
#'
#' This function fits the context-dependent clustering model
#' to the data using Gibbs sampling. It allows the user to specify a different
#' number of clusters on the global level, as well as on the local level.
#' @param datasets List of data matrices where each matrix represents a
#' context-specific dataset. Each data matrix has the size \emph{N} times \emph{M}, where
#' \emph{N} is the number of data points and \emph{M} is the dimensionality of the data.
#' The full list of matrices has length \emph{C}. The number of data points \emph{N} must
#' be the same for all data matrices.
#' @param clusterCounts Number of cluster on the global level and in each context.
#' List with the following structure: \code{clusterCounts = list(global=global, context=context)} where
#' \code{global} is the number of global clusters, and \code{context} is the list of
#' numbers of clusters in the individual contexts (datasets) of length \emph{C} where
#' \code{context[c]} is the number of clusters in dataset \emph{c}.
#' @param dataDistributions Distribution of data in each dataset. Can be either a list of
#' length \emph{C} where \code{dataDistributions[c]} is the distribution of dataset \emph{c},
#' or a single string when all datasets have the same distribution. Currently implemented
#' distribution is the \code{'diagNormal'} option for multivariate Normal distribution
#' with diagonal covariance matrix.
#' @param prior Prior distribution. If \code{NULL} then the prior is estimated using
#' the datasets. The \code{'diagNormal'} distribution uses the Normal-Gamma distribution
#' as a prior for each dimension.
#' @param maxIter Number of iterations of the Gibbs sampling algorithm.
#' @param burnin Number of burn-in iterations that will be discarded. If not specified,
#' the algorithm discards the first half of the \code{maxIter} samples.
#' @param lag Used for thinning the samples.
#' @param verbose Print progress, by default \code{FALSE}.
#'
#' @return Returns list containing the sequence of MCMC states and the log likelihoods of
#' the individual states.
#' \item{samples}{List of assignments sampled from the posterior,
#' each state \code{samples[[i]]} is a data frame with local and global assignments
#' for each data point.}
#' \item{logliks}{Log likelihoods during MCMC iterations.}
#' \item{DIC}{Deviance information criterion to help select the number of clusters. Lower
#' values of DIC correspond to better-fitting models.}
#'
#' @examples
#' # Example with simulated data (see vignette for details)
#' # Number of elements in each cluster
#' groupCounts <- c(50, 10, 40, 60)
#' # Centers of clusters
#' means <- c(-1.5,1.5)
#' testData <- generateTestData_2D(groupCounts, means)
#' datasets <- testData$data
#'
#' # Fit the model
#' # 1. specify number of clusters
#' clusterCounts <- list(global=10, context=c(3,3))
#' # 2. Run inference
#' # Number of iterations is just for demonstration purposes, use
#' # a larger number of iterations in practice!
#' results <- contextCluster(datasets, clusterCounts,
#' maxIter = 10, burnin = 5, lag = 1,
#' dataDistributions = 'diagNormal',
#' verbose = TRUE)
#'
#' # Extract results from the samples
#' # Final state:
#' state <- results$samples[[length(results$samples)]]
#' # 1) assignment to global clusters
#' globalAssgn <- state$Global
#' # 2) context-specific assignmnets- assignment in specific dataset (context)
#' contextAssgn <- state[,"Context 1"]
#' # Assess the fit of the model with DIC
#' results$DIC
#'
#' @importFrom magrittr %>%
#' @importFrom plyr aaply
#' @importFrom plyr laply
#' @importFrom plyr llply
#' @importFrom stats runif
#' @export
contextCluster <- function(datasets, clusterCounts,
dataDistributions="diagNormal", prior=NULL,
maxIter=1000, burnin=NULL, lag=3,
verbose = FALSE){
nDataPoints <- dim(datasets[[1]])[1]
nContexts <- length(datasets)
# check if the number of contexts = length(clusterCounts)
if (nContexts != length(clusterCounts)) stop("Number of datasets is different from number of contexts.")
# check if the number of data points is the same in all contexts
if (sum(laply(datasets, function(dt) nrow(dt) == nDataPoints)) != nContexts) {
stop("Number of data points is different in individual contexts.")
}
if(verbose) message('Running initialisation')
# Distributions for every dataset (context)
if (length(dataDistributions) == 1) {
fullDataDistributions <- rep(dataDistributions, nContexts)
} else {
fullDataDistributions <- dataDistributions
}
# Default burnin interval
if (is.null(burnin)) {
burnin <- maxIter/2
}
# Create an initial state
if (is.null(prior)) {
prior <- empiricalBayesPrior(datasets, fullDataDistributions)
}
state <- createGibbsState(datasets, clusterCounts, fullDataDistributions)
dataStats <- createDataStats(datasets, prior, fullDataDistributions)
#******************
# Run Gibbs sampling
#samples <- vector("list", length = maxIter)
logliks <- rep(0, maxIter)
assignSamples <- vector("list", length=maxIter)
for (iter in 1:maxIter) {
if (((iter%%10) == 0 ) && verbose) { message(paste("iter ",iter)) }
state <- state %>%
gibbsSampleZ(dataStats, prior, clusterCounts) %>%
gibbsSampleContextK(prior, clusterCounts)
logliks[iter] <- logJoint(state, prior, clusterCounts)
#samples[[iter]] <- state
assignSamples[[iter]]$dataAssignments <- state$Z
assignSamples[[iter]]$contextAssignments <- state$contextK
}
#****************
thinned_logliks <- getMCMCSamples(logliks, burnin, lag)
thinned_samples <- getMCMCSamples(assignSamples, burnin, lag)
DIC <- computeDIC(thinned_logliks, thinned_samples, nDataPoints,
state$distributions, clusterCounts, prior, datasets)
assignments <-
thinned_samples %>%
llply(getClustersAssignments)
return(list(samples=assignments, logliks=logliks, DIC=DIC))
}
#========= Helper functions ========================
logsumexp <- function(xs){
xMax <- max(xs)
values <- (xs - xMax) %>% exp %>% sum %>% log
values + xMax
}
normalizeProbs <- function(logliks) {
normalizer <- logsumexp(logliks)
probs <- exp(logliks - normalizer)
probs
}
sampleCategorical <- function(logliks) {
normalizer <- logsumexp(logliks)
probs <- exp(logliks - normalizer)
cdf <- cumsum(probs)
min(which(runif(1) < cdf))
}
rep.row<-function(x,n){
matrix(rep(x,each=n),nrow=n)
}
rep.col<-function(x,n){
matrix(rep(x,each=n), ncol=n, byrow=TRUE)
}
# Function to combine cluster assignments into global clusters
generateMapping <- function(clusterCounts) {
nGlobalClusters <- prod(clusterCounts)
nContexts <- length(clusterCounts)
mapping <- matrix(nrow=nGlobalClusters,ncol=nContexts+1)
# indices of global clusters
mapping[,nContexts+1] <- 1:nGlobalClusters
# indices of local clusters
localClusters <- list()
for(c in 1:nContexts) localClusters[[c]] <- 1:clusterCounts[c]
grid <- expand.grid(localClusters)
for(c in 1:nContexts) mapping[,c] <- grid[[c]]
mapping
}
getGlobalClusters <- function(state, mapping) {
nDataPoints <- length(state$Z)
nContexts <- length(state$distributions)
clusterAssgn <- matrix(nrow=nDataPoints,ncol=nContexts)
for (context in 1:nContexts) {
clusterAssgn[,context] <- state$contextK[state$Z, context]
}
# Map to global clusters
aaply(clusterAssgn,1,function(row)
as.character(row) %>% paste(collapse=''))
}
saveSample <- function(state, mapping, iter, filename) {
nDataPoints <- length(state$Z)
nContexts <- length(state$distributions)
if (!file.exists(filename)) {
# Write header
contextHeaders <-
laply(1:nContexts, function(context) {
laply(1:nDataPoints, function(n) paste("Context", context, " x", n, sep='')) %>% paste(collapse=',')
}) %>% paste(collapse=',')
# Global assignments
globalHeaders1 <-
laply(1:nDataPoints, function(n) paste("Global_assgn x", n, sep='')) %>% paste(collapse=',')
# Underlying global clusters
globalHeaders2 <-
laply(1:nDataPoints, function(n) paste("Global x", n, sep='')) %>% paste(collapse=',')
# Put all headers together
header <- paste('Sample', contextHeaders, globalHeaders1, globalHeaders2, sep=',')
fileConn<-file(filename,open = "a")
writeLines(header, fileConn)
} else {
fileConn<-file(filename, open = "a")
}
# 1) context-specific clusters
clusterAssgn = matrix(nrow=nDataPoints,ncol=nContexts)
for (context in 1:nContexts) {
clusterAssgn[,context] <- state$contextK[state$Z, context]
}
contextClusters <- clusterAssgn %>% as.vector %>% as.matrix
# 2) global clusters
assignments <- state$Z %>% as.matrix
globalClusters <- getGlobalClusters(state, mapping)
writeLines(c(iter, t(contextClusters), t(assignments), t(globalClusters)) %>% paste(collapse=','), fileConn)
close(fileConn)
}
computeDIC <- function(thinned_logliks, thinned_samples, nDataPoints, distributions, clusterCounts, prior, datasets) {
D_bar <- -2 * mean(thinned_logliks) # -2 * average posterior log likelihoods
theta_bar <- c() # mean of posterior parameter samples
# take the posterior mode - most common assignment of each data point
# most common global assignment for each data point
modeDataAssignments <-
1:nDataPoints %>%
laply(function(id) {thinned_samples %>% laply(function(s) s$dataAssignments[id])}) %>%
aaply(1, function(column) which.max(tabulate(column)))
idx2d <- # helper variable to get indices into 2d array
dim(thinned_samples[[1]]$contextAssignments) %>%
lapply(seq) %>%
expand.grid() %>%
as.matrix
modeContextAssignmentsTmp <-
idx2d %>%
aaply(1, function(x) {
i <- x[1]; j <- x[2]
thinned_samples %>%
laply(function(s) s$contextAssignments[i,j]) %>%
tabulate %>%
which.max
})
# reshape results
modeContextAssignments <-
matrix(nrow=nrow(thinned_samples[[1]]$contextAssignments), ncol=ncol(thinned_samples[[1]]$contextAssignments))
for (i in 1:nrow(modeContextAssignments)) {
for (j in 1:ncol(modeContextAssignments)) {
modeContextAssignments[i,j] <- modeContextAssignmentsTmp[idx2d[,1] == i & idx2d[,2] == j]
}
}
modeState <- c()
modeState$Z <- modeDataAssignments
modeState$contextK <- modeContextAssignments
modeState$distributions <- distributions
modeState$clusterStats <-
precomputeClusterStatistics(datasets, clusterCounts, modeState$Z, modeState$contextK, modeState$distributions)
# Compute log likelihood of the resulting state
D_hat <- -2 * logJoint(modeState, prior, clusterCounts)
p_D <- D_bar - D_hat
DIC <- p_D + D_bar
DIC
}
#========= Initialization ================
precomputeClusterStatistics <- function(datasets, clusterCounts, Z, contextAssgn, fullDataDistributions) {
N <- laply(1:clusterCounts$global, function(s) sum(Z == s))
globalStats <- llply(1:length(datasets), function(context) {
initGlobalClusterStatistics(Z, clusterCounts$global, datasets[[context]], fullDataDistributions[[context]])
})
globalStats$N <- N
contextStats <- llply(1:length(datasets), function(context) {
initContextClusterStatistics(Z, contextAssgn[,context], clusterCounts$context[context],
datasets[[context]], fullDataDistributions[[context]])
})
list(globalStats = globalStats, contextStats = contextStats)
}
createDataStats <- function(datasets, prior, fullDataDistributions) {
dataStats <- precomputeDataStatistics(datasets, fullDataDistributions)
dataMarginals <- precomputeDataMarginals(dataStats, prior, fullDataDistributions)
list(Stats = dataStats, Marginals = dataMarginals)
}
precomputeDataStatistics <- function(datasets, fullDataDistributions) {
nContexts <- length(datasets)
N <- nrow(datasets[[1]])
llply(1:N, function(n) {
llply(1:nContexts, function(context) {
getDataStatistics(datasets[[context]][n,, drop=F], fullDataDistributions[[context]])
})
})
}
precomputeDataMarginals <- function(dataStatistics, prior, fullDataDistributions) {
N <- length(dataStatistics)
nContexts <- length(fullDataDistributions)
llply(1:N, function(n) {
llply(1:nContexts, function(context) {
getMarginal(dataStatistics[[n]][[context]], 1, prior[[context]], fullDataDistributions[[context]])
})
})
}
# Generate a random initial state
createGibbsState <- function(datasets, clusterCounts, fullDataDistributions) {
nContexts <- length(datasets)
# Z ... assignment of data points to global clusters, with values 1..S
N <- nrow(datasets[[1]])
#Z <- sample(1:clusterCounts$global, N, replace=T)
Z <- rep(1, N)
# contextK ... assignment of global clusters to context clusters
# => matrix of size [globalClusterCount, nContexts]
# contextK[s,c] = index of cluster in context c for global cluster s
contextK <- matrix(nrow=clusterCounts$global, ncol=nContexts)
for (ic in 1:nContexts) {
contextK[,ic] <- sample(1:clusterCounts$context[ic], clusterCounts$global, replace=T)
#contextK[,ic] <- 1
}
# Sufficient statistics for clusters
clusterStats <- precomputeClusterStatistics(datasets, clusterCounts, Z, contextK, fullDataDistributions)
state <- list(Z = Z, contextK = contextK, distributions=fullDataDistributions,
clusterStats = clusterStats)
state
}
#========= Sampling Z =============================
getLogliksForGlobalCluster <- function(n, N, state, dataStats, prior) {
S <- nrow(state$contextK)
nContexts <- length(state$distributions)
N_s <- state$clusterStats$globalStats$N
N_s[state$Z[n]] <- N_s[state$Z[n]] - 1
globalCurrentCluster <- state$Z[n]
contextLogliks <- matrix(nrow=nContexts, ncol=S)
for (context in 1:nContexts) {
currentCluster <- state$contextK[globalCurrentCluster, context]
clusterStats <- state$clusterStats$contextStats[[context]]
cl <-
predictiveLoglik(dataStats$Stats[[n]][[context]],
dataStats$Marginals[[n]][[context]], 1,
currentCluster, clusterStats,
prior[[context]], state$distributions[[context]])
contextLogliks[context, ] <- cl[state$contextK[,context]]
}
logliks <- log(prior$gamma/S + N_s) - log(prior$gamma + N - 1) + colSums(contextLogliks)
logliks
}
gibbsSampleZ <- function(state, dataStats, prior, clusterCounts) {
N <- length(state$Z)
nContexts <- length(state$distributions)
#for (n in 1:N) {
for (n in (sample(1:N,N,replace = F))) { # randomized
logliks <- getLogliksForGlobalCluster(n, N, state, dataStats, prior)
probs <- normalizeProbs(logliks)
newZ <- sample(length(probs), 1, prob=probs)
oldZ <- state$Z[n]
# update parameters
if (newZ != oldZ) {
# Update global cluster counts
state$clusterStats$globalStats$N[newZ] <- state$clusterStats$globalStats$N[newZ] + 1
state$clusterStats$globalStats$N[oldZ] <- state$clusterStats$globalStats$N[oldZ] - 1
for (context in 1:nContexts) {
state$clusterStats$globalStats[[context]] <-
updateGlobalClusterStats(oldZ, newZ, state$clusterStats$globalStats[[context]], state$clusterStats$globalStats$N,
dataStats$Stats[[n]][[context]], 1, state$distributions[[context]])
oldK <- state$contextK[oldZ, context]
newK <- state$contextK[newZ, context]
if(oldK != newK) {
state$clusterStats$contextStats[[context]] <-
updateContextClusterStats(oldK, newK, state$clusterStats$contextStats[[context]],
dataStats$Stats[[n]][[context]], 1, state$distributions[[context]])
}
}
state$Z[n] <- newZ
}
}
state
}
#========= Sampling K ===================================
getLogliksForContextCluster <- function(s, S, context, K, state, prior) {
xStats <- getGlobalClusterStats(s, state$clusterStats$globalStats[[context]],
state$distributions[[context]])
clusterMarginal <- getMarginal(xStats, state$clusterStats$globalStats$N[s],
prior[[context]], state$distribution[[context]])
# Number of global clusters assigned to each local cluster
m_l <- rep(0,K)
for (l in 1:K) {
kIdxs <- which(state$contextK[,context] == l)
m_l[l] <- if (s %in% kIdxs) {length(kIdxs) - 1} else {length(kIdxs)}
}
clusterLoglik <-
predictiveLoglik(xStats, clusterMarginal, state$clusterStats$globalStats$N[s],
state$contextK[s,context], state$clusterStats$contextStats[[context]],
prior[[context]], state$distributions[[context]])
log(prior[[context]]$alpha/K + m_l) -
log(prior[[context]]$alpha + S - 1) +
clusterLoglik
}
gibbsSampleContextK <- function(state, prior, clusterCounts) {
nContexts <- length(clusterCounts$context)
S <- clusterCounts$global
for (context in 1:nContexts) {
K <- clusterCounts$context[context]
#for (s in 1:S) {
for (s in (sample(1:S, S, replace=F))) {
logliks <- getLogliksForContextCluster(s, S, context, K, state, prior)
probs <- normalizeProbs(logliks)
oldK <- state$contextK[s,context]
newK <- sample(length(probs), 1, prob=probs)
# Update parameters
if (oldK != newK) {
globalClusterStats <- getGlobalClusterStats(s, state$clusterStats$globalStats[[context]], state$distributions[[context]])
xN <- state$clusterStats$globalStats$N[s]
contextClusterStats <- state$clusterStats$contextStats[[context]]
contextClusterStats$xN <- state$clusterStats$contextStats[[context]]$xN
state$clusterStats$contextStats[[context]] <-
updateContextClusterStats(oldK, newK, contextClusterStats,
globalClusterStats, xN, state$distributions[[context]])
}
state$contextK[s,context] <- newK
}
}
state
}
#========= Helper functions for different distributions =========================
# - all distributions must implement the following functions
predictiveLoglik <- function(x, xMarginals, xN, currentCluster, clusterStats,
priorParams, distribution) {
switch(distribution,
diagNormal=predictiveLogLik_diagNormal(x, xMarginals, xN, currentCluster, clusterStats, priorParams))
}
getMarginal <- function(dataStats, dataN, priorParams, distribution) {
switch(distribution,
diagNormal=marginal_diagNormal(dataStats, dataN, priorParams))
}
getDataStatistics <- function(xs, distribution) {
switch(distribution,
diagNormal=getDataStatistics_diagNormal(xs))
}
initContextClusterStatistics <- function(Z, contextK, K, contextData, distribution) {
switch(distribution,
diagNormal = initContextClusterStatistics_diagNormal(Z, contextK, K, contextData))
}
initGlobalClusterStatistics <- function(Z, S, contextData, distribution) {
switch(distribution,
diagNormal = initGlobalClusterStatistics_diagNormal(Z, S, contextData))
}
updateContextClusterStats <- function(oldK, newK, clusterStats, dataStats, xN, distribution) {
switch(distribution,
diagNormal=updateContextClusterStats_diagNormal(oldK, newK, clusterStats, dataStats, xN))
}
updateGlobalClusterStats <- function(oldZ, newZ, clusterStats, clusterN, dataStats, xN, distribution) {
switch(distribution,
diagNormal=updateGlobalClusterStats_diagNormal(oldZ, newZ, clusterStats, clusterN, dataStats, xN))
}
getGlobalClusterStats <- function(globalClusterIdxs, clusterStats, distribution) {
switch(distribution,
diagNormal=getGlobalClusterStats_diagNormal(globalClusterIdxs, clusterStats))
}
getContextClusterMarginals <- function(clusterStats, prior, distribution) {
switch(distribution,
diagNormal = getContextClusterMarginals_diagNormal(clusterStats, prior))
}
#========= Normal distribution with diagonal covariance ====================
getParams_diagNormal <- function(sumX, sumX2, N, priorParams) {
D <- length(priorParams$theta$m)
S <- nrow(sumX)
sXX <- sumX*sumX
beta <- priorParams$theta$beta + N
a <- priorParams$theta$a + N/2
# m1 <- (priorParams$theta$beta * 1/N * sXX )/
# rep.col(2*(priorParams$theta$beta + N), D)
# m2 <- (priorParams$theta$beta * 1/N * sXX ) *
# 1/(2*(priorParams$theta$beta + N))
invN <- 1/N
invN[is.infinite(invN)] <- 0
if (all(priorParams$theta$m == 0)) {
dataContributions <- 0.5 * sumX2 - sXX * (1/2*invN) +
(priorParams$theta$beta * invN * sXX ) * 1/(2*(priorParams$theta$beta + N))
b <- sweep(dataContributions, 2, priorParams$theta$b, '+')
} else {
dataContributions <- 0.5 * sumX2 - rep.col(1/(2*N), D) * sXX +
(priorParams$theta$beta * (
rep.col(1/N, D) * sXX - 2*rep.row(priorParams$theta$m,S) * sumX +
rep.col(N, D) * rep.row(priorParams$theta$m^2, S)))/rep.col(2*(priorParams$theta$beta + N), D)
b <- sweep(dataContributions, 2, priorParams$theta$b, '+')
}
list(a = a, b = b, beta = beta)
}
predictiveLogLik_diagNormal <- function(xStats, xMarginals, xN, currentCluster,
clusterStats, priorParams) {
# Remove x from its current cluster
clusterStats$xN[currentCluster] <- clusterStats$xN[currentCluster] - xN
clusterStats$sumX[currentCluster,] <- clusterStats$sumX[currentCluster,] - xStats$sumX
clusterStats$sumX2[currentCluster,] <- clusterStats$sumX2[currentCluster,] - xStats$sumX2
paramsOther <- getParams_diagNormal(clusterStats$sumX, clusterStats$sumX2,
clusterStats$xN, priorParams)
paramsAll <- getParams_diagNormal(sweep(clusterStats$sumX, 2, xStats$sumX, '+'),
sweep(clusterStats$sumX2, 2, xStats$sumX2, '+'),
clusterStats$xN + xN, priorParams)
D <- length(priorParams$theta$m)
logliks <-
D * (lgamma(paramsAll$a) - lgamma(paramsOther$a)) +
paramsOther$a * rowSums(log(paramsOther$b)) -
paramsAll$a * rowSums(log(paramsAll$b)) +
D/2 * (log(paramsOther$beta) - log(paramsAll$beta)) -
(xN * D)/2 * log(2*pi)
logliks[clusterStats$xN == 0] <- xMarginals
logliks
}
marginal_diagNormal <- function(xStats, xN, priorParams) {
params <- getParams_diagNormal(xStats$sumX %>% as.matrix, xStats$sumX2 %>% as.matrix, xN, priorParams)
D <- length(priorParams$theta$m)
loglik <-
D* (lgamma(params$a) - lgamma(priorParams$theta$a)) +
D/2 * (log(priorParams$theta$beta) - log(params$beta)) -
(xN * D)/2 * log(2*pi) +
sum( priorParams$theta$a * log(priorParams$theta$b) - params$a * log(params$b))
loglik[xN == 0] <- 0
loglik
}
getDataStatistics_diagNormal <- function(xs) {
list(sumX = t(as.matrix(colSums(xs))), sumX2 = t(as.matrix(colSums(xs*xs))))
}
initGlobalClusterStatistics_diagNormal <- function(Z, S, contextData) {
sumX <- laply(1:S, function(s) contextData[Z == s,,drop=F] %>% colSums, .drop=F)
sumX2 <- laply(1:S, function(s) {
xs <- contextData[Z == s,,drop=F]
(xs * xs) %>% colSums }, .drop=F)
list(sumX = sumX, sumX2 = sumX2)
}
# Compute statistics for local clusters in a specific context
# inefficient but it's called only during initialisation
initContextClusterStatistics_diagNormal <- function(Z, contextK, K, contextData) {
sumX <- laply(1:K, function(k) {
globalIdxs <- which(contextK == k)
contextData[Z %in% globalIdxs,,drop=F] %>% colSums
}, .drop = F)
sumX2 <- laply(1:K, function(k) {
globalIdxs <- which(contextK == k)
xs <- contextData[Z %in% globalIdxs,,drop=F]
(xs * xs) %>% colSums
}, .drop=F)
xN <- laply(1:K, function(k) {
globalIdxs <- which(contextK == k)
sum(Z %in% globalIdxs)
})
list(sumX = sumX, sumX2 = sumX2, xN = xN)
}
updateContextClusterStats_diagNormal <- function(oldK, newK, clusterStats, dataStats, xN) {
# 1) Add x to new cluster
clusterStats$xN[newK] <- clusterStats$xN[newK] + xN
clusterStats$sumX[newK,] <- clusterStats$sumX[newK,] + dataStats$sumX
clusterStats$sumX2[newK,] <- clusterStats$sumX2[newK,] + dataStats$sumX2
# 2) remove x from the previous cluster
clusterStats$xN[oldK] <- clusterStats$xN[oldK] - xN
if (clusterStats$xN[oldK] == 0) {
clusterStats$sumX[oldK,] <- 0
clusterStats$sumX2[oldK,] <- 0
} else {
clusterStats$sumX[oldK,] <- clusterStats$sumX[oldK,] - dataStats$sumX
clusterStats$sumX2[oldK,] <- clusterStats$sumX2[oldK,] - dataStats$sumX2
}
clusterStats
}
updateGlobalClusterStats_diagNormal <- function(oldZ, newZ, clusterStats, clusterN, dataStats, xN) {
clusterStats$sumX[newZ,] <- clusterStats$sumX[newZ,] + dataStats$sumX
clusterStats$sumX2[newZ,] <- clusterStats$sumX2[newZ,] + dataStats$sumX2
if (clusterN[oldZ] == 0) {
clusterStats$sumX[oldZ,] <- 0
clusterStats$sumX2[oldZ,] <- 0
} else {
clusterStats$sumX[oldZ,] <- clusterStats$sumX[oldZ,] - dataStats$sumX
clusterStats$sumX2[oldZ,] <- clusterStats$sumX2[oldZ,] - dataStats$sumX2
}
clusterStats
}
getGlobalClusterStats_diagNormal <- function(s, clusterStats) {
sumX <- clusterStats$sumX[s,,drop=F]
sumX2 <- clusterStats$sumX2[s,,drop=F]
list(sumX = sumX, sumX2 = sumX2)
}
#========= Joint ==============
logZ <- function(state, S, N, prior) {
n_s <- state$clusterStats$globalStats$N
lgamma(prior$gamma) - (S * lgamma(prior$gamma/S)) +
sum(lgamma(prior$gamma/S + n_s)) - lgamma(prior$gamma + N)
}
logK <- function(state, context, K, prior) {
m_l <- rep(0,K)
for (l in 1:K) {
m_l[l] <- sum(state$contextK[,context]==l)
}
M <- state$contextK %>% nrow
lgamma(prior[[context]]$alpha) - K * lgamma(prior[[context]]$alpha/K) +
sum(lgamma(prior[[context]]$alpha/K + m_l)) - lgamma(prior[[context]]$alpha + M)
}
logJoint <- function(state, prior, clusterCounts) {
S <- clusterCounts$global
N <- state$Z %>% length
nContexts <- length(clusterCounts$context)
# Data log likelihood
dataLoglik <- 0
loglikK <- 0
for (context in 1:nContexts) {
dataLoglik <-
dataLoglik + getContextClusterMarginals(state$clusterStats$contextStats[[context]], prior[[context]],
state$distributions[[context]]) %>%
sum
loglikK <- loglikK + logK(state, context, clusterCounts$context[context], prior)
}
logZ(state, S, N, prior) + loglikK +
dataLoglik
}
getContextClusterMarginals_diagNormal <- function(clusterStats, priorParams) {
params <- getParams_diagNormal(clusterStats$sumX, clusterStats$sumX2,
clusterStats$xN, priorParams)
D <- length(priorParams$theta$m)
loglik <-
D * (lgamma(params$a) - lgamma(priorParams$theta$a)) +
D/2 * (log(priorParams$theta$beta) - log(params$beta)) -
(clusterStats$xN * D)/2 * log(2*pi) +
rowSums( priorParams$theta$a * log(priorParams$theta$b) - params$a * log(params$b))
loglik[clusterStats$xN == 0] <- 0
loglik
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.