# Get most likely class label based on posterior samples
get_MAP_z <- function(chain, burnin) {
rles <- apply(chain$z_chain[,-burnin], 1, function(X) rle(sort(X)))
map_int(rles, ~ .x$values[which.max(.x$lengths)])
}
# Computes pairwise KL distance between 2 MVNs
get_KL_distance <- function(mu1, mu2, Sigma1, Sigma2) {
k <- nrow(Sigma1)
KL1 <- 0.5 * (sum(diag(solve(Sigma1) %*% Sigma2)) - k + log(det(Sigma1) / det(Sigma2)))
KL2 <- 0.5 * (sum(diag(solve(Sigma2) %*% Sigma1)) - k + log(det(Sigma2) / det(Sigma1)))
0.5 * (KL1 + KL2)
}
merge_classes <- function(n_groups, chain, burnin, multichain = FALSE, method = "average", ...) {
if(multichain) {
merge_classes_multichain(n_groups, chain, burnin, method = method, ...)
} else {
mu <- map(chain$mu_chains, ~ colMeans(.x[-burnin, ])) %>%
bind_cols() %>%
as.matrix %>%
unname
prop <- colMeans(chain$prop_chain[-burnin, ])
sig_ests <-
map(chain$Sigma_chains, ~ apply(.x[, , -burnin], c(1, 2), mean)) %>% simplify2array
z <- get_MAP_z(chain, burnin)
nm <- ncol(mu)
dm <- nrow(mu)
cluster_dist <- matrix(0, nm, nm)
combos <- combn(unique(z), 2)
for (i in 1:ncol(combos)) {
cluster_dist[combos[1, i], combos[2, i]] <-
cluster_dist[combos[2, i], combos[1, i]] <-
get_KL_distance(
mu1 = mu[, combos[1, i]],
mu2 = mu[, combos[2, i]],
Sigma1 = sig_ests[, , combos[1, i]],
Sigma2 = sig_ests[, , combos[2, i]]
)
}
rownames(cluster_dist) <- colnames(cluster_dist) <- 1:nm
rmidx <- rowSums(cluster_dist) == 0
if(sum(!rmidx) < n_groups) {
n_groups <- sum(!rmidx)
warning(paste0("n_groups should be less than or equal to the number of non-empty clusters.\nYou have ",
sum(!rmidx), " empty clusters, so merging with that number instead."))
}
cluster_dist <- cluster_dist[!rmidx,!rmidx]
cl <- hclust(as.dist(cluster_dist), method = method, ...)
merge_idx <- cutree(cl, k = n_groups)
merge_prop <- rep(0, length(unique(merge_idx)))
merge_mu <- matrix(NA, length(unique(merge_idx)), dm)
merge_sigma <- array(NA, dim = c(dm, dm, length(unique(merge_idx))))
for (i in sort(unique(merge_idx))) {
subidx <- as.numeric(names(merge_idx[merge_idx == i]))
sub_prop <- prop[subidx]
sub_mu <- mu[, subidx]
sub_sigma <- sig_ests[,,subidx]
merge_prop[i] <- sum(sub_prop)
if (length(sub_prop) == 1) {
merge_mu[i, ] <- sub_mu
merge_sigma[,,i] <- sub_sigma
} else {
merge_mu[i, ] <- sub_mu %*% (sub_prop / merge_prop[i])
for(j in seq_along(sub_prop)) {
sub_sigma[,,j] <- sub_sigma[,,j] * (sub_prop[j] / merge_prop[i])
}
merge_sigma[,,i] <- apply(sub_sigma, c(1,2), sum)
}
}
merge_prop <- merge_prop / sum(merge_prop)
outz <- z
for (i in sort(unique(merge_idx))) {
outz[z %in% as.numeric(names(merge_idx[merge_idx == i]))] <- i
}
list(
"merged_z" = outz,
"merged_mu" = merge_mu,
"merged_sigma" = merge_sigma,
"merged_prop" = merge_prop,
"clustering" = cl,
"distmat" = cluster_dist
)
}
}
# Clustering the columns
compute_distances_between_conditions <- function(chain, burnin, multichain = FALSE) {
if(multichain) {
compute_distances_between_conditions_multichain(chain, burnin)
} else {
# Get parameter estimates after burn-in
prop <- colMeans(chain$prop_chain[-burnin, ])
sig_ests <-
map(chain$Sigma_chains, ~ apply(.x[, , -burnin], c(1, 2), mean)) %>%
simplify2array
# get correlations from covariances
correlations <- array(apply(sig_ests, 3, cov2cor), dim(sig_ests))
# weight correlations by mixing weights
weighted_corrs <- imap(prop, ~ .x * correlations[, , .y]) %>%
Reduce(`+`, x = .)
# Convert to distance
sqrt(1-(weighted_corrs) ^ 2)
}
}
# Clustering the rows
compute_distances_between_clusters <- function(chain, burnin, multichain = FALSE) {
if(multichain) {
compute_distances_between_clusters_multichain(chain, burnin)
} else {
# Get parameter estimates after burn-in
mu <- map(chain$mu_chains, ~ colMeans(.x[-burnin, ])) %>%
bind_cols() %>%
as.matrix %>%
unname
prop <- colMeans(chain$prop_chain[-burnin, ])
z_chain <- chain$z_chain[, -burnin]
sig_ests <-
map(chain$Sigma_chains, ~ apply(.x[, , -burnin], c(1, 2), mean)) %>%
simplify2array
nm <- ncol(mu)
dm <- nrow(mu)
rles <- apply(z_chain, 1, function(X) rle(sort(X)))
# z <- map_int(rles, ~ .x$values[which.max(.x$lengths)])
z <- unlist(map(rles, ~ .x$values[.x$lengths / sum(.x$lengths) >= 0.5]))
cluster_dist <- matrix(0, nm, nm)
combos <- combn(unique(z), 2)
for (i in 1:ncol(combos)) {
cluster_dist[combos[1, i], combos[2, i]] <-
cluster_dist[combos[2, i], combos[1, i]] <-
get_KL_distance(
mu1 = mu[, combos[1, i]],
mu2 = mu[, combos[2, i]],
Sigma1 = sig_ests[, , combos[1, i]],
Sigma2 = sig_ests[, , combos[2, i]]
)
}
rmidx <- rowSums(cluster_dist) == 0
cluster_dist <- cluster_dist[!rmidx, !rmidx]
labels <- apply(sign(mu[,!rmidx]), 2, function(X)
paste0("(", paste0(X, collapse = ","), ")"))
rownames(cluster_dist) <- colnames(cluster_dist) <- labels
cluster_dist
}
}
# Get row reordering (for bi-clustering heatmaps)
get_row_reordering <- function(row_clustering, chain, burnin, dat, multichain = FALSE) {
if(multichain) {
get_row_reordering_multichain(row_clustering, chain, burnin, dat)
} else {
# Get MAP class labels based on posterior samples
z <- get_MAP_z(chain, burnin)
# cluster number
nm <- length(row_clustering$order)
# Get a row label based on row clustering
newz <- z
count <- 1
for (i in seq(nm)) {
idx <- z == row_clustering$order[i]
if (sum(idx) > 0) {
if (sum(idx) > 1) {
reord <-
hclust(as.dist(sqrt(1 - cor(
t(dat[idx,]), method = "pearson"
) ^ 2)))$order
newz[idx] <- (count:(count + sum(idx) - 1))[reord]
} else {
newz[idx] <- count:(count + sum(idx) - 1)
}
count <- count + sum(idx)
}
}
newz
}
}
test_consistency <- function(chain,
burnin,
u = NULL,
b = 0.5,
with_zero = FALSE,
agnostic_to_sign = FALSE) {
z_chain <- chain$z_chain[,-burnin]
labs <- map(chain$mu_chains, ~ sign(.x[1,])) %>%
do.call(`rbind`, .)
n <- nrow(z_chain)
D <- ncol(labs)
if(with_zero & !is.null(u)) {
warning("Testing consistency and including the null class, but a threshold u was specified. Ignoring u and setting u = D.")
}
if(!is.null(u)) {
if(u > D) {
stop("u cannot be greater than the dimension of the data.")
}
}
if(agnostic_to_sign) {
labs <- abs(labs)
}
# Testing for consistency across all dimensions
if(with_zero) {
u <- D
# Find labels where all entries are equal
consistent_lab_idx <- apply(labs, 1, function(X) diff(range(X)) == 0)
# If there are no consistent labels, there are no consistent observations
if(sum(consistent_lab_idx) == 0) {
return(rep(FALSE, n))
}
consistent_lab_idx <- which(consistent_lab_idx)
} else { # Testing for replicability of a signal
consistent_pos_lab_idx <- apply(labs, 1, function(X) sum(X == 1) >= u)
consistent_neg_lab_idx <- apply(labs, 1, function(X) sum(X == -1) >= u)
# If there are no consistent labels, there are no consistent observations
if(sum(consistent_pos_lab_idx) == 0 & sum(consistent_neg_lab_idx) == 0) {
return(rep(FALSE, n))
}
consistent_lab_idx <-
sort(unique(c(
which(consistent_neg_lab_idx),
which(consistent_pos_lab_idx)
)))
}
# Find probabilities that each observation is assigned a consistent label
consistent_prob <-
colMeans(apply(z_chain, 1, function(X)
X %in% consistent_lab_idx))
if(length(b) == 1) {
return(consistent_prob > b)
} else {
replicable <- map(b, ~ consistent_prob > .x) %>%
set_names(paste(b))
return(replicable)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.