R/clustering.R

Defines functions RunModularityClustering RunLeiden NNHelper NNdist MultiModalNN GroupSingletons FindModalityWeights CreateAnn ComputeSNNwidth AnnoySearch AnnoyBuildIndex AnnoyNN FindNeighbors.Seurat FindNeighbors.dist FindNeighbors.Assay FindNeighbors.default FindClusters.Seurat FindClusters.default PredictAssay FindSubCluster FindMultiModalNeighbors

Documented in FindClusters.default FindClusters.Seurat FindMultiModalNeighbors FindNeighbors.Assay FindNeighbors.default FindNeighbors.dist FindNeighbors.Seurat FindSubCluster PredictAssay

#' @include generics.R
#'
NULL

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Functions
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#' Construct weighted nearest neighbor graph
#'
#' This function will construct a weighted nearest neighbor (WNN) graph. For
#' each cell, we identify the nearest neighbors based on a weighted combination
#' of two modalities. Takes as input two dimensional reductions, one computed
#' for each modality.Other parameters are listed for debugging, but can be left
#' as default values.
#'
#' @param object A Seurat object
#' @param reduction.list A list of two dimensional reductions, one for each of
#' the modalities to be integrated
#' @param dims.list A list containing the dimensions for each reduction to use
#' @param k.nn the number of multimodal neighbors to compute. 20 by default
#' @param l2.norm Perform L2 normalization on the cell embeddings after
#' dimensional reduction. TRUE by default.
#' @param knn.graph.name Multimodal knn graph name
#' @param snn.graph.name Multimodal snn graph name
#' @param weighted.nn.name Multimodal neighbor object name
#' @param modality.weight.name Variable name to store modality weight in object
#' meta data
#' @param knn.range The number of approximate neighbors to compute
#' @param prune.SNN Cutoff not to discard edge in SNN graph
#' @param sd.scale  The scaling factor for kernel width. 1 by default
#' @param cross.contant.list Constant used to avoid divide-by-zero errors. 1e-4
#' by default
#' @param smooth Smoothing modality score across each individual modality
#' neighbors. FALSE by default
#' @param return.intermediate Store intermediate results in misc
#' @param modality.weight A \code{\link{ModalityWeights}} object generated by
#' \code{FindModalityWeights}
#' @param verbose Print progress bars and output
#'
#' @return Seurat object containing a nearest-neighbor object, KNN graph, and
#' SNN graph - each based on a weighted combination of modalities.
#' @concept clustering
#' @export
#'
FindMultiModalNeighbors  <- function(
  object,
  reduction.list,
  dims.list,
  k.nn = 20,
  l2.norm = TRUE,
  knn.graph.name = "wknn",
  snn.graph.name = "wsnn",
  weighted.nn.name = "weighted.nn",
  modality.weight.name = NULL,
  knn.range = 200,
  prune.SNN = 1/15,
  sd.scale = 1,
  cross.contant.list = NULL,
  smooth = FALSE,
  return.intermediate = FALSE,
  modality.weight = NULL,
  verbose = TRUE
) {
  cross.contant.list <- cross.contant.list %||%
    as.list(x = rep(x = 1e-4, times = length(x = reduction.list)))
  if (is.null(x = modality.weight)) {
    if (verbose) {
      message("Calculating cell-specific modality weights")
    }
    modality.weight <- FindModalityWeights(
      object = object,
      reduction.list = reduction.list,
      dims.list = dims.list,
      k.nn = k.nn,
      sd.scale = sd.scale,
      l2.norm = l2.norm,
      cross.contant.list = cross.contant.list,
      smooth = smooth,
      verbose = verbose
    )
  }
  modality.weight.name <- modality.weight.name %||% paste0(modality.weight@modality.assay, ".weight")
  modality.assay <- slot(object = modality.weight, name = "modality.assay")
  if (length(modality.weight.name) != length(reduction.list)) {
    warning("The number of provided modality.weight.name is not equal to the number of modalities. ",
            paste(paste0(modality.assay, ".weight"), collapse = " "),
            " are used to store the modality weights" )
    modality.weight.name <- paste0(modality.assay, ".weight")
  }
  first.assay <- modality.assay[1]
  weighted.nn <- MultiModalNN(
    object = object,
    k.nn = k.nn,
    modality.weight = modality.weight,
    knn.range = knn.range,
    verbose = verbose
  )
  select_nn <- Indices(object = weighted.nn)
  select_nn_dist <- Distances(object = weighted.nn)
  # compute KNN graph
  if (verbose) {
    message("Constructing multimodal KNN graph")
  }
  j <- as.numeric(x = t(x = select_nn ))
  i <- ((1:length(x = j)) - 1) %/% k.nn + 1
  nn.matrix <- sparseMatrix(
    i = i,
    j = j,
    x = 1,
    dims = c(ncol(x = object), ncol(x = object))
  )
  diag(x = nn.matrix) <- 1
  rownames(x = nn.matrix) <-  colnames(x = nn.matrix) <- colnames(x = object)
  nn.matrix <- nn.matrix + t(x = nn.matrix) - t(x = nn.matrix) * nn.matrix
  nn.matrix <- as.Graph(x = nn.matrix)
  slot(object = nn.matrix, name = "assay.used") <- first.assay
  object[[knn.graph.name]] <- nn.matrix

  # compute SNN graph
  if (verbose) {
    message("Constructing multimodal SNN graph")
  }
  snn.matrix <- ComputeSNN(nn_ranked = select_nn, prune = prune.SNN)
  rownames(x = snn.matrix) <- colnames(x = snn.matrix) <- Cells(x = object)
  snn.matrix <- as.Graph(x = snn.matrix )
  slot(object = snn.matrix, name = "assay.used") <- first.assay
  object[[snn.graph.name]] <- snn.matrix

  # add neighbors and modality weights
  object[[weighted.nn.name]] <- weighted.nn
  for (m in 1:length(x = modality.weight.name)) {
    object[[modality.weight.name[[m]]]] <- slot(
      object = modality.weight,
      name = "modality.weight.list"
    )[[m]]
  }
  # add command log
  modality.weight.command <- slot(object = modality.weight, name = "command")
  slot(object = modality.weight.command, name = "assay.used") <- first.assay
  modality.weight.command.name <- slot(object = modality.weight.command, name = "name")
  object[[modality.weight.command.name]] <- modality.weight.command
  command <- LogSeuratCommand(object = object, return.command = TRUE)
  slot(object = command, name = "params")$modality.weight <- NULL
  slot(object = command, name = "assay.used") <- first.assay
  command.name <- slot(object = command, name = "name")
  object[[command.name]] <- command

  if (return.intermediate) {
    Misc(object = object, slot = "modality.weight") <- modality.weight
  }
  return (object)
}

#' Find subclusters under one cluster
#'
#' @inheritParams FindClusters
#' @param cluster the cluster to be sub-clustered
#' @param subcluster.name the name of sub cluster added in the meta.data
#'
#' @return return a object with sub cluster labels in the sub-cluster.name variable
#' @concept clustering
#' @export
#'
FindSubCluster <- function(
  object,
  cluster,
  graph.name,
  subcluster.name = "sub.cluster",
  resolution = 0.5,
  algorithm = 1
) {
  sub.cell <- WhichCells(object = object, idents = cluster)
  sub.graph <- as.Graph(x = object[[graph.name]][sub.cell, sub.cell])
  sub.clusters <- FindClusters(
    object = sub.graph,
    resolution = resolution,
    algorithm = algorithm
  )
  sub.clusters[, 1] <- paste(cluster,  sub.clusters[, 1], sep = "_")
  object[[subcluster.name]] <- as.character(x = Idents(object = object))
  object[[subcluster.name]][sub.cell, ] <- sub.clusters[, 1]
  return(object)
}

#' Predict value from nearest neighbors
#'
#' This function will predict expression or cell embeddings from its k nearest
#' neighbors index. For each cell, it will average its k neighbors value to get
#' its new imputed value. It can average expression value in assays and cell
#' embeddings from dimensional reductions.
#'
#' @param object The object used to calculate knn
#' @param nn.idx k near neighbour indices. A cells x k matrix.
#' @param assay Assay used for prediction
#' @param reduction Cell embedding of the reduction used for prediction
#' @param dims Number of dimensions of cell embedding
#' @param return.assay Return an assay or a predicted matrix
#' @param slot slot used for prediction
#' @param features features used for prediction
#' @param mean.function the function used to calculate row mean
#' @param seed Sets the random seed to check if the nearest neighbor is query
#' cell
#' @param verbose Print progress
#'
#' @return return an assay containing predicted expression value in the data
#' slot
#' @concept integration
#' @export
#'
PredictAssay <- function(
  object,
  nn.idx,
  assay,
  reduction = NULL,
  dims = NULL,
  return.assay = TRUE,
  slot = "scale.data",
  features = NULL,
  mean.function = rowMeans,
  seed = 4273,
  verbose = TRUE
){
  if (!inherits(x = mean.function, what = 'function')) {
    stop("'mean.function' must be a function")
  }
  if (is.null(x = reduction)) {
    reference.data <- GetAssayData(
      object = object,
      assay = assay,
      slot = slot
    )
    features <- features %||% VariableFeatures(object = object[[assay]])
    if (length(x = features) == 0) {
      features <- rownames(x = reference.data)
      if (verbose) {
        message("VariableFeatures are empty in the ", assay,
                " assay, features in the ", slot, " slot will be used" )
      }
    }
    reference.data <- reference.data[features, , drop = FALSE]
  } else {
    if (is.null(x = dims)) {
      stop("dims is empty")
    }
    reference.data <- t(x = Embeddings(object = object, reduction = reduction)[, dims])
  }
  set.seed(seed = seed)
  nn.check <- sample(x = 1:nrow(x = nn.idx), size = min(50, nrow(x = nn.idx)))
  if (all(nn.idx[nn.check, 1] == nn.check)) {
    if(verbose){
      message("The nearest neighbor is the query cell itself, and it will not be used for prediction")
    }
    nn.idx <- nn.idx[,-1]
  }
  predicted <- apply(
    X = nn.idx,
    MARGIN = 1,
    FUN = function(x) mean.function(reference.data[, x] )
  )
  colnames(x = predicted) <- Cells(x = object)
  if (return.assay) {
    predicted.assay <- CreateAssayObject(data = predicted, check.matrix = FALSE)
    return (predicted.assay)
  } else {
    return (predicted)
  }
}

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Methods for Seurat-defined generics
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#' @importFrom pbapply pblapply
#' @importFrom future.apply future_lapply
#' @importFrom future nbrOfWorkers
#'
#' @param modularity.fxn Modularity function (1 = standard; 2 = alternative).
#' @param initial.membership,node.sizes Parameters to pass to the Python leidenalg function.
#' @param resolution Value of the resolution parameter, use a value above
#' (below) 1.0 if you want to obtain a larger (smaller) number of communities.
#' @param algorithm Algorithm for modularity optimization (1 = original Louvain
#' algorithm; 2 = Louvain algorithm with multilevel refinement; 3 = SLM
#' algorithm; 4 = Leiden algorithm). Leiden requires the leidenalg python.
#' @param method Method for running leiden (defaults to matrix which is fast for small datasets).
#' Enable method = "igraph" to avoid casting large data to a dense matrix.
#' @param n.start Number of random starts.
#' @param n.iter Maximal number of iterations per random start.
#' @param random.seed Seed of the random number generator.
#' @param group.singletons Group singletons into nearest cluster. If FALSE, assign all singletons to
#' a "singleton" group
#' @param temp.file.location Directory where intermediate files will be written.
#' Specify the ABSOLUTE path.
#' @param edge.file.name Edge file to use as input for modularity optimizer jar.
#' @param verbose Print output
#'
#' @rdname FindClusters
#' @concept clustering
#' @export
#'
FindClusters.default <- function(
  object,
  modularity.fxn = 1,
  initial.membership = NULL,
  node.sizes = NULL,
  resolution = 0.8,
  method = "matrix",
  algorithm = 1,
  n.start = 10,
  n.iter = 10,
  random.seed = 0,
  group.singletons = TRUE,
  temp.file.location = NULL,
  edge.file.name = NULL,
  verbose = TRUE,
  ...
) {
  CheckDots(...)
  if (is.null(x = object)) {
    stop("Please provide an SNN graph")
  }
  if (tolower(x = algorithm) == "louvain") {
    algorithm <- 1
  }
  if (tolower(x = algorithm) == "leiden") {
    algorithm <- 4
  }
  if (nbrOfWorkers() > 1) {
    clustering.results <- future_lapply(
      X = resolution,
      FUN = function(r) {
        if (algorithm %in% c(1:3)) {
          ids <- RunModularityClustering(
            SNN = object,
            modularity = modularity.fxn,
            resolution = r,
            algorithm = algorithm,
            n.start = n.start,
            n.iter = n.iter,
            random.seed = random.seed,
            print.output = verbose,
            temp.file.location = temp.file.location,
            edge.file.name = edge.file.name
          )
        } else if (algorithm == 4) {
          ids <- RunLeiden(
            object = object,
            method = method,
            partition.type = "RBConfigurationVertexPartition",
            initial.membership = initial.membership,
            node.sizes = node.sizes,
            resolution.parameter = r,
            random.seed = random.seed,
            n.iter = n.iter
          )
        } else {
          stop("algorithm not recognised, please specify as an integer or string")
        }
        names(x = ids) <- colnames(x = object)
        ids <- GroupSingletons(ids = ids, SNN = object, verbose = verbose)
        results <- list(factor(x = ids))
        names(x = results) <- paste0('res.', r)
        return(results)
      }
    )
    clustering.results <- as.data.frame(x = clustering.results)
  } else {
    clustering.results <- data.frame(row.names = colnames(x = object))
    for (r in resolution) {
      if (algorithm %in% c(1:3)) {
        ids <- RunModularityClustering(
          SNN = object,
          modularity = modularity.fxn,
          resolution = r,
          algorithm = algorithm,
          n.start = n.start,
          n.iter = n.iter,
          random.seed = random.seed,
          print.output = verbose,
          temp.file.location = temp.file.location,
          edge.file.name = edge.file.name)
      } else if (algorithm == 4) {
        ids <- RunLeiden(
          object = object,
          method = method,
          partition.type = "RBConfigurationVertexPartition",
          initial.membership = initial.membership,
          node.sizes = node.sizes,
          resolution.parameter = r,
          random.seed = random.seed,
          n.iter = n.iter
        )
      } else {
        stop("algorithm not recognised, please specify as an integer or string")
      }
      names(x = ids) <- colnames(x = object)
      ids <- GroupSingletons(ids = ids, SNN = object, group.singletons = group.singletons, verbose = verbose)
      clustering.results[, paste0("res.", r)] <- factor(x = ids)
    }
  }
  return(clustering.results)
}

#' @importFrom methods is
#'
#' @param graph.name Name of graph to use for the clustering algorithm
#' @param cluster.name Name of output clusters
#'
#' @rdname FindClusters
#' @export
#' @concept clustering
#' @method FindClusters Seurat
#'
FindClusters.Seurat <- function(
  object,
  graph.name = NULL,
  cluster.name = NULL,
  modularity.fxn = 1,
  initial.membership = NULL,
  node.sizes = NULL,
  resolution = 0.8,
  method = "matrix",
  algorithm = 1,
  n.start = 10,
  n.iter = 10,
  random.seed = 0,
  group.singletons = TRUE,
  temp.file.location = NULL,
  edge.file.name = NULL,
  verbose = TRUE,
  ...
) {
  CheckDots(...)
  graph.name <- graph.name %||% paste0(DefaultAssay(object = object), "_snn")
  if (!graph.name %in% names(x = object)) {
    stop("Provided graph.name not present in Seurat object")
  }
  if (!is(object = object[[graph.name]], class2 = "Graph")) {
    stop("Provided graph.name does not correspond to a graph object.")
  }
  clustering.results <- FindClusters(
    object = object[[graph.name]],
    modularity.fxn = modularity.fxn,
    initial.membership = initial.membership,
    node.sizes = node.sizes,
    resolution = resolution,
    method = method,
    algorithm = algorithm,
    n.start = n.start,
    n.iter = n.iter,
    random.seed = random.seed,
    group.singletons = group.singletons,
    temp.file.location = temp.file.location,
    edge.file.name = edge.file.name,
    verbose = verbose,
    ...
  )
  cluster.name <- cluster.name %||%
    paste(
      graph.name,
      names(x = clustering.results),
      sep = '_'
    )
  names(x = clustering.results) <- cluster.name
  # object <- AddMetaData(object = object, metadata = clustering.results)
  # Idents(object = object) <- colnames(x = clustering.results)[ncol(x = clustering.results)]
  idents.use <- names(x = clustering.results)[ncol(x = clustering.results)]
  object[[]] <- clustering.results
  Idents(object = object, replace = TRUE) <- object[[idents.use, drop = TRUE]]
  levels <- levels(x = object)
  levels <- tryCatch(
    expr = as.numeric(x = levels),
    warning = function(...) {
      return(levels)
    },
    error = function(...) {
      return(levels)
    }
  )
  Idents(object = object) <- factor(x = Idents(object = object), levels = sort(x = levels))
  object[['seurat_clusters']] <- Idents(object = object)
  cmd <- LogSeuratCommand(object = object, return.command = TRUE)
  slot(object = cmd, name = 'assay.used') <- DefaultAssay(object = object[[graph.name]])
  object[[slot(object = cmd, name = 'name')]] <- cmd
  return(object)
}

#' @param query Matrix of data to query against object. If missing, defaults to
#' object.
#' @param distance.matrix Boolean value of whether the provided matrix is a
#' distance matrix; note, for objects of class \code{dist}, this parameter will
#' be set automatically
#' @param k.param Defines k for the k-nearest neighbor algorithm
#' @param return.neighbor Return result as \code{\link{Neighbor}} object. Not
#' used with distance matrix input.
#' @param compute.SNN also compute the shared nearest neighbor graph
#' @param prune.SNN Sets the cutoff for acceptable Jaccard index when
#' computing the neighborhood overlap for the SNN construction. Any edges with
#' values less than or equal to this will be set to 0 and removed from the SNN
#' graph. Essentially sets the stringency of pruning (0 --- no pruning, 1 ---
#' prune everything).
#' @param nn.method Method for nearest neighbor finding. Options include: rann,
#' annoy
#' @param annoy.metric Distance metric for annoy. Options include: euclidean,
#' cosine, manhattan, and hamming
#' @param n.trees More trees gives higher precision when using annoy approximate
#' nearest neighbor search
#' @param nn.eps Error bound when performing nearest neighbor seach using RANN;
#' default of 0.0 implies exact nearest neighbor search
#' @param verbose Whether or not to print output to the console
#' @param l2.norm Take L2Norm of the data
#' @param cache.index Include cached index in returned Neighbor object
#' (only relevant if return.neighbor = TRUE)
#' @param index Precomputed index. Useful if querying new data against existing
#' index to avoid recomputing.
#'
#' @importFrom RANN nn2
#' @importFrom methods as
#'
#' @rdname FindNeighbors
#' @export
#' @concept clustering
#' @method FindNeighbors default
#'
FindNeighbors.default <- function(
  object,
  query = NULL,
  distance.matrix = FALSE,
  k.param = 20,
  return.neighbor = FALSE,
  compute.SNN = !return.neighbor,
  prune.SNN = 1/15,
  nn.method = "annoy",
  n.trees = 50,
  annoy.metric = "euclidean",
  nn.eps = 0,
  verbose = TRUE,
  l2.norm = FALSE,
  cache.index = FALSE,
  index = NULL,
  ...
) {
  CheckDots(...)
  if (is.null(x = dim(x = object))) {
    warning(
      "Object should have two dimensions, attempting to coerce to matrix",
      call. = FALSE
    )
    object <- as.matrix(x = object)
  }
  if (is.null(rownames(x = object))) {
    stop("Please provide rownames (cell names) with the input object")
  }
  n.cells <- nrow(x = object)
  if (n.cells < k.param) {
    warning(
      "k.param set larger than number of cells. Setting k.param to number of cells - 1.",
      call. = FALSE
    )
    k.param <- n.cells - 1
  }
  if (l2.norm) {
    object <- L2Norm(mat = object)
    query <- query %iff% L2Norm(mat = query)
  }
  query <- query %||% object
  # find the k-nearest neighbors for each single cell
  if (!distance.matrix) {
    if (verbose) {
      if (return.neighbor) {
        message("Computing nearest neighbors")
      } else {
        message("Computing nearest neighbor graph")
      }
    }
    nn.ranked <- NNHelper(
      data = object,
      query = query,
      k = k.param,
      method = nn.method,
      n.trees = n.trees,
      searchtype = "standard",
      eps = nn.eps,
      metric = annoy.metric,
      cache.index = cache.index,
      index = index
    )
    if (return.neighbor) {
      if (compute.SNN) {
        warning("The SNN graph is not computed if return.neighbor is TRUE.", call. = FALSE)
      }
      return(nn.ranked)
    }
    nn.ranked <- Indices(object = nn.ranked)
  } else {
    if (verbose) {
      message("Building SNN based on a provided distance matrix")
    }
    knn.mat <- matrix(data = 0, ncol = k.param, nrow = n.cells)
    knd.mat <- knn.mat
    for (i in 1:n.cells) {
      knn.mat[i, ] <- order(object[i, ])[1:k.param]
      knd.mat[i, ] <- object[i, knn.mat[i, ]]
    }
    nn.ranked <- knn.mat[, 1:k.param]
  }
  # convert nn.ranked into a Graph
  j <- as.numeric(x = t(x = nn.ranked))
  i <- ((1:length(x = j)) - 1) %/% k.param + 1
  nn.matrix <- as(object = sparseMatrix(i = i, j = j, x = 1, dims = c(nrow(x = object), nrow(x = object))), Class = "Graph")
  rownames(x = nn.matrix) <- rownames(x = object)
  colnames(x = nn.matrix) <- rownames(x = object)
  neighbor.graphs <- list(nn = nn.matrix)
  if (compute.SNN) {
    if (verbose) {
      message("Computing SNN")
    }
    snn.matrix <- ComputeSNN(
      nn_ranked = nn.ranked,
      prune = prune.SNN
    )
    rownames(x = snn.matrix) <- rownames(x = object)
    colnames(x = snn.matrix) <- rownames(x = object)
    snn.matrix <- as.Graph(x = snn.matrix)
    neighbor.graphs[["snn"]] <- snn.matrix
  }
  return(neighbor.graphs)
}

#' @rdname FindNeighbors
#' @export
#' @concept clustering
#' @method FindNeighbors Assay
#'
FindNeighbors.Assay <- function(
  object,
  features = NULL,
  k.param = 20,
  return.neighbor = FALSE,
  compute.SNN = !return.neighbor,
  prune.SNN = 1/15,
  nn.method = "annoy",
  n.trees = 50,
  annoy.metric = "euclidean",
  nn.eps = 0,
  verbose = TRUE,
  l2.norm = FALSE,
  cache.index = FALSE,
  ...
) {
  CheckDots(...)
  features <- features %||% VariableFeatures(object = object)
  data.use <- t(x = GetAssayData(object = object, slot = "data")[features, ])
  neighbor.graphs <- FindNeighbors(
    object = data.use,
    k.param = k.param,
    compute.SNN = compute.SNN,
    prune.SNN = prune.SNN,
    nn.method = nn.method,
    n.trees = n.trees,
    annoy.metric = annoy.metric,
    nn.eps = nn.eps,
    verbose = verbose,
    l2.norm = l2.norm,
    return.neighbor = return.neighbor,
    cache.index = cache.index,
    ...
  )
  return(neighbor.graphs)
}

#' @rdname FindNeighbors
#' @export
#' @concept clustering
#' @method FindNeighbors dist
#'
FindNeighbors.dist <- function(
  object,
  k.param = 20,
  return.neighbor = FALSE,
  compute.SNN = !return.neighbor,
  prune.SNN = 1/15,
  nn.method = "annoy",
  n.trees = 50,
  annoy.metric = "euclidean",
  nn.eps = 0,
  verbose = TRUE,
  l2.norm = FALSE,
  cache.index = FALSE,
  ...
) {
  CheckDots(...)
  return(FindNeighbors(
    object = as.matrix(x = object),
    distance.matrix = TRUE,
    k.param = k.param,
    compute.SNN = compute.SNN,
    prune.SNN = prune.SNN,
    nn.eps = nn.eps,
    nn.method = nn.method,
    n.trees = n.trees,
    annoy.metric = annoy.metric,
    verbose = verbose,
    l2.norm = l2.norm,
    return.neighbor = return.neighbor,
    cache.index = cache.index,
    ...
  ))
}

#' @param assay Assay to use in construction of (S)NN; used only when \code{dims}
#' is \code{NULL}
#' @param features Features to use as input for building the (S)NN; used only when
#' \code{dims} is \code{NULL}
#' @param reduction Reduction to use as input for building the (S)NN
#' @param dims Dimensions of reduction to use as input
#' @param do.plot Plot SNN graph on tSNE coordinates
#' @param graph.name Optional naming parameter for stored (S)NN graph
#' (or Neighbor object, if return.neighbor = TRUE). Default is assay.name_(s)nn.
#' To store both the neighbor graph and the shared nearest neighbor (SNN) graph,
#' you must supply a vector containing two names to the \code{graph.name}
#' parameter. The first element in the vector will be used to store the nearest
#' neighbor (NN) graph, and the second element used to store the SNN graph. If
#' only one name is supplied, only the NN graph is stored.
#'
#' @importFrom igraph graph.adjacency plot.igraph E
#'
#' @rdname FindNeighbors
#' @export
#' @concept clustering
#' @method FindNeighbors Seurat
#'
FindNeighbors.Seurat <- function(
  object,
  reduction = "pca",
  dims = 1:10,
  assay = NULL,
  features = NULL,
  k.param = 20,
  return.neighbor = FALSE,
  compute.SNN = !return.neighbor,
  prune.SNN = 1/15,
  nn.method = "annoy",
  n.trees = 50,
  annoy.metric = "euclidean",
  nn.eps = 0,
  verbose = TRUE,
  do.plot = FALSE,
  graph.name = NULL,
  l2.norm = FALSE,
  cache.index = FALSE,
  ...
) {
  CheckDots(...)
  if (!is.null(x = dims)) {
    assay <- DefaultAssay(object = object[[reduction]])
    data.use <- Embeddings(object = object[[reduction]])
    if (max(dims) > ncol(x = data.use)) {
      stop("More dimensions specified in dims than have been computed")
    }
    data.use <- data.use[, dims]
    neighbor.graphs <- FindNeighbors(
      object = data.use,
      k.param = k.param,
      compute.SNN = compute.SNN,
      prune.SNN = prune.SNN,
      nn.method = nn.method,
      n.trees = n.trees,
      annoy.metric = annoy.metric,
      nn.eps = nn.eps,
      verbose = verbose,
      l2.norm = l2.norm,
      return.neighbor = return.neighbor,
      cache.index = cache.index,
      ...
    )
  } else {
    assay <- assay %||% DefaultAssay(object = object)
    neighbor.graphs <- FindNeighbors(
      object = object[[assay]],
      features = features,
      k.param = k.param,
      compute.SNN = compute.SNN,
      prune.SNN = prune.SNN,
      nn.method = nn.method,
      n.trees = n.trees,
      annoy.metric = annoy.metric,
      nn.eps = nn.eps,
      verbose = verbose,
      l2.norm = l2.norm,
      return.neighbor = return.neighbor,
      cache.index = cache.index,
      ...
    )
  }
  if (length(x = neighbor.graphs) == 1) {
    neighbor.graphs <- list(nn = neighbor.graphs)
  }
  graph.name <- graph.name %||%
    if (return.neighbor) {
      paste0(assay, ".", names(x = neighbor.graphs))
    } else {
      paste0(assay, "_", names(x = neighbor.graphs))
    }
  if (length(x = graph.name) == 1) {
    message("Only one graph name supplied, storing nearest-neighbor graph only")
  }
  for (ii in 1:length(x = graph.name)) {
    if (inherits(x = neighbor.graphs[[ii]], what = "Graph")) {
      DefaultAssay(object = neighbor.graphs[[ii]]) <- assay
    }
    object[[graph.name[[ii]]]] <- neighbor.graphs[[ii]]
  }
  if (do.plot) {
    if (!"tsne" %in% names(x = object@reductions)) {
      warning("Please compute a tSNE for SNN visualization. See RunTSNE().")
    } else {
      if (nrow(x = Embeddings(object = object[["tsne"]])) != ncol(x = object)) {
        warning("Please compute a tSNE for SNN visualization. See RunTSNE().")
      } else {
        net <- graph.adjacency(
          adjmatrix = as.matrix(x = neighbor.graphs[[2]]),
          mode = "undirected",
          weighted = TRUE,
          diag = FALSE
        )
        plot.igraph(
          x = net,
          layout = as.matrix(x = Embeddings(object = object[["tsne"]])),
          edge.width = E(graph = net)$weight,
          vertex.label = NA,
          vertex.size = 0
        )
      }
    }
  }
  object <- LogSeuratCommand(object = object)
  return(object)
}

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Internal
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

# Run annoy
#
# @param data Data to build the index with
# @param query A set of data to be queried against data
# @param metric Distance metric; can be one of "euclidean", "cosine", "manhattan",
# "hamming"
# @param n.trees More trees gives higher precision when querying
# @param k Number of neighbors
# @param search.k During the query it will inspect up to search_k nodes which
# gives you a run-time tradeoff between better accuracy and speed.
# @param include.distance Include the corresponding distances
# @param index optional index object, will be recomputed if not provided
#
AnnoyNN <- function(data,
                    query = data,
                    metric = "euclidean",
                    n.trees = 50,
                    k,
                    search.k = -1,
                    include.distance = TRUE,
                    index = NULL
) {
  idx <- index %||% AnnoyBuildIndex(
    data = data,
    metric = metric,
    n.trees = n.trees)
  nn <- AnnoySearch(
    index = idx,
    query = query,
    k = k,
    search.k = search.k,
    include.distance = include.distance)
  nn$idx <- idx
  nn$alg.info <- list(metric = metric, ndim = ncol(x = data))
  return(nn)
}

# Build the annoy index
#
# @param data Data to build the index with
# @param metric Distance metric; can be one of "euclidean", "cosine", "manhattan",
# "hamming"
# @param n.trees More trees gives higher precision when querying
#
#' @importFrom RcppAnnoy AnnoyEuclidean AnnoyAngular AnnoyManhattan AnnoyHamming
#
AnnoyBuildIndex <- function(data, metric = "euclidean", n.trees = 50) {
  f <- ncol(x = data)
  a <- switch(
    EXPR = metric,
    "euclidean" =  new(Class = RcppAnnoy::AnnoyEuclidean, f),
    "cosine" = new(Class = RcppAnnoy::AnnoyAngular, f),
    "manhattan" = new(Class = RcppAnnoy::AnnoyManhattan, f),
    "hamming" = new(Class = RcppAnnoy::AnnoyHamming, f),
    stop ("Invalid metric")
  )
  for (ii in seq(nrow(x = data))) {
    a$addItem(ii - 1, data[ii, ])
  }
  a$build(n.trees)
  return(a)
}

# Search an Annoy approximate nearest neighbor index
#
# @param Annoy index, built with AnnoyBuildIndex
# @param query A set of data to be queried against the index
# @param k Number of neighbors
# @param search.k During the query it will inspect up to search_k nodes which
# gives you a run-time tradeoff between better accuracy and speed.
# @param include.distance Include the corresponding distances in the result
#
# @return A list with 'nn.idx' (for each element in 'query', the index of the
# nearest k elements in the index) and 'nn.dists' (the distances of the nearest
# k elements)
#
#' @importFrom future plan
#' @importFrom future.apply future_lapply
#
AnnoySearch <- function(index, query, k, search.k = -1, include.distance = TRUE) {
  n <- nrow(x = query)
  idx <- matrix(nrow = n,  ncol = k)
  dist <- matrix(nrow = n, ncol = k)
  convert <- methods::is(index, "Rcpp_AnnoyAngular")
  if (!inherits(x = plan(), what = "multicore")) {
    oplan <- plan(strategy = "sequential")
    on.exit(plan(oplan), add = TRUE)
  }
  res <- future_lapply(X = 1:n, FUN = function(x) {
    res <- index$getNNsByVectorList(query[x, ], k, search.k, include.distance)
    # Convert from Angular to Cosine distance
    if (convert) {
      res$dist <- 0.5 * (res$dist * res$dist)
    }
    list(res$item + 1, res$distance)
  })
  for (i in 1:n) {
    idx[i, ] <- res[[i]][[1]]
    if (include.distance) {
      dist[i, ] <- res[[i]][[2]]
    }
  }
  return(list(nn.idx = idx, nn.dists = dist))
}

# Calculate mean distance of the farthest neighbors from SNN graph
#
# This function will compute the average distance of the farthest k.nn
# neighbors with the lowest nonzero SNN edge weight. First, for each cell it
# finds the k.nn neighbors with the smallest edge weight. If there are multiple
# cells with the same edge weight at the k.nn-th index, consider all of those
# cells in the next step. Next, it computes the euclidean distance to all k.nn
# cells in the space defined by the embeddings matrix and returns the average
# distance to the farthest k.nn cells.
#
# @param snn.graph An SNN graph
# @param embeddings The cell embeddings used to calculate neighbor distances
# @param k.nn The number of neighbors to calculate
# @param l2.norm Perform L2 normalization on the cell embeddings
# @param nearest.dist The vector of distance to the nearest neighbors to
# subtract off from distance calculations
#
#
ComputeSNNwidth <- function(
  snn.graph,
  embeddings,
  k.nn,
  l2.norm = TRUE,
  nearest.dist = NULL
) {
  if (l2.norm) {
    embeddings <- L2Norm(mat = embeddings)
  }
  nearest.dist <- nearest.dist %||% rep(x = 0, times = ncol(x = snn.graph))
  if (length(x = nearest.dist) != ncol(x = snn.graph)) {
    stop("Please provide a vector for nearest.dist that has as many elements as",
         " there are columns in the snn.graph (", ncol(x = snn.graph), ").")
  }
  snn.width <- SNN_SmallestNonzero_Dist(
    snn = snn.graph,
    mat = embeddings,
    n = k.nn,
    nearest_dist = nearest.dist
  )
  return (snn.width)
}

# Create an Annoy index
#
# @note Function exists because it's not exported from \pkg{uwot}
#
# @param name Distance metric name
# @param ndim Number of dimensions
#
# @return An nn index object
#
#' @importFrom methods new
#' @importFrom RcppAnnoy AnnoyAngular AnnoyManhattan AnnoyEuclidean AnnoyHamming
#
CreateAnn <- function(name, ndim) {
  return(switch(
    EXPR = name,
    cosine = new(Class = AnnoyAngular, ndim),
    manhattan = new(Class = AnnoyManhattan, ndim),
    euclidean = new(Class = AnnoyEuclidean, ndim),
    hamming = new(Class = AnnoyHamming, ndim),
    stop("BUG: unknown Annoy metric '", name, "'")
  ))
}

# Calculate modality weights
#
# This function calculates cell-specific modality weights which are used to
# in WNN analysis.
#' @inheritParams FindMultiModalNeighbors
# @param object A Seurat object
# @param snn.far.nn Use SNN farthest neighbors to calculate the kernel width
# @param s.nn How many SNN neighbors to use in kernel width
# @param sigma.idx Neighbor index used to calculate kernel width if snn.far.nn = FALSE
# @importFrom pbapply pblapply
# @return Returns a \code{ModalityWeights} object that can be used as input to
# \code{\link{FindMultiModalNeighbors}}
#
#' @importFrom pbapply pblapply
#
FindModalityWeights  <- function(
  object,
  reduction.list,
  dims.list,
  k.nn = 20,
  snn.far.nn = TRUE,
  s.nn = k.nn,
  prune.SNN = 0,
  l2.norm = TRUE,
  sd.scale = 1,
  query = NULL,
  cross.contant.list = NULL,
  sigma.idx = k.nn,
  smooth = FALSE,
  verbose = TRUE
) {
  my.lapply <- ifelse(
    test = verbose,
    yes = pblapply,
    no = lapply
  )
  cross.contant.list <- cross.contant.list %||%
    as.list(x = rep(x = 1e-4, times = length(x = reduction.list)))
  reduction.set <- unlist(x = reduction.list)
  names(x = reduction.list) <- names(x = dims.list) <-
    names(x = cross.contant.list) <- reduction.set
  embeddings.list <- lapply(
    X = reduction.list,
    FUN = function(r) Embeddings(object = object, reduction = r)[, dims.list[[r]]]
  )
  if (l2.norm) {
    embeddings.list.norm <- lapply(
      X = embeddings.list,
      FUN = function(embeddings) L2Norm(mat = embeddings)
    )
  } else {
    embeddings.list.norm <- embeddings.list
  }
  if (is.null(x = query)) {
    query.embeddings.list.norm <- embeddings.list.norm
    query <- object
  } else {
    if (snn.far.nn) {
      stop("query does not support using snn to find distant neighbors")
    }
    query.embeddings.list <- lapply(
      X = reduction.list,
      FUN = function(r) {
        Embeddings(object = query, reduction = r)[, dims.list[[r]]]
      }
    )
    if (l2.norm) {
      query.embeddings.list <- lapply(
        X = query.embeddings.list,
        FUN = function(embeddings) L2Norm(mat = embeddings)
      )
    }
    query.embeddings.list.norm <- query.embeddings.list
  }
  if (verbose) {
    message("Finding ", k.nn, " nearest neighbors for each modality.")
  }
  nn.list <- my.lapply(
    X = reduction.list,
    FUN = function(r) {
      nn.r <- NNHelper(
        data = embeddings.list.norm[[r]],
        query = query.embeddings.list.norm[[r]],
        k = max(k.nn, sigma.idx, s.nn),
        method = "annoy",
        metric = "euclidean"
      )
      return(nn.r)
    }
  )
  sigma.nn.list <- nn.list
  if (sigma.idx > k.nn || s.nn > k.nn) {
    nn.list <- lapply(
      X = nn.list,
      FUN = function(nn){
        slot(object = nn, name = "nn.idx") <- Indices(object = nn)[, 1:k.nn]
        slot(object = nn, name = "nn.dists") <- Distances(object = nn)[, 1:k.nn]
        return(nn)
      }
    )
  }
  nearest_dist <- lapply(X = reduction.list, FUN = function(r) Distances(object = nn.list[[r]])[, 2])
  within_impute <- list()
  cross_impute <- list()
  # Calculating within and cross modality distance
  for (r in reduction.set) {
    reduction.norm <- paste0(r, ".norm")
    object[[ reduction.norm ]] <- CreateDimReducObject(
      embeddings = embeddings.list.norm[[r]],
      key = paste0("norm", Key(object = object[[r]])),
      assay = DefaultAssay(object = object[[r]])
    )
    within_impute[[r]] <- PredictAssay(
      object = object,
      nn.idx =  Indices(object = nn.list[[r]]),
      reduction = reduction.norm,
      dims = 1:ncol(x = embeddings.list.norm[[r]]),
      verbose = FALSE,
      return.assay = FALSE
    )
    cross.modality <- setdiff(x = reduction.set, y = r)
    cross_impute[[r]] <- lapply(X = cross.modality, FUN = function(r2) {
      PredictAssay(
        object = object,
        nn.idx = Indices(object = nn.list[[r2]]),
        reduction = reduction.norm,
        dims = 1:ncol(x = embeddings.list.norm[[r]]),
        verbose = FALSE,
        return.assay = FALSE
      )
    }
    )
    names(x = cross_impute[[r]]) <- cross.modality
  }

  within_impute_dist <- lapply(
    X = reduction.list,
    FUN = function(r) {
      r_dist <- impute_dist(
        x = query.embeddings.list.norm[[r]],
        y = t(x = within_impute[[r]]),
        nearest.dist = nearest_dist[[r]]
      )
      return(r_dist)
    }
  )
  cross_impute_dist <- lapply(
    X = reduction.list,
    FUN = function(r) {
      r_dist <- sapply(setdiff(x = reduction.set, y = r), FUN = function(r2) {
        r2_dist <- impute_dist(
          x = query.embeddings.list.norm[[r]],
          y = t(x = cross_impute[[r]][[r2]]),
          nearest.dist = nearest_dist[[r]]
        )
        return( r2_dist)
      })
      return(r_dist)
    }
  )
  # calculate kernel width
  if (snn.far.nn) {
    if (verbose) {
      message("Calculating kernel bandwidths")
    }
    snn.graph.list <- lapply(
      X = sigma.nn.list,
      FUN = function(nn) {
        snn.matrix <- ComputeSNN(
          nn_ranked =  Indices(object = nn)[, 1:s.nn],
          prune = prune.SNN
        )
        colnames(x = snn.matrix) <- rownames(x = snn.matrix) <- Cells(x = object)
        return (snn.matrix)
      }
    )
    farthest_nn_dist <- my.lapply(
      X = 1:length(x = snn.graph.list),
      FUN = function(s) {
        distant_nn <- ComputeSNNwidth(
          snn.graph = snn.graph.list[[s]],
          k.nn = k.nn,
          l2.norm = FALSE,
          embeddings =  embeddings.list.norm[[s]],
          nearest.dist = nearest_dist[[s]]
        )
        return (distant_nn)
      }
    )
    names(x = farthest_nn_dist) <- unlist(x = reduction.list)
    modality_sd.list <- lapply(
      X = farthest_nn_dist,
      FUN =  function(sd)  sd * sd.scale
    )
  } else {
    if (verbose) {
      message("Calculating sigma by ", sigma.idx, "th neighbor")
    }
    modality_sd.list <- lapply(
      X = reduction.list ,
      FUN =  function(r) {
        rdist <- Distances(object = sigma.nn.list[[r]])[, sigma.idx] - nearest_dist[[r]]
        rdist <- rdist * sd.scale
        return (rdist)
      }
    )
  }
  # Calculating within and cross modality kernel, and modality weights
  within_impute_kernel <- lapply(
    X = reduction.list,
    FUN = function(r) {
      exp(-1 * (within_impute_dist[[r]] / modality_sd.list[[r]]) )
    }
  )
  cross_impute_kernel <- lapply(
    X = reduction.list,
    FUN = function(r) {
      exp(-1 * (cross_impute_dist[[r]] / modality_sd.list[[r]]) )
    }
  )
  params <- list(
    "reduction.list" = reduction.list,
    "dims.list" = dims.list,
    "l2.norm" = l2.norm,
    "k.nn" = k.nn,
    "sigma.idx" = sigma.idx,
    "snn.far.nn" = snn.far.nn ,
    "sigma.list" = modality_sd.list,
    "nearest.dist" = nearest_dist
  )
  modality_score <-  lapply(
    X = reduction.list,
    FUN = function(r) {
      score.r <- sapply(
        X = setdiff(x = reduction.set, y = r),
        FUN = function(r2) {
          score <- within_impute_kernel[[r]] / (cross_impute_kernel[[r]][, r2] + cross.contant.list[[r]])
          score <- MinMax(data = score, min = 0, max = 200)
          return(score)
        }
      )
      return(score.r)
    }
  )
  if (smooth) {
    modality_score <- lapply(
      X = reduction.list,
      FUN = function(r) {
        apply(
          X = Indices(object = nn.list[[r]]),
          MARGIN = 1,
          FUN = function(nn)  mean(x = modality_score[[r]][nn[-1]])
        )
      }
    )
  }
  all_modality_score <- rowSums(x = exp(x = Reduce(f = cbind, x = modality_score)))
  modality.weight <- lapply(
    X = modality_score,
    FUN = function(score_m) {
      rowSums(x = exp(x = score_m))/all_modality_score
    }
  )
  score.mat <- list(
    within_impute_dist = within_impute_dist,
    cross_impute_dist = cross_impute_dist,
    within_impute_kernel = within_impute_kernel,
    cross_impute_kernel = cross_impute_kernel,
    modality_score = modality_score
  )

  # unlist the input parameters
  command <- LogSeuratCommand(object = object, return.command = TRUE)
  command@params <- lapply(X = command@params, FUN = function(l) unlist(x = l))
  modality.assay <- sapply(
    X = reduction.list ,
    FUN = function (r) slot(object[[r]], name = "assay.used")
  )
  modality.weights.obj <- new(
    Class = "ModalityWeights",
    modality.weight.list = modality.weight,
    modality.assay = modality.assay,
    params = params,
    score.matrix = score.mat,
    command = command
  )
  return(modality.weights.obj)
}

# Group single cells that make up their own cluster in with the cluster they are
# most connected to.
#
# @param ids Named vector of cluster ids
# @param SNN SNN graph used in clustering
# @param group.singletons Group singletons into nearest cluster. If FALSE, assign all singletons to
# a "singleton" group
#
# @return Returns Seurat object with all singletons merged with most connected cluster
#
GroupSingletons <- function(ids, SNN, group.singletons = TRUE, verbose = TRUE) {
  # identify singletons
  singletons <- c()
  singletons <- names(x = which(x = table(ids) == 1))
  singletons <- intersect(x = unique(x = ids), singletons)
  if (!group.singletons) {
    ids[which(ids %in% singletons)] <- "singleton"
    return(ids)
  }
  # calculate connectivity of singletons to other clusters, add singleton
  # to cluster it is most connected to
  cluster_names <- as.character(x = unique(x = ids))
  cluster_names <- setdiff(x = cluster_names, y = singletons)
  connectivity <- vector(mode = "numeric", length = length(x = cluster_names))
  names(x = connectivity) <- cluster_names
  new.ids <- ids
  for (i in singletons) {
    i.cells <- names(which(ids == i))
    for (j in cluster_names) {
      j.cells <- names(which(ids == j))
      subSNN <- SNN[i.cells, j.cells]
      set.seed(1) # to match previous behavior, random seed being set in WhichCells
      if (is.object(x = subSNN)) {
        connectivity[j] <- sum(subSNN) / (nrow(x = subSNN) * ncol(x = subSNN))
      } else {
        connectivity[j] <- mean(x = subSNN)
      }
    }
    m <- max(connectivity, na.rm = T)
    mi <- which(x = connectivity == m, arr.ind = TRUE)
    closest_cluster <- sample(x = names(x = connectivity[mi]), 1)
    ids[i.cells] <- closest_cluster
  }
  if (length(x = singletons) > 0 && verbose) {
    message(paste(
      length(x = singletons),
      "singletons identified.",
      length(x = unique(x = ids)),
      "final clusters."
    ))
  }
  return(ids)
}

# Find multimodal neighbors
#
# @param object The object used to calculate knn
# @param query The query object when query and reference are different
# @param modality.weight A \code{\link{ModalityWeights}} object generated by
# \code{\link{FindModalityWeights}}
# @param modality.weight.list A list of modality weight value
# @param k.nn Number of nearest multimodal neighbors to compute
# @param reduction.list A list of reduction name
# @param dims.list A list of dimensions used for the reduction
# @param knn.range The number of approximate neighbors to compute
# @param kernel.power The power for the exponential kernel
# @param nearest.dist The list of distance to the nearest neighbors
# @param sigma.list The list of kernel width
# @param l2.norm Perform L2 normalization on the cell embeddings after
# dimensional reduction
# @param verbose Print output to the console
# @importFrom pbapply pblapply
# @return return a list containing nn index and nn multimodal distance
#
#' @importFrom methods new
#' @importClassesFrom SeuratObject Neighbor
#
MultiModalNN <- function(
  object,
  query = NULL,
  modality.weight = NULL,
  modality.weight.list = NULL,
  k.nn = NULL,
  reduction.list = NULL,
  dims.list = NULL,
  knn.range = 200,
  kernel.power = 1,
  nearest.dist = NULL,
  sigma.list = NULL,
  l2.norm =  NULL,
  verbose = TRUE
){
  my.lapply <- ifelse(
    test = verbose,
    yes = pblapply,
    no = lapply
  )
  k.nn <-  k.nn %||% slot(object = modality.weight, name = "params")$k.nn
  reduction.list <- reduction.list %||%
    slot(object = modality.weight, name = "params")$reduction.list
  dims.list = dims.list %||%
    slot(object = modality.weight, name = "params")$dims.list
  nearest.dist = nearest.dist %||%
    slot(object = modality.weight, name = "params")$nearest.dist
  sigma.list =sigma.list %||%
    slot(object = modality.weight, name = "params")$sigma.list
  l2.norm = l2.norm %||%
    slot(object = modality.weight, name = "params")$l2.norm
  modality.weight.value  <- modality.weight.list %||%
    slot(object = modality.weight, name = "modality.weight.list")
  names(x = modality.weight.value) <- unlist(x = reduction.list)
  if (inherits(x = object, what = "Seurat")) {
    reduction_embedding <- lapply(
      X = 1:length(x = reduction.list),
      FUN = function(x) {
        Embeddings(object = object, reduction = reduction.list[[x]])[, dims.list[[x]]]
      }
    )
  } else {
    reduction_embedding <- object
  }
  if (is.null(x = query)) {
    query.reduction_embedding <- reduction_embedding
    query <- object
  } else {
    if (inherits(x = object, what = "Seurat")) {
      query.reduction_embedding <- lapply(
        X = 1:length(x = reduction.list),
        FUN = function(x) {
          Embeddings(object = query, reduction = reduction.list[[x]] )[, dims.list[[x]]]
        }
      )
    } else {
      query.reduction_embedding <- query
    }
  }
  if (l2.norm) {
    query.reduction_embedding <- lapply(
      X = query.reduction_embedding,
      FUN = function(x)  L2Norm(mat = x)
    )
    reduction_embedding <- lapply(
      X = reduction_embedding,
      FUN = function(x) L2Norm(mat = x)
    )
  }
  query.cell.num <- nrow(x = query.reduction_embedding[[1]])
  reduction.num <- length(x = query.reduction_embedding)
  if (verbose) {
    message("Finding multimodal neighbors")
  }
  reduction_nn <- my.lapply(
    X = 1:reduction.num,
    FUN = function(x) {
      nn_x <- NNHelper(
        data = reduction_embedding[[x]],
        query = query.reduction_embedding[[x]],
        k = knn.range,
        method = 'annoy',
        metric = "euclidean"
      )
      return (nn_x)
    }
  )
  # union of rna and adt nn, remove itself from neighobors
  reduction_nn <- lapply(
    X = reduction_nn,
    FUN = function(x)  Indices(object = x)[, -1]
  )
  nn_idx <- lapply(
    X = 1:query.cell.num ,
    FUN = function(x) {
      Reduce(
        f = union,
        x = lapply(
          X = reduction_nn,
          FUN = function(y) y[x, ]
        )
      )
    }
  )
  # calculate euclidean distance of all neighbors
  nn_dist <- my.lapply(
    X = 1:reduction.num,
    FUN = function(r) {
      nndist <- NNdist(
        nn.idx = nn_idx,
        embeddings = reduction_embedding[[r]],
        query.embeddings = query.reduction_embedding[[r]],
        nearest.dist = nearest.dist[[r]]
      )
      return(nndist)
    }
  )
  # modality weighted distance
  if (length(x = sigma.list[[1]]) == 1) {
    sigma.list <- lapply(X = sigma.list, FUN = function(x) rep(x = x, ncol(x = object)))
  }
  nn_weighted_dist <- lapply(
    X = 1:reduction.num,
    FUN = function(r) {
      lapply(
        X = 1:query.cell.num,
        FUN = function(x) {
          exp(-1*(nn_dist[[r]][[x]] / sigma.list[[r]][x] ) ** kernel.power) * modality.weight.value[[r]][x]
        }
      )
    }
  )
  nn_weighted_dist <- sapply(
    X = 1:query.cell.num,
    FUN =  function(x) {
      Reduce(
        f = "+",
        x = lapply(
          X = 1:reduction.num,
          FUN = function(r) nn_weighted_dist[[r]][[x]]
        )
      )
    }
  )
  # select k nearest joint neighbors
  select_order <- lapply(
    X = nn_weighted_dist,
    FUN = function(dist) {
      order(dist, decreasing = TRUE)
    })
  select_nn <- t(x = sapply(
    X = 1:query.cell.num,
    FUN = function(x) nn_idx[[x]][select_order[[x]]][1:k.nn]
  )
  )
  select_dist <- t(x = sapply(
    X = 1:query.cell.num,
    FUN = function(x) nn_weighted_dist[[x]][select_order[[x]]][1:k.nn])
  )
  select_dist <- sqrt(x = MinMax(data = (1 - select_dist) / 2, min = 0, max = 1))
  weighted.nn <- new(
    Class = 'Neighbor',
    nn.idx = select_nn,
    nn.dist = select_dist,
    alg.info = list(),
    cell.names = Cells(x = query)
  )
  return(weighted.nn)
}

# Calculate NN distance for the given nn.idx
# @param nn.idx The nearest neighbors position index
# @param embeddings cell embeddings
# @param metric distance metric
# @param query.embeddings query cell embeddings
# @param nearest.dist The list of distance to the nearest neighbors
#
NNdist <- function(
  nn.idx,
  embeddings,
  metric = "euclidean",
  query.embeddings = NULL,
  nearest.dist = NULL
) {
  if (!is.list(x = nn.idx)) {
    nn.idx <- lapply(X = 1:nrow(x = nn.idx), FUN = function(x) nn.idx[x, ])
  }
  query.embeddings <- query.embeddings %||% embeddings
  nn.dist <- fast_dist(
    x = query.embeddings,
    y = embeddings,
    n = nn.idx
  )
  if (!is.null(x = nearest.dist)) {
    nn.dist <- lapply(
      X = 1:nrow(x = query.embeddings),
      FUN = function(x) {
        r_dist = nn.dist[[x]] - nearest.dist[x]
        r_dist[r_dist < 0] <- 0
        return(r_dist)
      }
    )
  }
  return(nn.dist)
}

# Internal helper function to dispatch to various neighbor finding methods
#
# @param data Input data
# @param query Data to query against data
# @param k Number of nearest neighbors to compute
# @param method Nearest neighbor method to use: "rann", "annoy"
# @param cache.index Store algorithm index with results for reuse
# @param ... additional parameters to specific neighbor finding method
#
#' @importFrom methods new
#' @importClassesFrom SeuratObject Neighbor
#
NNHelper <- function(data, query = data, k, method, cache.index = FALSE, ...) {
  args <- as.list(x = sys.frame(which = sys.nframe()))
  args <- c(args, list(...))
  results <- (
    switch(
      EXPR = method,
      "rann" = {
        args <- args[intersect(x = names(x = args), y = names(x = formals(fun = nn2)))]
        do.call(what = 'nn2', args = args)
      },
      "annoy" = {
        args <- args[intersect(x = names(x = args), y = names(x = formals(fun = AnnoyNN)))]
        do.call(what = 'AnnoyNN', args = args)
      },
      "hnsw" = {
        args <- args[intersect(x = names(x = args), y = names(x = formals(fun = HnswNN)))]
        do.call(what = 'HnswNN', args = args)
      },
      stop("Invalid method. Please choose one of 'rann', 'annoy'")
    )
  )
  n.ob <- new(
    Class = 'Neighbor',
    nn.idx = results$nn.idx,
    nn.dist = results$nn.dists,
    alg.info = results$alg.info %||% list(),
    cell.names = rownames(x = query)
  )
  if (isTRUE(x = cache.index) && !is.null(x = results$idx)) {
    slot(object = n.ob, name = "alg.idx") <- results$idx
  }
  return(n.ob)
}

# Run Leiden clustering algorithm
#
# Implements the Leiden clustering algorithm in R using reticulate
# to run the Python version. Requires the python "leidenalg" and "igraph" modules
# to be installed. Returns a vector of partition indices.
#
# @param adj_mat An adjacency matrix or SNN matrix
# @param partition.type Type of partition to use for Leiden algorithm.
# Defaults to RBConfigurationVertexPartition. Options include: ModularityVertexPartition,
# RBERVertexPartition, CPMVertexPartition, MutableVertexPartition,
# SignificanceVertexPartition, SurpriseVertexPartition (see the Leiden python
# module documentation for more details)
# @param initial.membership,node.sizes Parameters to pass to the Python leidenalg function.
# @param resolution.parameter A parameter controlling the coarseness of the clusters
# for Leiden algorithm. Higher values lead to more clusters. (defaults to 1.0 for
# partition types that accept a resolution parameter)
# @param random.seed Seed of the random number generator
# @param n.iter Maximal number of iterations per random start
#
# @keywords graph network igraph mvtnorm simulation
#
#' @importFrom leiden leiden
#' @importFrom reticulate py_module_available
#' @importFrom igraph graph_from_adjacency_matrix graph_from_adj_list
#
# @author Tom Kelly
#
# @export
#
RunLeiden <- function(
  object,
  method = c("matrix", "igraph"),
  partition.type = c(
    'RBConfigurationVertexPartition',
    'ModularityVertexPartition',
    'RBERVertexPartition',
    'CPMVertexPartition',
    'MutableVertexPartition',
    'SignificanceVertexPartition',
    'SurpriseVertexPartition'
  ),
  initial.membership = NULL,
  node.sizes = NULL,
  resolution.parameter = 1,
  random.seed = 0,
  n.iter = 10
) {
  if (!py_module_available(module = 'leidenalg')) {
    stop(
      "Cannot find Leiden algorithm, please install through pip (e.g. pip install leidenalg).",
      call. = FALSE
    )
  }
  switch(
    EXPR = method,
    "matrix" = {
      input <- as(object = object, Class = "matrix")
    },
    "igraph" = {
      input <- if (inherits(x = object, what = 'list')) {
        graph_from_adj_list(adjlist = object)
      } else if (inherits(x = object, what = c('dgCMatrix', 'matrix', 'Matrix'))) {
        if (inherits(x = object, what = 'Graph')) {
          object <- as.sparse(x = object)
        }
        graph_from_adjacency_matrix(adjmatrix = object, weighted = TRUE)
      } else if (inherits(x = object, what = 'igraph')) {
        object
      } else {
        stop(
          "Method for Leiden not found for class", class(x = object),
          call. = FALSE
        )
      }
    },
    stop("Method for Leiden must be either 'matrix' or igraph'")
  )
  #run leiden from CRAN package (calls python with reticulate)
  partition <- leiden(
    object = input,
    partition_type = partition.type,
    initial_membership = initial.membership,
    weights = NULL,
    node_sizes = node.sizes,
    resolution_parameter = resolution.parameter,
    seed = random.seed,
    n_iterations = n.iter
  )
  return(partition)
}

# Runs the modularity optimizer (C++ port of java program ModularityOptimizer.jar)
#
# @param SNN SNN matrix to use as input for the clustering algorithms
# @param modularity Modularity function to use in clustering (1 = standard; 2 = alternative)
# @param resolution Value of the resolution parameter, use a value above (below) 1.0 if you want to obtain a larger (smaller) number of communities
# @param algorithm Algorithm for modularity optimization (1 = original Louvain algorithm; 2 = Louvain algorithm with multilevel refinement; 3 = SLM algorithm; 4 = Leiden algorithm). Leiden requires the leidenalg python module.
# @param n.start Number of random starts
# @param n.iter Maximal number of iterations per random start
# @param random.seed Seed of the random number generator
# @param print.output Whether or not to print output to the console
# @param temp.file.location Deprecated and no longer used
# @param edge.file.name Path to edge file to use
#
# @return Seurat object with identities set to the results of the clustering procedure
#
#' @importFrom utils read.table write.table
#
RunModularityClustering <- function(
  SNN = matrix(),
  modularity = 1,
  resolution = 0.8,
  algorithm = 1,
  n.start = 10,
  n.iter = 10,
  random.seed = 0,
  print.output = TRUE,
  temp.file.location = NULL,
  edge.file.name = NULL
) {
  edge_file <- edge.file.name %||% ''
  clusters <- RunModularityClusteringCpp(
    SNN,
    modularity,
    resolution,
    algorithm,
    n.start,
    n.iter,
    random.seed,
    print.output,
    edge_file
  )
  return(clusters)
}
satijalab/seurat documentation built on May 11, 2024, 4:04 a.m.