R/nebula.R

#' Association analysis of a multi-subject single-cell data set using a fast negative binomial mixed model
#' 
#' @param count A raw count matrix of the single-cell data. The rows are the genes, and the columns are the cells. The matrix can be a matrix object or a sparse dgCMatrix object.
#' @param id A vector of subject IDs. The length should be the same as the number of columns of the count matrix.
#' @param pred A design matrix of the predictors. The rows are the cells and the columns are the predictors. If not specified, an intercept column will be generated by default.
#' @param offset A vector of the scaling factor. The values must be strictly positive. If not specified, a vector of all ones will be generated by default. 
#' @param model 'NBGMM', 'PMM' or 'NBLMM'. 'NBGMM' is for fitting a negative binomial gamma mixed model. 'PMM' is for fitting a Poisson gamma mixed model. 'NGLMM' is for fitting a negative binomial lognormal mixed model (the same model as that in the lme4 package). The default is 'NBGMM'.
#' @param method 'LN' or 'HL'. 'LN' is to use NEBULA-LN and 'HL' is to use NEBULA-HL. The default is 'LN'.
#' @param min Minumum values for the overdispersions parameters \eqn{\sigma^2} and \eqn{\phi}. Must be positive. The default is c(1e-4,1e-4).
#' @param max Maximum values for the overdispersions parameters \eqn{\sigma^2} and \eqn{\phi}. Must be positive. The default is c(10,1000).
#' @param kappa Please see the vignettes for more details. The default is 800.
#' @param cpc A non-negative threshold for filtering low-expressed genes. Genes with counts per cell smaller than the specified value will not be analyzed. 
#' @param cutoff_cell The data will be refit using NEBULA-HL to estimate both overdispersions if the product of the cells per subject and the estimated cell-level overdispersion paremeter \eqn{\phi} is smaller than cutoff_cell. The default is 20.
#' @param opt 'lbfgs' or 'trust'. Specifying the optimization algorithm used in NEBULA-LN. The default is 'lbfgs'. If it is 'trust', a trust region algorithm based on the Hessian matrix wil be used for optimization.
#' @param verbose An optional logical scalar indicating whether to print additional messages. Default is FALSE.
#' @param covariance If TRUE, nebula will output the covariance matrix for the estimated log(FC), which can be used for testing contrasts.
#' @return summary: The estimated coefficient, standard erro and p-value for each predictor.
#' @return overdispersion: The estimated cell-level and subject-level overdispersions \eqn{\sigma^2} and \eqn{\phi^{-1}}.
#' @return convergence: More information about the convergence of the algorithm for each gene. A value of -20 or -30 indicates a potential failure of the convergence.    
#' @return algorithm: The algorithm used for analyzing the gene. More information can be found in the vignettes.

#' @keywords 
#' @export
#' @examples
#' library(nebula)
#' data(sample_data)
#' pred = model.matrix(~X1+X2+cc,data=sample_data$pred)
#' re = nebula(count=sample_data$count,id=sample_data$sid,pred=pred)
#' 


nebula = function (count, id, pred = NULL, offset = NULL,min = c(1e-4,1e-4), max = c(10,1000), 
                   model = 'NBGMM', method = "LN", cutoff_cell = 20, kappa=800, opt='lbfgs',
                   verbose = TRUE, cpc = 0.005, covariance = FALSE)
{
  reml = 0
  eps = 1e-06
  
  if(cpc<0)
  {cpc = 0}
  
  if(!(method %in% c('LN','HL')))
  {stop("The method argument should be \'LN\' or \'HL\'.")}
  
  if (is.vector(count)) {
    count = matrix(count, nrow = 1)
  }
  count = as(count, "dgCMatrix")
  nind = ncol(count)
  if(nind<2)
  {stop("There is no more than one cell in the count matrix.")}
  gname = rownames(count)
  predn = sds = NULL
  intcol = NULL
  if (is.null(pred)) {
    pred = matrix(1, ncol = 1, nrow = nind)
    sds = 1
    intcol = 1
  }else {
    predn = colnames(pred)
    pred = center_m(as.matrix(pred))
    sds = pred$sds
    if((sum(sds==0)>1)|(sum(sds<0)>0))
    {stop("Some predictors have zero variation or a zero vector.")}
    intcol = which(sds==0)
    pred = as.matrix(pred$pred)
    if(nrow(pred)!=nind)
    {stop('The number of rows of the design matrix should be equal to the number of columns of the count matrix.')}
  }
  
  nb = ncol(pred)
  ngene = nrow(count)
  
  if (is.null(offset)) {
    of_re = cv_offset(0,0,nind)
  }else{
    if(length(offset)!=nind)
    {
      stop('The length of offset should be equal to the number of columns of the count matrix.')
    }
    of_re = cv_offset(as.double(offset),1,nind)
  }
  offset = of_re$offset;
  moffset = of_re$moffset;
  cv = of_re$cv
  cv2 = cv*cv
  
  ntau = switch(model,'NBGMM'=2,'PMM'=1,'NBLMM'=2,
                stop("The argument model should be NBGMM, PMM or NBLMM."))
  
  npar = nb + ntau
  lower = c(rep(-100, nb), min[1:ntau])
  upper = c(rep(100, nb), max[1:ntau])
  
  if(length(id)!=nind)
  {stop('The length of subject IDs should be equal to the number of columns of the count matrix.')}
  
  id = as.character(id)
  levels = unique(id)
  id = as.numeric(factor(id,levels=levels))
  if(is.unsorted(id))
  {stop("The cells of the same subject have to be grouped.")}
  k = length(levels)
  fid = which(c(1, diff(id)) == 1)
  fid = as.integer(c(fid, nind + 1))
  
  cell_ind = get_cell(pred,fid-1,nb,k)
  cell_ind = which(cell_ind>0)
  ncell = length(cell_ind)
  
  mfs = nind/k
  if(mfs<30)
  {
    method = 'HL'
    if (verbose == TRUE) {
      cat("The average number of cells per subject (", round(mfs,digits=2), ") is less than 30. The 'method' is set as 'HL'.\n")
    }
  }
  
  cumsumy = call_cumsumy(count, fid - 1, k, ngene)
  if (ngene == 1) {
    cumsumy = matrix(cumsumy, ncol = k)
  }
  posind = lapply(1:ngene, function(x) which(cumsumy[x, ] > 0))
  count = Matrix::t(count)
  gid = which((rowSums(cumsumy)/nind) > cpc)
  lgid = length(gid)
  if (verbose == TRUE) {
    cat("Remove ", ngene-lgid, " genes having low expression.\n")
  }
  if (lgid == 0) {
    stop("No gene passed the filtering.")
  }
  if (verbose == TRUE) {
    cat("Analyzing ", lgid, " genes with ",
        k, " subjects and ", nind, " cells.\n")
  }
  
  cell_init = 1
  
  if (model %in% c('NBGMM','NBLMM')) {
    
    re = sapply(gid, function(x) {
      posv = call_posindy(count, x - 1, nind)
      
      if((posv$mct*mfs)<3)
      {ord = 3}else{ord = 1}
      
      para = c(log(posv$mct) - moffset, rep(0, nb - 1))
      re_t = tryCatch({
        ref = switch(opt,
                     'lbfgs'=lbfgs(c(para,1,cell_init), ptmg_ll_der, posindy = posv$posindy, X = pred, offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                                   ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,nind = nind, lower = lower, upper = upper,control = list(ftol_abs = eps)),
                     'trust'=trust(objfun=ptmg_ll_der_hes3, c(para,0,0), 2, 100,posindy=posv$posindy,X=pred,offset=offset,Y=posv$Y,n_one=posv$n_onetwo,ytwo=posv$ytwo,fid=fid,cumsumy=cumsumy[x,],posind=posind[[x]],nb=nb,k=k,nind=nind),
                     stop("The argument opt should be lbfgs, or trust.")
        )
        if(opt=='lbfgs')
        {ref = ref$par}else{
          ref = ref$argument
          ref[nb+1] = median(c(exp(ref[nb+1]),min[1],max[1]))
          ref[nb+2] = median(c(exp(ref[nb+2]),min[2],max[2]))
        }
        c(ref,1)
      }, error = function(e) {
        ref = nlminb(start=c(para,1,cell_init), objective=ptmg_ll,gradient=ptmg_der,posindy = posv$posindy, X = pred, offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                     ytwo = posv$ytwo, fam = id, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                     nind = nind, lower = lower, upper = upper)
        c(ref$par,0)
      })
      # betae = re_t$par[1:nb]
      conv = re_t[length(re_t)]
      fit = 1
      lmct = log(posv$mct)
      betae = c(lmct - moffset, rep(0, nb - 1))
      vare = re_t[(nb + 1):(nb + 2)]
      para = vare
      
      if(ncell>0)
      {cv2 = get_cv(offset,pred,re_t[1:nb],cell_ind,ncell,nind)}
      
      gni = mfs * vare[2]
      
      if(method == "LN")
      {
        if(model=="NBGMM")
        {
          if (gni < cutoff_cell) {
            re_t = tryCatch({bobyqa(para, pql_ll, reml = reml, eps = eps, ord=ord,
                          betas = betae, intcol = intcol, posindy = posv$posindy, X = pred,offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                          ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                          nind = nind, lower = min, upper = max)},
                          error = function(e){
                            bobyqa(para, pql_ll, reml = reml, eps = eps, ord=1,
                                   betas = betae, intcol = intcol, posindy = posv$posindy, X = pred,offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                                   ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                                   nind = nind, lower = min, upper = max)
                          })
            
            vare = re_t$par[1:2]
            fit = 2
          }else{
            kappa_obs = gni/(1+cv2)
            # if ((gni/(1+cv2)) < kappa)
            if((kappa_obs<20)|((kappa_obs<kappa)&(vare[1]<(8/kappa_obs))))
            {
              varsig = tryCatch({nlminb(para[1], pql_gamma_ll, gamma = para[2],betas = betae, intcol = intcol, reml = reml, eps = eps,ord=ord,
                              posindy = posv$posindy, X = pred, offset = offset,Y = posv$Y, n_one = posv$n_onetwo, ytwo = posv$ytwo,
                              fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                              nind = nind, lower = min[1], upper = max[1])},
                              error = function(e){
                                nlminb(para[1], pql_gamma_ll, gamma = para[2], betas = betae, intcol = intcol, reml = reml, eps = eps,ord=1,
                                       posindy = posv$posindy, X = pred, offset = offset,Y = posv$Y, n_one = posv$n_onetwo, ytwo = posv$ytwo,
                                       fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                                       nind = nind, lower = min[1], upper = max[1])
                              })
              vare[1] = varsig$par[1]
              fit = 3
            }
          }
        }else{
          if (gni < cutoff_cell) {
            re_t = bobyqa(para, pql_nbm_ll, reml = reml,
                          eps = eps, ord=1, betas = betae, intcol = intcol, posindy = posv$posindy,
                          X = pred, offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                          ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                          nind = nind, lower = min, upper = max)
            vare = re_t$par[1:2]
            fit = 4
          }else{
            varsig = nlminb(para[1], pql_nbm_gamma_ll, gamma = para[2],betas = betae, intcol = intcol, reml = reml, eps = eps,ord=1,
                          posindy = posv$posindy, X = pred, offset = offset,Y = posv$Y, n_one = posv$n_onetwo, ytwo = posv$ytwo,
                          fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                          nind = nind, lower = min[1], upper = max[1])
            vare[1] = varsig$par[1]
            fit = 6
          }
        }
      }else {
        
        if (model == "NBGMM") {
          re_t = tryCatch({bobyqa(para, pql_ll, reml = reml, eps = eps, ord=ord,
                        betas = betae, intcol = intcol, posindy = posv$posindy, X = pred,
                        offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                        ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                        nind = nind, lower = min, upper = max)},
                        error = function(e){
                          bobyqa(para, pql_ll, reml = reml, eps = eps, ord=1,
                                 betas = betae, intcol = intcol, posindy = posv$posindy, X = pred,
                                 offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                                 ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                                 nind = nind, lower = min, upper = max)
                        })
          fit = 2
        }
        else {
          re_t = bobyqa(para, pql_nbm_ll, reml = reml,
                        eps = eps, ord=1, betas = betae, intcol = intcol, posindy = posv$posindy,
                        X = pred, offset = offset, Y = posv$Y, n_one = posv$n_onetwo,
                        ytwo = posv$ytwo, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,
                        nind = nind, lower = min, upper = max)
          fit = 4
        }
        vare = re_t$par[1:2]
      }
      
      betae[intcol] = betae[intcol] - vare[1]/2
      if (model == "NBGMM") {
        repml = opt_pml(pred, offset, posv$Y, fid-1, cumsumy[x, ], posind[[x]]-1, posv$posindy,nb, nind, k, betae, vare, reml, 1e-06,1)
      }else {
        repml = opt_pml_nbm(pred, offset, posv$Y, fid-1, cumsumy[x, ], posind[[x]]-1,posv$posindy, nb, nind, k, betae,vare, reml, 1e-06,1)
      }
      
      if(is.nan(repml$loglik))
      {conv = -30}else{
        if(repml$iter==50)
        {
          conv = -20
        }else{
          if(repml$damp==11)
          {conv = -10}else{
            if(repml$damp==12)
            {conv = -40}
          }
        }
      }
      if(covariance==FALSE)
      {
        c(repml$beta, vare[1],1/vare[2], diag(repml$var),conv,fit)
      }else{
        c(repml$beta, vare[1],1/vare[2], diag(repml$var),conv,fit,repml$var[lower.tri(repml$var, diag = TRUE)])
      }
    })
  }else {
    re = sapply(gid, function(x) {
      posv = call_posindy(count, x - 1, nind)
      para = c(log(posv$mct) - moffset, rep(0, nb - 1),1)
      re_t = tryCatch({
        nlminb(para, pmg_ll, pmg_der,posindy = posv$posindy, X = pred, offset = offset,
               Y = posv$Y, fid = fid, cumsumy = cumsumy[x,], posind = posind[[x]], nb = nb, k = k,nind = nind, lower = lower, upper = upper)
      }, error = function(e) {
        bobyqa(para, pmg_ll, posindy = posv$posindy,X = pred, offset = offset, Y = posv$Y, 
               fid = fid, cumsumy = cumsumy[x, ], posind = posind[[x]],
               nb = nb, k = k, nind = nind, lower = lower,upper = upper)
      })
      
      # hes = optimHess(re_t$par, pmg_ll, gr = pmg_der, posindy = posv$posindy,X = pred, offset = offset, Y = posv$Y, fid = fid, cumsumy = cumsumy[x, ], posind = posind[[x]],nb = nb, k = k, nind = nind)
      hes = pmg_hes(re_t$par,posindy = posv$posindy,X = pred, offset = offset, Y = posv$Y,fid = fid, cumsumy = cumsumy[x, ], posind = posind[[x]],nb = nb, k = k, nind = nind)
      
      var_re = rep(NA, npar)
      nnaind = (!is.nan(diag(hes))) & (re_t$par != lower) & (re_t$par != upper)
      covfix = solve(-hes[nnaind, nnaind])
      var_re[nnaind] = diag(covfix)
      if(covariance==FALSE)
      {
        c(re_t$par, var_re[1:nb],1,5)
      }else{
        tempv = matrix(NA,nb,nb)
        tempv[nnaind, nnaind] = covfix
        c(re_t$par, var_re[1:nb],1,5,tempv[lower.tri(tempv, diag = TRUE)])
      }
    })
  }
  re_all = as.data.frame(t(re))
  p = sapply(1:nb, function(x) pchisq(re_all[, x]^2/re_all[,x + npar], 1, lower.tail = FALSE))
  for (i in (npar + 1):(npar + nb)) {
    re_all[, i] = sqrt(re_all[, i])
  }
  
  sds[intcol] = 1
  
  re_all[, c(1:nb,npar+(1:nb))] = t(t(re_all[, c(1:nb,npar+(1:nb))])/c(sds,sds))
  
  if(covariance==TRUE)
  {
    sdssc = sds%*%t(sds)
    sdssc = sdssc[lower.tri(sdssc, diag = TRUE)]
    re_all[, (npar+nb+3):ncol(re_all)] = t(t(re_all[, (npar+nb+3):ncol(re_all)])/sdssc)
  }
  
  re_all = cbind(re_all, matrix(p, ncol = nb))
  re_all$gene_id = gid
  re_all$gene = gname[gid]
  
  rl = c("Subject", "Cell")
  if (length(predn) == 0) {
    indp = 1:nb
  }else {
    indp = predn
  }
  
  ncov = 0
  if(covariance==FALSE)
  {
    colnames(re_all)[1:(3 * nb + ntau+2)] = c(paste0("logFC_",indp), rl[1:ntau], paste0("se_",indp), "convergence", "algorithm",paste0("p_", indp))
  }else{
    ncov = (nb+1)*nb/2
    colnames(re_all)[1:(3 * nb + ntau+2+ncov)] = c(paste0("logFC_",indp), rl[1:ntau], paste0("se_",indp), "convergence", "algorithm",paste0("cov_", 1:ncov),paste0("p_", indp))
  }
  
  algorithms = c('NBGMM (LN)','NBGMM (HL)','NBGMM (LN+HL)','NBLMM (HL)','PGMM','NBLMM (LN+HL)')
  list(summary = re_all[, c(paste0("logFC_", indp), paste0("se_",indp), paste0("p_", indp), colnames(re_all)[(3*nb+ntau+3+ncov):ncol(re_all)])],
       overdispersion = re_all[, c(rl[1:ntau])],
       convergence = re_all[,'convergence'],
       algorithm = algorithms[re_all[,'algorithm']],
       covariance = if(covariance==FALSE){NULL}else{re_all[, paste0("cov_", 1:ncov)]})
  
}

Try the nebula package in your browser

Any scripts or data that you put into this service are public.

nebula documentation built on Sept. 24, 2021, 3 a.m.