#' Correlation Bi-community Extraction method
#'
#' Given two groups of variables, find correlation bi-communities between them. For such a community, the nodes from the first group are are higly correlated to the community-nodes from the second group, and vice versa.
#'
#' \code{cbce} applies an update function (mapping subsets of variables to subsets of variables) iteratively until a fixed point is found. These fixed points are reported as communities. The update function uses multiple-testing procedure to find variables correlated to the current set of variables.
#'
#' The update starts from a single node (starting with the initialization step) and is repeated till either a fixed point is found or some set repeats. Each such run is called an extraction. Since the extraction only starts from singleton node, there are \code{ncol(X)+ncol(Y)} possible extractions.
#'
#' @param X,Y Numeric Matices. Represents the two groups of variables.
#' @param alpha \eqn{\in (0,1)}. Controls the type1 error per update. This is the type1 error to use for BH procedure
#' @param exhaustive Boolean. If exhaustive is FALSE, new extractions are not started from nodes within found communities. Otherwise, attempt is made to start from all nodes.
#' @param OL_thres \eqn{\in (0,1)}. Threshold used to conclude significant overlap. The sets have significant overlap if jaccard similarity is >= OL_thres.
#' @param OL_tol If more than OL_tol found communities have OL_thres overlap, stop method.
#' @param Dud_tol If more than Dud_tol initializations end in a dud, stop method.
#' @param time_limit Stop method after time_limit seconds.
#' @param updateMethod Use the 1(-step) update vs 2(-step) update
#' @param init_method The initialization procedure to use. Must be one of "conservative-BH", "non-conservative-BH", "BH-0.5", "no-multiple-testing".
#' @param inv.length Logical. Use inv.length as score while selecting the smallest community from communities with significant overlap.
#' @param start_nodes The initial set of nodes to start with. If Null start from all the nodes (may still exclude nodes within found communities if exhaustive = FALSE).
#' @param parallel Use parallel processing.
#' @param calc_full_cor Calculate \code{c(ncol(X),ncol(Y))} dimensional correlation matrix. This makes the computation faster but requires more memory.
#' @param mask_extracted Boolean. If true, then once a stable community is found, the nodes in it will never be used in future extractions.
#' @param backend The engine to use for p-value computation. Currently must be one of "normal", "normal_two_sided", "chisq", "chisq_fast", "indepChiSq", "indepNormal".
#' @param diagnostics This is a function that is called whenever internal events happen. It can then collect useful meta-data which is added to the final resutls.
#' @return The return value is a list with details of the extraction and list of indices representing the communities. See example below (finding communities in noise). Note that the variables from the X and Y set are denoted using a single numbering. Hence the nodes in X are denoted by \code{1:dx} and the nodes in Y are denoted by the numbers following dx (hence \code{dx+1:dy})
#' @export
#' @examples
#' \dontrun{
#' n <- 100
#' dx <- 50
#' dy <- 70
#'
#' X <- matrix(rnorm(n*dx), ncol=dx)
#' Y <- matrix(rnorm(n*dy), ncol=dy)
#' res <- cbce(X, Y)
#' finalComms <- lapply(res$extract_res[res$finalIndxs], function(r) r$StableComm)
#'}
cbce <- function(X, Y,
alpha = 0.05,
OL_thres = 0.9,
exhaustive = FALSE,
OL_tol = Inf,
Dud_tol = Inf,
time_limit = 18000,
updateMethod = 1,
init_method = "conservative-BH",
inv.length = TRUE,
start_nodes = NULL,
parallel = FALSE,
calc_full_cor=FALSE,
backend = "chisq",
mask_extracted = FALSE,
diagnostics=diagnostics1
) {
#-----------------------------------------------------------------------------
# Setup
if(mask_extracted && exhaustive) {
stop("mask_extracted and exhaustive are incompatible.")
}
start_second <- proc.time()[3]
cat("#-------------------\n")
cat("Setting up backend\n")
#Initialize the backend method specified by \code{backend}. This involves doing some precomputation. The variable bk is an S3 object which stores the precomputation.
bk <- switch(backend,
chisq = backend.chisq(X, Y, parallel, calc_full_cor),
chisq_fast = backend.chisq(X, Y, parallel, calc_full_cor, fast_approx=TRUE),
normal = backend.normal(X, Y, calc_full_cor),
normal_two_sided = backend.normal_two_sided(X, Y, calc_full_cor),
indepChiSq = backend.indepChiSq(X, Y, calc_full_cor),
indepNormal = backend.indepNormal(X, Y, calc_full_cor),
stop(paste("Unknown backend:", backend))
)
dx <- ncol(X)
dy <- ncol(Y)
#Global numbering scheme for the X and Y nodes.
#Remember how X indices were numbered from (1:dx) and Y from (dx+1 : dy)?
Xindx <- 1:dx
Yindx <- (1:dy) + dx
cat("Beginning method.\n\n")
# Getting node orders. We know what nodes to start with, but what is a good order to start in.
if (exhaustive) {
# Random ordering
extractord <- sample(c(Xindx, Yindx))
} else {
# Sort the nodes in decreasing order of thier correlation sums to the oppisite side.
# This can be calculated efficiently. Note that bk$X, bk$Y are scaled verions of X and Y.
Ysum <- bk$Y %*% rep(1,dy) / dy
Xsum <- bk$X %*% rep(1,dx) / dx
cor_X_to_Ysums <- abs(as.vector(t(Ysum) %*% bk$X))
cor_Y_to_Xsums <- abs(as.vector(t(Xsum) %*% bk$Y))
extractord <- c(Xindx, Yindx)[order(c(cor_X_to_Ysums, cor_Y_to_Xsums),
decreasing = TRUE)]
}
if (!is.null(start_nodes))
extractord <- extractord[extractord %in% start_nodes]
# Initializing control variables
stop_extracting <- FALSE
# The Dud and Overlap counts
OL_count <- 0
Dud_count <- 0
# The collection of nodes already clustered.
clustered <- integer(0)
comms <- list()
#-------------------------------------------------------------------------------
# The core functions
# Initialize the extraction.
# Use the backend to do most of the work, but correct for global indices
# @param indx The (global) index of the node (variable) to initialize from.
# @return list(x=integer-vector, y=integer-vector): The initialized x, y sets.
initialize <- function(indx) {
if (indx <= dx) {
#indx on the X side, so only need to correct the init-step.
B01 <- init(bk, indx, alpha, init_method) + dx
B02 <- bh_reject(pvals(bk, B01), alpha)
return(list(x = B02, y = B01))
} else {
#indx on the Y side, so only need to correct the half update following the init step.
B01 <- init(bk, indx, alpha, init_method)
B02 <- bh_reject(pvals(bk, B01), alpha) + dx
return(list(x = B01, y = B02))
}
}
#Split a single set into X, Y each with the global numbering.
split <- function(set){
set_x <- Filter(function(x) x <= dx, set)
set_y <- Filter(function(x) x > dx, set)
return(list(x=set_x, y=set_y))
}
# Store the X and Y sets of B in a single vector.
# This is an inverse to split
# Note B$y is already uses global numbering, so no information is lost.
merge <- function(B) {
c(B$x, B$y)
}
jac <- function(B1, B2) {
jaccard(merge(B1), merge(B2))
}
conglomerate <- function(chain) {
#Construct the union of the sets in chain
split(unique(unlist(chain)))
}
update <- function(B0, startX=TRUE) {
# Do the update starting from B. Do either the two sided or one-sided update.
# @param B0 list(x, y) : x is a subset of X nodes, y is a subset of Y nodes (using the global index).
# @param startX If doing 2 step update start at X side; If FALSE,
# start at Y side. Does not have any effect on 1 step update.
# @return list(x, y) : The X and Y subsets corresponding to the updated set (again using global numbering)
if (updateMethod == 2) {
if (startX) {
B1y <- bh_reject(pvals(bk, B0$x), alpha) + dx
B1x <- bh_reject(pvals(bk, B1y), alpha)
} else {
B1x <- bh_reject(pvals(bk, B0$y), alpha)
B1y <- bh_reject(pvals(bk, B1x), alpha) + dx
}
list(x=B1x, y=B1y)
} else {
px <- pvals(bk, B0$y) # size dx vector
py <- pvals(bk, B0$x) # size dy vector
#Note that the positions returned by bh_reject are already the global
#numbers.
return(split(bh_reject(c(px, py), alpha)))
}
}
extract <- function(indx) {
# Start the extraction from indx
#
# First do the initialization at indx, and then repeateadly apply update till a fixed point is found, or one of the sets repeates (forming a non-trivial cycle).
# In case there is a cycle and if none of the consecutive sets in the cycle are sufficiently disjoint (jaccard >= 0.5), the take the union of all those sets and start the extraction from the conglomerate set. Othewise we say is break is found, and the extraction is terminated (without any result).
#
#@return The return value is the extract_res field of the final method results.
# Getting start time
current_time <- proc.time()[3] - start_second
# Doing checks
if (current_time > time_limit || Dud_count > Dud_tol || OL_count > OL_tol) {
stop_extracting <<- TRUE
}
if (stop_extracting) return(list(report = "stop_extracting"))
# Checking for stops
if (indx %in% clustered && !exhaustive)
return(list(report = "indx_clustered"))
# Print some output
cat("\n#-----------------------------------------\n\n")
cat("extraction", which(extractord == indx), "of", length(extractord), "\n\n")
cat("indx", indx, "has not been clustered\n")
cat("OL_count = ", OL_count, "\n")
cat("Dud_count = ", Dud_count, "\n")
cat(paste0(sum(clustered <= dx), " X vertices in communities.\n"))
cat(paste0(sum(clustered > dx), " Y vertices in communities.\n"))
#Is indx an X or Y node. This is important for the two-step update to
# tell \code{update()} which side to start initialization from.
startsAtX <- (indx <= dx)
B0 <- initialize(indx)
# Check for dud
if (min(length(B0$x),length(B0$y)) <= 1) {
Dud_count <<- Dud_count + 1
cat("--initialize was a dud\n")
return(list("indx" = indx, "report" = "dud"))
}
cat(paste0("--obtained initial set of size (", length(B0$x), ", ", length(B0$y), ")\n"))
# Initializing extraction loop
chain <- list(B0) #The collection of all sets visited during the extraction
B1 <- split(integer(0)) #Initial value for B1
did_it_cycle <- FALSE
break_or_collapsed <- FALSE
itCount <- 0
diagnostics("ExtracionLoopBegins")
# Extraction loop
repeat {
itCount <- itCount + 1
cat(paste0("--Performing update ", itCount, "...\n"))
B1 <- update(B0, startX = startsAtX)
if (length(B1$y) * length(B1$x) == 0) {
break_or_collapsed <- TRUE
diagnostics("Collapsed")
cat("----collapse on at least one side\n")
break
}
jaccards <- rlist::list.mapv(chain, jac(., B1))
diagnostics("AfterUpdate")
#End loop if fixed point found.
if (jac(B1, B0) == 0) break
# Checking for cycles (4.4.1 in CCME paper)
if(sum(jaccards == 0) > 0) {
diagnostics("FoundCycle")
did_it_cycle <- TRUE
cat("----cycle found")
Start <- max(which(jaccards == 0))
cycle_chain <- chain[Start:length(chain)]
# Checking for cycle break (4.4.1a)
seq_pair_jaccards <- rep(0, length(cycle_chain) - 1)
for (j in seq_along(seq_pair_jaccards)) {
seq_pair_jaccards[j] <- jac(chain[[Start + j - 1]], chain[[Start + j]])
}
if (sum(seq_pair_jaccards > 1 - OL_thres) > 0) {# then break needed
cat("------break found\n")
diagnostics("FoundBreak")
break_or_collapsed <- TRUE
break
}
# Create conglomerate set (and check, 4.4.1b)
B_J <- conglomerate(cycle_chain)
B_J_check <- rlist::list.mapv(chain, jaccard(B_J, .))
if (sum(B_J_check == 0) > 0) {
cat("------old cycle\n")
break_or_collapsed <- TRUE
break
} else {
cat("------new cycle\n")
B1 <- B_J
}
}
cat(paste0("----updated to set of size (", length(B1$x), ", ", length(B1$y), ")\n"))
cat(paste0("----jaccard to previous = ", round(jaccards[length(jaccards)], 3), "\n"))
chain <- rlist::list.append(chain, B1)
B0 <- B1
}
diagnostic_info <- diagnostics("EndOfExtract")
# Storing B_new and collecting update info
if (break_or_collapsed) {
Dud_count <<- Dud_count + 1
return(c(list("indx" = indx,
"StableComm" = integer(0),
"itCount" = itCount,
"did_it_cycle" = did_it_cycle,
"report" = "break_or_collapse"), diagnostic_info))
}
stableComm <- merge(B1)
if (mask_extracted) {
mask(bk, B1$x, B1$y)
}
clustered <<- union(clustered, stableComm)
comms <<- rlist::list.append(comms, B1)
# Checking overlap with previous sets
if (length(comms) > 1) {
OL_check <- rlist::list.mapv(comms, jac(., B1))
if (sum(OL_check < 1 - OL_thres) > 0) {
OL_count <<- OL_count + 1
}
}
return(c(list("indx" = indx,
"StableComm" = stableComm,
"itCount" = itCount, "did_it_cycle" = did_it_cycle,
"report" = "complete_extraction"), diagnostic_info))
}
#-------------------------------------------------------------------------------
# Extractions
# Extracting
extract_res <- lapply(extractord, extract)
stopped_at <- rlist::list.first(extract_res, report == "stop_extracting")
extract_res <- extract_res[order(extractord)]
#-----------------------------------------------------------------------------
# Clean-up and return -------------------------------------------------------
cat("Cleaning up.\n")
# Getting final sets and making report
final.sets <- lapply(extract_res, function(R) R$StableComm)
endtime <- proc.time()[3]
report <- list(OL_count = OL_count,
Dud_count = Dud_count,
Total_time = endtime - start_second,
Extract_ord = extractord,
Stopped_at = stopped_at)
# Removing blanks and trivial sets
nonNullIndxs <- rlist::list.which(final.sets, length(.) > 0)
if (length(nonNullIndxs) == 0) {
returnList <- list("communities" = list("X_sets" = NULL,
"Y_sets" = NULL),
"background" = 1:(dx + dy),
"extract_res" = extract_res,
"finalIndxs" = integer(0),
"final.sets" = final.sets,
"report" = report)
return(returnList)
}
nonNullComms <- final.sets[nonNullIndxs]
# Keep a single representative for communities with overlap > OL_thres
OLfilt <- filter_overlap(nonNullComms, tau = OL_thres, inv.length = inv.length)
#finalComms <- OLfilt$final_comms
finalIndxs <- nonNullIndxs[OLfilt$kept_comms]
returnList <- list("extract_res" = extract_res,
"finalIndxs" = finalIndxs,
"report" = report)
cat("#-------------------\n\n\n")
return(returnList)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.