#' Calculates the spatial distance between cells and each cluster. Distance to cluster is determined by metric.
#' @param cell.dists An n x n matrix of pairwise physical distances between all cells, scaled to [0,1]
#' @param clust A vector length n indiacting cluster membership for each cell
#' @param metric A column-wise quantile function to determine distance to cluster (eg. colMins, colMedians, colMaxs from matrixStats)
#' @return An nclust X n matrix with the physical distance to each cluster for each cell
#' @export
spatial.distance = function(cell.dists, clust, metric=matrixStats::colMins) {
n = length(clust)
grouped = sapply(1:max(clust), function(i) {
mat = cell.dists[which(clust == i),]
if(is.matrix(mat)) {
return(list(metric(mat)))
} else {
return(list(mat))
}
})
d = do.call(rbind,grouped)
d
}
#' Clusters spatial omics data with spatial regularization
#' @param X An n x p matrix with n cells and p features
#' @param cell.dists An n x n matrix with pairwise distances between all cells, scaled to [0,1]
#' @param nclust A scalar indicating the number of clusters
#' @param lambda The weight that spatial penalty takes
#' @param metric A column-wise quantile function to determine distance to cluster (eg. colMins, colMedians, colMaxs from matrixStats)
#' @param max.iter Maximum number of iterations before stopping clustering
#' @example
#' @returns a list containing mus (the centroid of each cluster), est (cluster membership for each cell), num_it (number of iterations taken to converge)
#' @export
k.means.spatial = function(X, cell.dists, nclust=3, lambda=1, metric=matrixStats::colMins, max.iter=100) {
stopifnot(max(cell.dists) <= 1)
it = 0
N = dim(X)[1]
assign.changed = TRUE
mus = t(X[sample(1:N,nclust),])
last.clust = rep(1:nclust,length.out = N)
clust = last.clust
max.scale = max(colMeans(abs(X)))
X = scale(X)
while(it < max.iter & assign.changed) {
dists = rdist::cdist(t(mus),X)
phys.dists = spatial.distance(cell.dists, clust)
phys.dists = lambda*max.scale*phys.dists
dists = dists + phys.dists
clust = apply(dists, 2, function(d) {
which.min(d)
})
mus = sapply(1:nclust, function(i) {
X.clust = X[clust == i,]
if(is.vector(X.clust)) {
return(X.clust)
} else {
return(colMeans(X.clust))
}
})
if(sum(last.clust == clust) == nrow(X)) {
assign.changed = FALSE
}
last.clust = clust
it = it + 1
if(it %% 10 == 0) {
print(it)
}
}
return(list(mus=mus,est=factor(clust),num_it=it))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.