R/sc_cluster_stability.R

Defines functions cluster_res_stability_heatmap sc_cluster_res_stability cluster_stability_barplot sc_cluster_stability

Documented in cluster_res_stability_heatmap cluster_stability_barplot sc_cluster_res_stability sc_cluster_stability

#' sc_cluster_stability
#'
#' Compute the stability of your single cell clusters. This is done by sampling
#' a number of cells, reclustering, and computing the intersection of the
#' sampled clusters and the original clusters.
#' @param seurat_obj A Seurat object where FindClusters has been called.
#' @param n_cells The number of cells to sample
#' @param n_bootstrap The number of iterations to sample. Default to 10
#' @param cluster_res The resolution of clustering that will be computed for
#' sampled data. If not provided, the resolution used in FindClusters will be used.
#' @param reduction The dimensionality reduction for FindNeighbors. Default to pca
#' @param verbose Show more information. Default to FALSE.
#'
#' @import Seurat
#' @import ggplot2
#' @import tidyverse
#' @keywords clustering
#' @export
#' @examples
#' sc_cluster_stability(my_seurat, 10000)

sc_cluster_stability <- function(seurat_obj, n_cells, n_bootstrap=10, cluster_res=F, reduction="pca", verbose=F){
  # set cluster resolution if not provided by user
  if(cluster_res == F){
    cluster_res = seurat_obj@commands$FindClusters$resolution
  }
  # startup message
  if(verbose){
    cat(
      paste(
        "Performing cluster stability analysis with:",
        "\nresolution =", cluster_res,
        "\nreduction =", reduction,
        "\nn_bootstrap = ", n_bootstrap,
        "\nn_cells = ", n_cells, "\n"))
  }

  clusters <- as.numeric(levels(unique(seurat_obj$seurat_clusters)))
  bootstrap_list <- list()
  for(n in 1:n_bootstrap){
    if(verbose){cat(paste("iteration:", n, "\n"))}
    intersect_list <- vector(mode='numeric', length=length(clusters))
    seurat_sample <- subset(seurat_obj, cells = sample(colnames(seurat_obj), size=n_cells, replace=F))
    seurat_sample <- FindNeighbors(seurat_sample, reduction = reduction, dims = 1:dim(seurat_sample[[reduction]])[2], nn.eps=0.5, verbose=verbose)
    seurat_sample <- FindClusters(seurat_sample, resolution = cluster_res, n.start=10, verbose=verbose)
    sample_clusters <- as.numeric(levels(unique(seurat_sample$seurat_clusters)))
    for(i in 1:length(clusters)){
      cur_cells_i <- colnames(seurat_obj)[seurat_obj$seurat_clusters == clusters[i]]
      sample_intersect_list <- vector(mode='numeric', length=length(sample_clusters))
      for(j in 1:length(sample_clusters)){
        cur_cells_j <- colnames(seurat_sample)[seurat_sample$seurat_clusters == sample_clusters[j]]
        sample_intersect_list[j] <- length(intersect(cur_cells_j, cur_cells_i))/length(cur_cells_j)
      }
      intersect_list[i] <- max(sample_intersect_list)
    }
    bootstrap_list[[n]] <- intersect_list
  }
  return(rowMeans(sapply(bootstrap_list, unlist)))
}


#' cluster_stability_barplot
#'
#' Plot cluster stability scores as a barplot. Returns a ggplot object.
#' @param cluster_stabilities The output of the sc_cluster_stability function
#'
#' @import tidyverse
#' @keywords plotting
#' @export
#' @examples
#' cluster_stability_barplot(sc_cluster_stability(my_seurat, 10000))
cluster_stability_barplot <- function(cluster_stabilities, stability, clusters=1:length(cluster_stabilities)){
  stability_df <- data.frame(
    clusters=clusters,
    scaled_intersection=cluster_stabilities,
    stability = stability
  )
  p <- ggplot(
    data=stability_df,
    aes(x=reorder(clusters, -scaled_intersection), y=scaled_intersection, fill=stability)
  ) + geom_bar(stat="identity") + # NoLegend() +
  scale_y_continuous(expand = c(0, 0), limits = c(0, 1)) +
  geom_hline(yintercept=0.85, color="black", linetype="dashed") +
  geom_hline(yintercept=0.60, color="black", linetype="dashed") +
  theme(axis.text.x = element_text(angle = 45, hjust = 1), axis.title.x=element_blank(),
        panel.background = element_blank(),
  )

  return(p)
}


#' sc_cluster_res_stability
#'
#' For a vector of cluster resolutions, compute clusters on a seurat object.
#' Sample the seurat object for a certain number of cells, compute clusters again,
#' and compare the overlapping cells. This intersection is a proxy for stability
#' of the cluster.
#' @param seurat_obj A Seurat object where FindClusters has been called.
#' @param cluster_resolutions A numeric vector of cluster resolutions to test.
#' @param ... parameters to pass onto sc_cluster_stability
#' @import Seurat
#' @import tidyverse
#' @keywords clustering
#' @export
#' @examples
#' resm <- sc_cluster_res_stability(NucSeq.oligo, cluster_resolutions, n_cells=n_cells, n_bootstrap=n_bootstrap, reduction="inmf", verbose=T)
sc_cluster_res_stability <- function(seurat_obj, cluster_resolutions, verbose=T, ...){
  resolution_list <- list()
  for(i in 1:length(cluster_resolutions)){
    if(verbose){print(paste("resolution:", cluster_resolutions[i]))}
    seurat_obj <- FindClusters(seurat_obj, resolution = cluster_resolutions[i], n.start=10, verbose=verbose)
    resolution_list[[i]] <- sc_cluster_stability(seurat_obj, cluster_res = cluster_resolutions[i], ...)
  }
  # this line adds a bunch of NAs so we have every row as the same length
  res_matrix <- sapply(resolution_list, function(x){length(x) <- max(sapply(resolution_list, length)); return(x)})
  colnames(res_matrix) <- cluster_resolutions
  return(res_matrix)
}


#' cluster_res_stability_heatmap
#'
#' Plot a heatmap summarizing the output of sc_cluster_res_stability. Clusters
#' with stability values below 0.6 are considered unstable. Values above 0.85 are
#' stable enough to be considered real clusters.
#' @param res_matrix The output of sc_cluster_res_stability
#' @import tidyverse
#' @import scales
#' @importFrom reshape2 melt
#' @keywords plotting
#' @export
#' @examples
#' resm <- sc_cluster_res_stability(NucSeq.oligo, cluster_resolutions, n_cells=n_cells, n_bootstrap=n_bootstrap, reduction="inmf", verbose=T)
cluster_res_stability_heatmap <- function(res_matrix, midpoint=0.65){
  p <- ggplot(data=melt(res_matrix), aes(x=Var1, y=Var2, fill=value)) +
    geom_tile(aes(fill = value)) +
    geom_text(aes(label = round(value, 2))) +
    scale_fill_gradientn(
       colors=c("blue","white","red"),
       values=rescale(c(0,midpoint,1)),
       limits=c(0,1), na.value="white"
    ) +
    scale_y_continuous(breaks=as.numeric(colnames(res_matrix))) +
    scale_x_continuous(breaks=1:dim(res_matrix)[1]-1) +
    labs(x="cluster", y="resolution") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)) +
    theme(panel.border = element_blank(), panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(), axis.line = element_blank())

    return(p)
}
smorabit/scicat documentation built on July 23, 2020, 3:57 a.m.