R/trajectories.R

Defines functions mincostflow prune distmat

Documented in distmat mincostflow prune

#' @import lpSolve
NULL
#> NULL

#' Calculate scSLAM-seq based distance matrices
#'
#' This function calculates distances matrices between cells of different time points based on metabolic labeling RNA profiles.
#' @param prev.t Cells to be used from the previous time point in distance matrix calculation.
#' @param next.t Cells to be used from the next time point in distance matrix calculation.
#' @param prevAssay Name of the expression assay of cells from the previous time point.
#' @param nextAssay Name of the expression assay of cells from the next time point.
#' @param gene_subset Set a subset of genes on which trajectories should be calculated. Other genes will be disregarded.
#' @return Distance matrix between cells from two time points.
#' @examples
#' \donttest{
#' 
#' # Full vignette available on https://grandr.erhard-lab.de/articles/web/hetseq.html
#' 
#'   obj.list <- SplitObject(seuratObject, split.by = "time")
#'   D.list=list(
#'    distmat(treatment.list[["0h"]],treatment.list[["2h"]], "RNA", "prevRNA"),
#'    distmat(treatment.list[["2h"]],treatment.list[["4h"]], "RNA", "prevRNA")
#'   )
#' }
#' @export
distmat=function(prev.t,next.t, prevAssay, nextAssay, gene_subset = NULL) {
  if(is.null(gene_subset)){
    gene_subset = c(rownames(prev.t),rownames(next.t))
  }
  DefaultAssay(prev.t)=prevAssay
  DefaultAssay(next.t)=nextAssay
  prev.t=FindVariableFeatures(NormalizeData(prev.t))
  next.t=FindVariableFeatures(NormalizeData(next.t))
  features <- intersect(SelectIntegrationFeatures(list(prev.t,next.t)), gene_subset)
  anchors <- FindIntegrationAnchors(list(prev.t,next.t), anchor.features = features)
  combined <- IntegrateData(anchors)
  
  DefaultAssay(combined) <- "integrated"
  combined <- ScaleData(combined, verbose = FALSE)
  combined <- RunPCA(combined, npcs = 30, verbose = FALSE)
  combined <- RunUMAP(combined, reduction = "pca", dims = 1:15)
  combined <- FindNeighbors(combined, reduction = "pca", dims = 1:15)
  #m=combined@assays$integrated@data
  m=t(combined@reductions$pca@cell.embeddings[,1:15])
  t1=as.matrix(m[,1:ncol(prev.t)])
  t2=as.matrix(m[,(1+ncol(prev.t)):ncol(m)])
  
  sqeuc=function(a,b) sum((a-b)^2)
  allcomb = apply(expand.grid(1:ncol(t1),1:ncol(t2)),1,function(ii) sqeuc(t1[,ii[1]],t2[,ii[2]]))
  re=matrix(allcomb,nrow=ncol(t1))
  rownames(re)=colnames(t1)
  colnames(re)=colnames(t2)
  re
}

#' Prune trajectories
#'
#' Prune trajectories down to top.n candidates to reduce runtime of subsequent mincostflow function.
#' @param D.list List of distance matrices. Can be generated by the distmat function.
#' @param top.n Prune trajectories to only top.n possible connections to optimize subsequent application of the mincostflow function.
#' @examples
#' \donttest{
#' 
#' # Full vignette available on https://grandr.erhard-lab.de/articles/web/hetseq.html
#' 
#'   prune(D.list, top.n = 10)
#' }
#' @return Pruned distance matrix between cells from multiple timepoints.
#' @export
prune=function(D.list,top.n = 10) {
  if(length(D.list)==1){
    D=D.list[[1]]
    re=lapply(D.list,function(D) {
      prune = matrix(0,nrow=nrow(D),ncol=ncol(D))
      for (i in 1:nrow(D)) prune[i,rank(D[i,])>top.n] = prune[i,rank(D[i,])>top.n]+1
      for (i in 1:ncol(D)) prune[rank(D[,i])>top.n,i] = prune[rank(D[,i])>top.n,i]+1
      D[prune==2]=Inf
      D
    })
    keep = apply(re[[1]],1,function(v) sum(is.finite(v))>0)
    re[[1]] = re[[1]][keep,]
    re
  } else {
    re=lapply(D.list,function(D) {
      prune = matrix(0,nrow=nrow(D),ncol=ncol(D))
      for (i in 1:nrow(D)) prune[i,rank(D[i,])>top.n] = prune[i,rank(D[i,])>top.n]+1
      for (i in 1:ncol(D)) prune[rank(D[,i])>top.n,i] = prune[rank(D[,i])>top.n,i]+1
      D[prune==2]=Inf
      D
    })
    keep = apply(re[[1]],1,function(v) sum(is.finite(v))>0)
    re[[1]] = re[[1]][keep,]
    for (i in 2:length(re)) {
      keep = apply(re[[i-1]],2,function(v) sum(is.finite(v))>0) & apply(re[[i]],1,function(v) sum(is.finite(v))>0)
      re[[i-1]] = re[[i-1]][,keep]
      re[[i]] = re[[i]][keep,]
    }
    keep = apply(re[[length(re)]],2,function(v) sum(is.finite(v))>0)
    re[[length(re)]] = re[[length(re)]][,keep]
    re
  }
}


#' Min-Cost-Max-Flow for cellular trajectories
#'
#' Applies Min-Cost-Max-Flow to calculate optimal trajectories from distance matrices.
#' @param D.list List of (pruned) distance matrices. Can be generated by the distmat function.
#' @param verbose Show verbose output.
#' @examples
#' \donttest{
#' 
#' # Full vignette available on https://grandr.erhard-lab.de/articles/web/hetseq.html
#' 
#'   mincostflow(D.list)
#' }
#' @return Matrix of cell-cell trajectories spanning all given timepoints.
#' @export
mincostflow = function(D.list, verbose=TRUE) {
  
  if (length(D.list)>1) for (i in 2:length(D.list)) if (ncol(D.list[[i-1]])!=nrow(D.list[[i]]) || !all(colnames(D.list[[i-1]])==rownames(D.list[[i]]))) stop("Dimensions do not match!")
  groups = c(nrow(D.list[[1]]),sapply(D.list,ncol))
  dist.edges = sapply(D.list,function(D) sum(is.finite(D)))
  num.edges = sum(dist.edges) + sum(groups)
  
  if (verbose) cat(sprintf("Computing max flow...\n"))
  
  ren=function(D,i) {rownames(D)=paste(i,rownames(D));colnames(D)=paste(i,colnames(D));D}
  g=do.call("rbind",lapply(1:length(D.list),function(di) reshape2::melt(ren(D.list[[di]],di))))
  g=g[is.finite(g$value),]
  g$value=1
  colnames(g) <- c("from", "to", "capacity")
  g = rbind(g,data.frame(from="source",to=paste(1,rownames(D.list[[1]])),capacity=1))
  if (length(D.list)>1) for (i in 2:length(D.list)) g = rbind(g,data.frame(from=paste(i-1,rownames(D.list[[i]])),to=paste(i,rownames(D.list[[i]])),capacity=1))
  g = rbind(g,data.frame(from=paste(length(D.list),colnames(D.list[[length(D.list)]])),to="sink",capacity=1))
  g <- igraph::graph_from_data_frame(g)
  total = igraph::max_flow(g, source = "source", target = "sink")$value
  
  if (verbose) cat(sprintf("max flow=%d, bottleneck=%d\n",total,min(groups)))
  
  #total = min(groups)
  num.constraints = sum(groups) + 2 + sum(sapply(D.list,function(D) nrow(D)+ncol(D)))
  # the first edges are the finite dist edges of D1, then D2, ... , Dn
  
  if (verbose) cat(sprintf("Allocate %dx%d matrix...\n",num.constraints,num.edges))
  lhs = matrix(NA,ncol=num.edges,nrow = num.constraints)
  
  
  # construct contraints enforcing capacity of 1 for each cell
  capacities = sum(groups)
  if (verbose) cat(sprintf("Constructing capacity constraints for %d cells...\n",capacities))
  lhs[1:capacities,] = cbind(matrix(0,nrow=capacities,ncol=sum(dist.edges)),diag(capacities))
  dir = rep("<=",capacities)
  rhs = rep(1,capacities)
  
  cpad = rep(0,sum(dist.edges))
  # add source constraints
  lhs[capacities+1,] = c(cpad,rep(1,groups[1]),rep(0,capacities-groups[1]))
  dir = c(dir,"==")
  rhs = c(rhs, total)
  # add sink constraints
  lhs[capacities+2,] = c(cpad,rep(0,capacities-groups[length(groups)]),rep(1,groups[length(groups)]))
  dir = c(dir,"==")
  rhs = c(rhs, total)
  
  # add flow constraints
  if (verbose) cat(sprintf("Constructing flow constraints and costs for %d edges...\n",sum(dist.edges)))
  cpad = rep(0,sum(groups))
  bpad = c()
  cpos = 1
  ncstr = capacities+3
  for (D in D.list) {
    # add contraints for left group
    for (r in 1:nrow(D)) {
      rr=matrix(0,nrow=nrow(D),ncol=ncol(D))
      rr[r,] = 1
      rr=as.vector(rr)[is.finite(D)]
      row = c(bpad,rr)
      row = c(row,rep(0,sum(dist.edges)-length(row)))
      rest = rep(0,num.edges-length(row))
      rest[cpos] = -1
      lhs[ncstr,] = c(row,rest)
      ncstr=ncstr+1
      cpos=cpos+1
    }
    # add contraints for right group
    for (c in 1:ncol(D)) {
      cc=matrix(0,nrow=nrow(D),ncol=ncol(D))
      cc[,c] = 1
      cc=as.vector(cc)[is.finite(D)]
      row = c(bpad,cc)
      row = c(row,rep(0,sum(dist.edges)-length(row)))
      rest = rep(0,num.edges-length(row))
      rest[cpos] = -1
      lhs[ncstr,] = c(row,rest)
      ncstr=ncstr+1
      cpos=cpos+1
    }
    cpos = cpos-ncol(D)
    bpad = c(bpad,rep(0,sum(is.finite(D))))
  }
  dir = c(dir,rep("==",nrow(lhs)-length(dir)))
  rhs = c(rhs,rep(0,nrow(lhs)-length(rhs)))
  
  obj = c(do.call("c",lapply(D.list,function(D) as.vector(D))), rep(0,sum(groups)))
  obj = obj[is.finite(obj)]

  if (verbose) cat(sprintf("Optimizing %d variables with %d constraints...\n",ncol(lhs),nrow(lhs)))
  solution <- lpSolve::lp(
    direction = 'min',
    objective.in = obj,
    const.mat = lhs,
    const.dir = dir,
    const.rhs = rhs)
  
  mergeorder = function(a,b) {
    re=merge(a,b,by='C')
    re=re[,c(2:ncol(a),1,(ncol(a)+1):ncol(re))]
    re
  }
  if (verbose) cat(sprintf("Constructing %d trajectories...\n",total))
  df=NULL
  sol = solution$solution
  pos = 1
  for (di in 1:length(D.list)) {
    D=D.list[[di]]
    edges = dist.edges[di]
    mm = rep(0,nrow(D)*ncol(D))
    mm[as.vector(is.finite(D))]=sol[pos:(pos+edges-1)]
    pos = pos+edges
    
    mm=matrix(mm,ncol=ncol(D))
    rownames(mm) = rownames(D)
    colnames(mm) = colnames(D)
    mm = reshape2::melt(mm)
    mm = mm[mm$value!=0,]
    dfn = setNames(data.frame(as.character(mm$Var1),as.character(mm$Var2)),c("C",""))
    df = if (is.null(df)) dfn else mergeorder(df,dfn)
    names(df)=c(rep("",ncol(df)-1),"C")
  }
  names(df)=paste0("G",1:(length(D.list)+1))
  df
}

Try the HetSeq package in your browser

Any scripts or data that you put into this service are public.

HetSeq documentation built on April 4, 2025, 2:03 a.m.