R/SCATE.R

Defines functions SCATE

Documented in SCATE

#' Perform SCATE
#'
#' Single-cell ATAC-seq signal Extration and Enhancement
#'
#' This function takes as input the scATAC-seq reads and generates enhanced signals. Users can either perform SCATE on clusters of cells or a single group of cells.
#' @param satac If type='reads', satac should be a GRanges object or list of GRanges object of scATAC-seq reads. Each element corresponds to one single cell. The GRanges should be the middle point of the reads with length of 1 base pair. Use 'satacprocess' to preprocess raw reads. If type='peaks', satac should be a data frame or list of data frames of scATAC-seq peaks. For each data frame, first column is chromsome name, second column is start site, third column is end site, and fourth column is the number of reads of the peak.
#' @param type Character variable of either 'reads' or 'peaks'.
#' @param peakOverlapMethod Character variable of either 'full' or 'middle'. Only effective when type = 'peaks'. If peakOverlapMethod='full', then the full range of the peak will be used to find overlap with bins, and all bins overlapping with this peak will be assigned the read counts of this peak. If peakOverlapMethod='middle', only the middle base pair of the peak will be used to find overlap with bins.
#' @param genome Character variable of either "hg19" or "mm10". Default is 'hg19'.
#' @param cluster Numeric vector specifying the cluster of cells. Needs to be named and include all cells in satac. If NULL, SCATE will be run on all cells in satac.
#' @param clusterid Numeric number specifying the single cluster to run SCATE. If NULL SCATE will be run on all clusters. Ignored if cluster is NULL. The cluster id must be included in variable 'cluster'.
#' @param clunum Numeric value specifying number of CRE clusters. If NULL, SCATE automatically chooses number of CRE clusters.
#' @param datapath Character variable of the path to the customized database (eg myfolder/database.rds). The database can be made using 'makedatabase' function. If not null, 'genome' is ignored.
#' @param ncores Numeric variable of number of cores to use. If NULL, the maximum number of cores is used.
#' @param verbose Either TRUE or FALSE. If TRUE, progress will be displayed.
#' @return A numeric vector or matrix of values generated by SCATE, depending on the number of clusters. The length of the vector or the number of rows of the matrix is the same as the number of bins in the genome. If a matrix, the column names indicate the cluster id.
#' @export
#' @import GenomicAlignments GenomicRanges parallel splines2 xgboost
#' @author Zhicheng Ji, Weiqiang Zhou, Wenpin Hou, Hongkai Ji* <whou10@@jhu.edu>
#' @examples
#' #Reads as input, setting CRE cluster number as 156 to increase speed. Users need to set it to be NULL in real applications.
#' gr <- GRanges(seqnames="chr1",IRanges(start=seq_len(100)+1e6,end=seq_len(100)+1e8))
#' SCATE(gr,clunum=156,type='reads',genome="mm10") 
#' \dontrun{
#' peak <- data.frame(seqnames="chr1",start=seq_len(100)+1e6,end=seq_len(100)+1e8,count=1)
#' #Peak as input, peakOverlapMethod=full
#' SCATE(satac=peak,clunum=156,type='peaks',genome="mm10") 
#' #Peak as input, peakOverlapMethod=middle
#' SCATE(satac=peak,clunum=156,type='peaks',peakOverlapMethod='middle',genome="mm10") 
#' }

SCATE <- function(satac,type='reads',peakOverlapMethod = 'full',genome='hg19',cluster=NULL,clusterid=NULL,clunum=NULL,datapath=NULL,verbose=TRUE,ncores = 1) {
      if (Sys.info()[['sysname']]=='Windows') {
            message('Parallel is disabled for Windows. Running with one core')
            ncores <- 1
      }
      if (!is.null(datapath)) {
            loaddata <- readRDS(datapath)
      } else {
            loaddata <- readRDS(paste0(system.file(package="SCATEData"),"/extdata/",genome,".rds"))
      }
      gr <- loaddata$gr
      SCATEsingle <- function(satac,genome='hg19',datapath=NULL) {
            options(scipen=999)
            if (verbose) {
                  message('Preparing data')      
            }
            allclunum <- loaddata$allclunum
            ms <- loaddata$ms
            m <- ms$m
            s <- ms$s
            id <- loaddata$id
            gr <- loaddata$gr
            if (type=='reads') {
                  satac <- unlist(GRangesList(satac))
                  count <- countOverlaps(gr[id],satac)
            } else if (type=='peaks') {
                  if (!is.data.frame(satac)) {
                        for (i in seq_len(length(satac)))
                              colnames(satac[[i]]) <- c('chr','start','end','count')
                        satac <- do.call(rbind,satac)
                  }
                  if (peakOverlapMethod == 'full'){
                        peak <- GRanges(seqnames=satac[,1],IRanges(start=satac[,2],end=satac[,3]))
                  } else {
                        tmp <- (satac[,2]+satac[,3])/2
                        peak <- GRanges(seqnames=satac[,1],IRanges(start=tmp,end=tmp))
                  }
                  o <- as.matrix(findOverlaps(gr[id],peak))
                  count <- rep(0,length(id))
                  count[o[,1]] <- satac[o[,2],4]
            } else {
                  stop('Wrong type')
            }
            
            cutm <- as.numeric(cut(m,quantile(m,c(seq(0,10))/10),include.lowest = TRUE))
            hkid <- unlist(lapply(unique(cutm),function(i) {
                  tmpid <- which(cutm==i)
                  tmpid[which(s[tmpid] <= sort(s[tmpid])[1000])]
            }))
            
            mfit <- function(x,y) {
                  tryid <- seq(0,5)
                  BIC <- rep(0,length(tryid))
                  fit <- allkn <- list()
                  for (knotid in seq(1,length(tryid))) {
                        knotnum <- tryid[knotid]
                        if (knotnum==0) {
                              knots <- NULL
                        } else {
                              knots <- quantile(x,seq(0,1,length.out=knotnum+2)[-c(1,knotnum+2)])
                        }
                        base <- cbind(1,iSpline(x,knots=knots,Boundary.knots=range(x)))
                        allkn[[knotid]] <- knots
                        optimfunc <- function(k) {
                              logmu <- (base %*% matrix(k,ncol=1))[,1]
                              -sum(y*logmu-exp(logmu))
                        }
                        fit[[knotid]] <- optim(rep(1,ncol(base)),optimfunc,lower=c(-Inf,rep(0,ncol(base)-1)),method="L-BFGS-B")
                        logmu <- (base %*% matrix(fit[[knotid]]$par,ncol=1))[,1]
                        BIC[knotid] <- log(length(y))*length(fit[[knotid]]$par)-2*sum(y*logmu-exp(logmu))
                  }
                  par <- fit[[which.min(BIC)]]$par
                  knots <- allkn[[which.min(BIC)]]
                  xseq <- seq(min(x),max(x),0.005)
                  base <- cbind(1,iSpline(xseq,knots=knots,Boundary.knots=range(x)))
                  logmu <- (base %*% matrix(par,ncol=1))[,1]
                  
                  slopeleft <- (logmu[length(xseq)*0.1] - logmu[1])/(xseq[length(xseq)*0.1] - xseq[1])
                  sloperight <- (logmu[length(xseq)] - logmu[length(xseq)*0.9])/(xseq[length(xseq)] - xseq[length(xseq)*0.9])
                  interleft <- logmu[1] - slopeleft*xseq[1]
                  interright <- logmu[length(xseq)] - sloperight*xseq[length(xseq)]
                  list(xseq = xseq, logmu = logmu, parleft = c(interleft,slopeleft), parright = c(interright,sloperight))
            }
            
            mfitres <- mfit(m[hkid],count[hkid])
            
            logmufunc <- function(x) {
                  id <- round((x-min(mfitres$xseq))*200) + 1
                  res <- rep(0,length(id))
                  sid <- which(id >= 1 & id <= length(mfitres$xseq))
                  res[sid] <- mfitres$logmu[id[sid]]
                  sid <- which(id < 1)
                  res[sid] <- mfitres$parleft[1] + mfitres$parleft[2] * x[sid]
                  sid <- which(id > length(mfitres$logmu))
                  res[sid] <- mfitres$parright[1] + mfitres$parright[2] * x[sid]
                  res
            }
            
            deltafitfunc <- function(count,m,s) {
                  if (length(count) == 1) {
                        if (count == 0) {
                              optimfunc <- function(delta) {
                                    exp(logmufunc(m+s*delta)) + delta^2/2
                              }
                        } else {
                              optimfunc <- function(delta) {
                                    logmu <- logmufunc(m+s*delta)
                                    -(count*logmu-exp(logmu)) + delta^2/2
                              }
                        }
                  } else {
                        if (sum(count) == 0) {
                              optimfunc <- function(delta) {
                                    sum(exp(logmufunc(m+s*delta))) + delta^2/2
                              }
                        } else {
                              optimfunc <- function(delta) {
                                    logmu <- logmufunc(m+s*delta)
                                    -sum(count*logmu-exp(logmu)) + delta^2/2
                              }
                        }
                  }
                  optimise(optimfunc,c(-100,100))$minimum 
            }
            
            if (verbose) {
                  message('Fitting model')      
            }
            if (is.null(clunum)) {
                  cluster <- loaddata$cluster[,ncol(loaddata$cluster)]
                  spclu <- split(seq(1,length(cluster)),cluster)
                  tabclu <- sapply(spclu,length)
                  targetcluid <- which(tabclu >= 10)
                  targetid <- unlist(spclu[targetcluid])
                  if (length(targetid) > 10000) {
                        targetid <- sample(targetid,10000) 
                  }
                  loglike <- sapply(allclunum,function(clunum) {
                        cluster <- loaddata$cluster[,which(allclunum==clunum)]
                        spclu <- split(seq(1,length(cluster)),cluster)
                        names(spclu) <- NULL
                        if (ncores == 1){
                           loglike <- unlist(lapply(targetid,function(testid) {
                                 trainid <- setdiff(spclu[[cluster[testid]]],testid)
                                 delta <- deltafitfunc(count[trainid],m[trainid],s[trainid])
                                 logmu <- logmufunc(m[testid]+s[testid]*delta)
                                 count[testid]*logmu-exp(logmu)
                           }))
                        } else {
                           loglike <- unlist(mclapply(targetid,function(testid) {
                                 trainid <- setdiff(spclu[[cluster[testid]]],testid)
                                 delta <- deltafitfunc(count[trainid],m[trainid],s[trainid])
                                 logmu <- logmufunc(m[testid]+s[testid]*delta)
                                 count[testid]*logmu-exp(logmu)
                           }, mc.cores = ncores))
                        }
                           
                  })
                  logclu <- log2(c(allclunum,length(id)))
                  medll <- apply(loglike,2,median)
                  loessmod <- loess(medll~logclu,data.frame(medll=medll,logclu=logclu[-length(logclu)]),surface="direct")
                  clunum <- c(allclunum, length(id))[which.max(predict(loessmod,logclu))]
            }
            if (clunum == length(id)) {
                  if (ncores == 1){
                     crefeature <- unlist(lapply(seq(1,length(id)),function(i) {
                           m[i]+s[i]*deltafitfunc(count[i],m[i],s[i])
                     }))
                  } else {
                     crefeature <- unlist(mclapply(seq(1,length(id)),function(i) {
                           m[i]+s[i]*deltafitfunc(count[i],m[i],s[i])
                     }, mc.cores = ncores))
                  }
                     
            } else {
                  cluster <- loaddata$cluster[,which(allclunum==clunum)]
                  spclu <- split(seq(1,length(cluster)),cluster)
                  if (ncores == 1){
                     delta <- unlist(lapply(seq(1,clunum),function(cluid) {
                           binid <- spclu[[cluid]]
                           deltafitfunc(count[binid],m[binid],s[binid])
                     }))
                  } else {
                     delta <- unlist(mclapply(seq(1,clunum),function(cluid) {
                           binid <- spclu[[cluid]]
                           deltafitfunc(count[binid],m[binid],s[binid])
                     }, mc.cores = ncores))
                  }
                     
                  crefeature <- m+s*delta[cluster]
            }
            excid <- loaddata$excid
            excms <- loaddata$excms
            excm <- excms$m
            excs <- excms$s
            if (type=='reads') {
                  exccount <- countOverlaps(gr[excid],satac)
            } else if (type=='peaks') {
                  o <- as.matrix(findOverlaps(gr[excid],peak))
                  exccount <- rep(0,length(excid))
                  exccount[o[,1]] <- satac[o[,2],4]
            }
            logcount <- log2(count + 1)
            logexccount <- log2(exccount + 1)
            
            mod <- xgboost(cbind(m=m,logcount=logcount,s=s),crefeature,nrounds=50,verbose=FALSE)
            predfeature <- predict(mod,cbind(m=excm,logcount=logexccount,s=excs))
            fullfeature <- rep(0,length(gr))
            fullfeature[id] <- crefeature
            fullfeature[excid] <- predfeature
            pmax(0,fullfeature)
      }
      if (is.null(cluster)) {
            res <- matrix(SCATEsingle(satac,genome=genome,datapath=datapath),ncol=1)
            colnames(res) <- 'combine'
      } else {
            cluster <- cluster[names(satac)]
            if (is.null(clusterid)) {
                  target <- sort(unique(cluster))
            } else {
                  target <- sort(intersect(clusterid,cluster))
            }
            res <- sapply(target,function(i) {
                  SCATEsingle(satac[cluster==i],genome=genome,datapath=datapath)
            })     
            colnames(res) <- target
      }
      if (verbose) {
            message('Generating results')      
      }
      row.names(res) <- sprintf('%s_%s_%s',as.character(seqnames(gr)),start(gr),end(gr))
      res
}
Winnie09/SCATE documentation built on May 10, 2023, 8:10 a.m.