R/PseudotimeDE.R

Defines functions zinbgam dzinb.log fit_gam fit_qgam pseudotimeDE

Documented in pseudotimeDE

#' @title Perform Differential Expression Test on One Gene
#'
#' @description
#' Test if one gene is differentially expressed along pseudotime.
#' @param gene A string of gene name. It should be one of the row names in sce.
#' @param ori.tbl A tibble or dataframe which contains the original cells and pseudotime as two columns.
#' @param sub.tbl A list of tibbles or dataframes where each is the fit of a subsample. Each element is the same format as \code{ori.tbl}.
#' @param mat The input expression data. It can be:
#' (1) A SingleCellExperment object which contain the expression data;
#' (2) An matrix;
#' (3) A Seurat object which contain the expression data.
#' Its row names should be genes and col names should be cells.
#' @param assay.use The \code{assay} used in SingleCellExperiment or \code{slot} used in Seurat. Default is \code{counts}.
#' @param model A string of the model name. One of \code{nb}, \code{zinb}, \code{gaussian}, \code{auto} and \code{qgam}.
#' @param k A integer of the basis dimension. Default is 6. The reults are usually robust to different k; we recommend to use k from 5 to 10.
#' @param knots A numeric vector of the location of knots. Default is evenly distributed between 0 to 1. For instance, if your k = 6, and your range is [0, 10], then the position of knots should be \code{c(0:5)*(10-0)}.
#' @param fix.weight A logic variable indicating if the ZINB-GAM will use the zero weights from the original model. Used for saving time since ZINB-GAM is computationally intense.
#' @param aicdiff A numeric variable of the threshold of model selection. Only works when \code{model = `auto`}.
#' @param seed A numeric variable of the random seed. It mainly affects the parametricfitting of null distribution.
#' @param quant The quantile of interest for quantile regression (qgam), range from 0 to 1, default as 0.5.
#' @param usebam A logical variable. If use \code{mgcv::bam}, which may be faster with large sample size (e.g., > 10'000 cells).
#' @param formula An (optional) custom formula to be passed to \code{mgcv::gam}, \code{mgcv::bam}, or \code{qgam::qgam}.
#' @param seurat.assay The \code{assay} used in Seurat. Default is \code{'RNA'}.
#'
#' @return A list with the components:
#' \describe{
#'   \item{\code{fix.pv}}{The p-value assuming the pseudotime is fixed}
#'   \item{\code{emp.pv}}{The permutation p-value}
#'   \item{\code{para.pv}}{The permutation p-value by fitting a parametric null distribution}
#'   \item{\code{ad.pv}}{P-value of Anderson-Darling test on comparing null distribution and its parametric fit}
#'   \item{\code{rank}}{The estimated effect degree of freedom of the original model}
#'   \item{\code{gam.fit}}{The fitted gam model on original data}
#'   \item{\code{zinf}}{Whether the model is zero inflated}
#'   \item{\code{aic}}{The AIC of the orignial model}
#'   \item{\code{expv.quantile}}{Quantiles of the log counts plus 1}
#'   \item{\code{expv.mean}}{Mean of the log counts plus 1}
#'   \item{\code{expv.zero}}{Zero proportions of counts}
#' }
#'
#' @importFrom magrittr %>%
#' @importFrom stats fitted
#'
#' @export pseudotimeDE
#'
#' @author Dongyuan Song, Shiyu Ma


pseudotimeDE <- function(gene,
                         ori.tbl,
                         sub.tbl,
                         mat,
                         assay.use = "counts",
                         model = c("nb", "zinb", "gaussian" ,"auto", "qgam"),
                         k = 6,
                         knots = c(0:5/5),
                         fix.weight = TRUE,
                         aicdiff = 10,
                         seed = 123,
                         quant = 0.5,
                         usebam = FALSE,
                         formula = NULL,
                         seurat.assay = 'RNA') {

  ## Set seed
  set.seed(seed)


  ## ori.tbl is a dataframe containing cell and pseudotime. "cell" must be the same order as in mat

  ## Avoid R checking warning
  splits <- time.res <- NULL

  # mgcv::gam <<- function(...) {
  #   if(usebam){
  #     mgcv::bam(..., discrete = TRUE)
  #   }else{
  #     mgcv::gam(...)
  #   }
  # }

  if(length(knots) != k){
    stop("The number of knot positions given (",length(knots),") does not match k.")
  }


  ## Construct formula
  if(is.null(formula)){
    
    fit.formula <- stats::as.formula(paste0("expv ~ s(pseudotime, k = ", k, ",bs = 'cr')"))
    
  }
  else{
    
    if(inherits(formula, "formula")){
      fit.formula <- formula
    }
    else{
      stop("Invalid custom formula")
    }
    
  }
  

  num_cell <- length(ori.tbl$cell)
  pseudotime <- ori.tbl$pseudotime

  num_total_cell <- dim(mat)[2]

  #expv <- assays(sce)$counts[gene, ori.tbl$cell]
  #count.v <- expv

  if(class(mat)[1] == "SingleCellExperiment") {
    expv <- SummarizedExperiment::assay(mat, assay.use)[gene, ori.tbl$cell]
    count.v <- expv
  }
  else if(class(mat)[1] == "Seurat") {
    if(assay.use == 'logcounts'){
      assay_alter <- 'data'
    }else{
      assay_alter <- assay.use
    }
    expv <- Seurat::GetAssayData(object = mat, slot = assay_alter, assay = seurat.assay)[gene, ori.tbl$cell]
    count.v <- expv
  }
  else {
    expv <- mat[gene, ori.tbl$cell]
    count.v <- expv
  }

  dat <- cbind(pseudotime, expv) %>% tibble::as_tibble()

  #if(assay.use == "logcounts"){ ## We remove the special setting of counts since the assay name can be arbitrary from users
    expv.quantile <- stats::quantile(count.v)
    expv.mean <- mean(count.v)
    expv.zero <- sum(count.v == 0)/length(count.v)
  #}
  #else{
  #  expv.quantile <- stats::quantile(log(count.v + 1))
  #  expv.mean <- mean(log(count.v + 1))
  #  expv.zero <- sum(count.v == 0)/length(count.v)
  #}


  if(model == "gaussian"){
    fit.gaussian <- fit_gam(dat, distribution = "gaussian", use_weights = FALSE, k = k, knots = knots, usebam = usebam, fit.formula = fit.formula)
    aic.gaussian <- stats::AIC(fit.gaussian)
  }
  else if(model == "qgam"){
    fit.qgam <- fit_qgam(dat, k = k, quant = quant, fit.formula = fit.formula)
    aic.qgam <- stats::AIC(fit.qgam)
  }
  else{
    NULL
  }

  ## zinb model only uses the mu part.


  if(model == "nb") {
    fit <- fit_gam(dat, distribution = "nb", use_weights = FALSE, k = k, knots = knots, usebam = usebam, fit.formula = fit.formula)
    if(identical(fit, FALSE)) stop("Fit failed")
    
    zinf <- FALSE
    distri <- "nb"
    aic <- stats::AIC(fit)
  }
  else if(model == "gaussian"){
    fit <- fit.gaussian
    if(identical(fit, FALSE)) stop("Fit failed")
    
    zinf <- FALSE
    distri <- "gaussian"
    aic <- aic.gaussian
  }
  else if (model == "zinb"){
    fit.zinb <- zinbgam(fit.formula, ~ logmu, data = dat, knots = knots, k = k, usebam = usebam)
    if(identical(fit.zinb, FALSE)) stop("Fit failed")
    
    aic.zinb <- fit.zinb$aic
    fit.zinb <- fit.zinb$fit.mu
    zinf <- TRUE
    distri <- "zinb"
    fit <- fit.zinb
    aic <- aic.zinb
  }
  else if (model == "auto") {
    fit.nb <- fit_gam(dat, distribution = "nb", use_weights = FALSE, k = k, knots = knots, usebam = usebam, fit.formula = fit.formula)
    if(identical(fit.nb, FALSE)) stop("Fit failed")
    
    aic.nb <- stats::AIC(fit.nb)

    fit.zinb <- zinbgam(fit.formula, ~ logmu, data = dat, knots = knots, k = k, usebam = usebam)
    if(identical(fit.zinb, FALSE)) stop("Fit failed")
    
    aic.zinb <- fit.zinb$aic
    fit.zinb <- fit.zinb$fit.mu

    if(aic.nb - aic.zinb > aicdiff) {
      fit <- fit.zinb
      zinf <- TRUE
      distri <- "zinb"
      aic <- aic.zinb
    }
    else {
      fit <- fit.nb
      zinf <- FALSE
      distri <- "nb"
      aic <- aic.nb
    }
  }
  else if (model == "qgam"){
    fit <- fit.qgam
    zinf <- FALSE
    aic <- aic.qgam
  }
  else {stop("Specified 'method=' parameter is invalid. Must be one of 'nb', 'zinb', 'gaussian' ,'auto', 'qgam'.")}



  ## ori p-value
  edf1 <- sum(fit$edf1)
  s.pv <- summary(fit)$s.pv

  ## Simon 2012 test statistic for sample. Copied from mgcv, but drop the subsampling part.
  res.df <- -1
  p <- fit$coefficients
  V <- fit$Vp
  rank <- sum(fit$edf1) - 1
  Xp <- stats::model.matrix(fit)
  k_col <- ncol(Xp)
  
  Tr <- testStat(p = p[2:k_col], X = Xp[, 2:k_col,drop=FALSE], V = V[2:k_col, 2:k_col,drop=FALSE], rank = rank, res.df = -1, type = 0)
  Tr <- Tr$stat

  ## Matteo 2020 qgam JASA
  if(is.null(sub.tbl) || model == "qgam") {
    return(list(fix.pv = s.pv,
                emp.pv = NA,
                para.pv = NA,
                ad.pv = NA,
                rank = rank,
                test.statistic = Tr,
                gam.fit = fit,
                zinf = zinf,
                aic = aic,
                expv.quantile = expv.quantile,
                expv.mean = expv.mean,
                expv.zero = expv.zero
    ))
  }

  ##  Subsample Tr
  n.boot <- length(sub.tbl)

  ## Redefine count.v since the original fit not necessarily includes all cell
  if(class(mat)[1] == "SingleCellExperiment") {
    expv <- SummarizedExperiment::assay(mat, assay.use)[gene, ]
    count.v <- expv
  }
  else if(class(mat)[1] == "Seurat") {
    if(assay.use == 'logcounts'){
      assay_alter <- 'data'
    }else{
      assay_alter <- assay.use
    }
    expv <- Seurat::GetAssayData(object = mat, slot = assay_alter, assay = seurat.assay)[gene, ]
    count.v <- expv
  }
  else {
    expv <- mat[gene, ]
    count.v <- expv
  }
  ## Use weights

  if(fix.weight && zinf) {
    ## Use the weights estimated from ori pseudotime

    cell_weights <- rep(1, num_cell)
    cell_weights <- fit.zinb$prior.weights
    names(cell_weights) <- ori.tbl$cell
  }
  else {
    ## All weights are 1
    cell_weights <- base::rep(1, dim(mat)[2]) #num_total_cell
    names(cell_weights) <- colnames(mat)}

  ## Start permutation
  boot_tbl <- tibble::tibble(id = seq_len(length(sub.tbl)), time.res = sub.tbl)

  boot_models <- boot_tbl %>%
    dplyr::mutate(splits = purrr::map(time.res, function(x){
      x <- cbind(expv = count.v[x$cell], pseudotime = base::sample(x$pseudotime), cellWeights = cell_weights[x$cell]) %>% as.data.frame(); x
    })) %>%
    dplyr::mutate(model = lapply(X = splits, FUN = fit_gam, distribution = distri, use_weights = fix.weight, k = k, knots = knots, usebam = usebam, fit.formula = fit.formula),
                  stat = sapply(model, function(fit) {
                    if(is.logical(fit)) {Tr <- NA}
                    else {
                      edf1 <- sum(fit$edf1)
                      res.df <- -1
                      p <- fit$coefficients
                      V <- fit$Vp
                      rank <- sum(fit$edf1) - 1
                      Xp <- stats::model.matrix(fit)

                      Tr <- testStat(p = p[2:k_col], X = Xp[, 2:k_col,drop=FALSE], V = V[2:k_col, 2:k_col,drop=FALSE], rank = rank, res.df = -1, type = 0)
                      Tr <- as.numeric(Tr$stat)}

                    return(Tr)
                  }))

  ## empirical permutation p-value
  s.pv.p <- (sum(Tr <= boot_models$stat)+1)/(n.boot+1)

  ## parametric permutation p-value
  smooth_pv <- suppressMessages(cal_pvalue(y = Tr, x = boot_models$stat, plot.fit = FALSE))

  return(list(fix.pv = s.pv,
              emp.pv = s.pv.p,
              para.pv = smooth_pv$p,
              ad.pv = smooth_pv$ad.pv,
              rank = rank,
              test.statistics = Tr,
              gam.fit = fit,
              zinf = zinf,
              aic = aic,
              expv.quantile = expv.quantile,
              expv.mean = expv.mean,
              expv.zero = expv.zero
  ))
}

# added qgam model
fit_qgam <- function(dat, quant, k, fit.formula) {


  fit.qgam <- tryCatch(expr = {
    fit <- qgam::qgam(fit.formula, data = dat, qu = quant)
    return(fit)
  },
  error = function(e){return(FALSE)})
  fit.qgam

}

## Define the fit_gam function

## set distribution method
# zinf
fit_gam <- function(dat, nthreads = 1, distribution, use_weights, knots, k, usebam, fit.formula) {
  if(use_weights) {cellWeights <- dat$cellWeights}
  else cellWeights <- rep(1, dim(dat)[1])



  fit.gam <- tryCatch(expr = {
    if(distribution == "nb") {
      if(usebam) {
        fit <- mgcv::bam(fit.formula, family = mgcv::nb(link = "log"),
                         data = dat,
                         knots = list(pseudotime = knots), control = list(nthreads = nthreads), discrete = TRUE)
      } else {
        fit <- mgcv::gam(fit.formula, family = mgcv::nb(link = "log"),
                         data = dat,
                         knots = list(pseudotime = knots), control = list(nthreads = nthreads))

      }
    }
    else if(distribution == "gaussian"){
      if(usebam) {
        fit <- mgcv::bam(fit.formula, family = stats::gaussian(link = "identity"),
                         data = dat,
                         knots = list(pseudotime = knots), control = list(nthreads = nthreads), discrete = TRUE)

      } else {
        fit <- mgcv::gam(fit.formula, family = stats::gaussian(link = "identity"),
                         data = dat,
                         knots = list(pseudotime = knots), control = list(nthreads = nthreads))
      }
    }

    else if(distribution == "zinb"){
      if(!use_weights) {
        fit <- zinbgam(fit.formula, ~ logmu,
                       data = dat, knots = knots, k = k, usebam = usebam)
        fit <- fit$fit.mu
      }
      else {
        if(usebam) {
          fit <- mgcv::bam(fit.formula, family = mgcv::nb(link = "log"),
                           data = dat,
                           knots = list(pseudotime = knots), control = list(nthreads = nthreads), weights = cellWeights, discrete = TRUE)
        } else {
          fit <- mgcv::gam(fit.formula, family = mgcv::nb(link = "log"),
                           data = dat,
                           knots = list(pseudotime = knots), control = list(nthreads = nthreads), weights = cellWeights)
        }
      }
    }
    else{stop("Specified 'distribution=' parameter is invalid. Must be one of 'nb', 'zinb', 'gaussian'.")}
    return(fit)
  },
  error = function(e){return(FALSE)})
  fit.gam
}




## Zero Inflated NB GAM. Refer to https://github.com/AustralianAntarcticDataCentre/zigam.
## Log ZINB density
dzinb.log <- function(x,mu,pi,shape) {
  logp <- log(1 - pi)+stats::dnbinom(x,size=shape,mu=mu,log=T)
  logp[x==0] <- log(exp(logp[x==0])+(pi[x==0]))
  logp
}

## Main function

zinbgam <- function(mu.formula,
                    pi.formula,
                    data,
                    mu=NULL,
                    pi=NULL,
                    min.em=5,
                    max.em=50,
                    tol=1.0e-4,
                    k,
                    knots,
                    usebam) {
  ## As matrix
  #data <- data.frame(data)
  ## Extract the response y
  mf <- stats::model.frame(stats::update(mu.formula,.~1),data=data)
  y <- stats::model.response(mf)
  N <- length(y)

  ## Response for pi component is the weights
  pi.formula <- stats::update(pi.formula, z ~ .)

  ## Get inital NB fit
  if(usebam) {
    fit.gam <- mgcv::bam(formula = mu.formula,
                         data = data, family = mgcv::nb(link = "log"), knots = list(pseudotime = knots), discrete = TRUE)

  } else {
    fit.gam <- mgcv::gam(formula = mu.formula,
                         data = data, family = mgcv::nb(link = "log"), knots = list(pseudotime = knots))

  }
  ## Set initial pi, mu
  if(is.null(mu)) mu <- stats::fitted(fit.gam)
  if(is.null(pi)) pi <- mean(y>0)

  logL <- double(max.em)
  theta <- fit.gam$family$getTheta(TRUE)
  for(k in 1:max.em) {
    ## Evaluate weights for current iteration
    z <- ifelse(y==0,pi/(pi+ (1-pi)*stats::dnbinom(0,size=theta,mu=mu)),0)
    ## Update the data (with mu and z)
    data$z <- z
    data$mu <- mu
    data$logmu <- log(mu)
    ## Update models for current iteration
    if(usebam) {
      fit.pi <- suppressWarnings(mgcv::bam(formula = pi.formula,family=stats::binomial(),data=data))
      fit.mu <- suppressWarnings(mgcv::bam(formula = mu.formula,weights=1-z,family=mgcv::nb(link = log), data=data, knots = list(pseudotime = knots), discrete = TRUE))
    } else {
      fit.pi <- suppressWarnings(mgcv::gam(formula = pi.formula,family=stats::binomial(),data=data))
      fit.mu <- suppressWarnings(mgcv::gam(formula = mu.formula,weights=1-z,family=mgcv::nb(link = log), data=data, knots = list(pseudotime = knots)))
    }

    pi <- stats::predict(fit.pi,type="response")
    mu <- stats::predict(fit.mu,type="response")
    theta <- fit.mu$family$getTheta(TRUE)
    ## Evaluate likelihood
    logL[k] <- sum(dzinb.log(y,mu,pi,theta))
    if(k >= 2) {
      if(abs(logL[k]-logL[k-1]) < tol) {
        logL <- logL[1:k]
        break
      }
    }
  }

  ## Calculate degrees of freedom and aic
  df <- attr(stats::logLik(fit.pi),"df")+attr(stats::logLik(fit.mu),"df")
  aic <- 2*(df-logL[length(logL)])
  ## Return results
  fit <- list(fit.mu=fit.mu,
              fit.pi=fit.pi,
              mu=mu,
              pi=pi,
              z=z,
              aic=aic,
              logL=logL,
              theta=theta)
  class(fit) <- "zinbgam"
  fit
}




###### mgcv function #######
#### These functions copied from Simon Woods's mgcv package.
##  testStat
testStat <- function(p,X,V,rank=NULL,type=0,res.df= -1) {
  ## Implements Wood (2013) Biometrika 100(1), 221-228
  ## The type argument specifies the type of truncation to use.
  ## on entry `rank' should be an edf estimate
  ## 0. Default using the fractionally truncated pinv.
  ## 1. Round down to k if k<= rank < k+0.05, otherwise up.
  ## res.df is residual dof used to estimate scale. <=0 implies
  ## fixed scale.

  qrx <- qr(X,tol=0)
  R <- qr.R(qrx)
  V <- R%*%V[qrx$pivot,qrx$pivot,drop=FALSE]%*%t(R)
  V <- (V + t(V))/2
  ed <- eigen(V,symmetric=TRUE)
  ## remove possible ambiguity from statistic...
  siv <- sign(ed$vectors[1,]);siv[siv==0] <- 1
  ed$vectors <- sweep(ed$vectors,2,siv,"*")

  k <- max(0,floor(rank))
  nu <- abs(rank - k)     ## fractional part of supplied edf
  if (type==1) { ## round up is more than .05 above lower
    if (rank > k + .05||k==0) k <- k + 1
    nu <- 0;rank <- k
  }

  if (nu>0) k1 <- k+1 else k1 <- k

  ## check that actual rank is not below supplied rank+1
  r.est <- sum(ed$values > max(ed$values)*.Machine$double.eps^.9)
  if (r.est<k1) {k1 <- k <- r.est;nu <- 0;rank <- r.est}

  ## Get the eigenvectors...
  # vec <- qr.qy(qrx,rbind(ed$vectors,matrix(0,nrow(X)-ncol(X),ncol(X))))
  vec <- ed$vectors
  if (k1<ncol(vec)) vec <- vec[,1:k1,drop=FALSE]

  ## deal with the fractional part of the pinv...
  if (nu>0&&k>0) {
    if (k>1) vec[,1:(k-1)] <- t(t(vec[,1:(k-1)])/sqrt(ed$val[1:(k-1)]))
    b12 <- .5*nu*(1-nu)
    if (b12<0) b12 <- 0
    b12 <- sqrt(b12)
    B <- matrix(c(1,b12,b12,nu),2,2)
    ev <- diag(ed$values[k:k1]^-.5,nrow=k1-k+1)
    B <- ev%*%B%*%ev
    eb <- eigen(B,symmetric=TRUE)
    rB <- eb$vectors%*%diag(sqrt(eb$values))%*%t(eb$vectors)
    vec1 <- vec
    vec1[,k:k1] <- t(rB%*%diag(c(-1,1))%*%t(vec[,k:k1]))
    vec[,k:k1] <- t(rB%*%t(vec[,k:k1]))
  } else {
    vec1 <- vec <- if (k==0) t(t(vec)*sqrt(1/ed$val[1])) else
      t(t(vec)/sqrt(ed$val[1:k]))
    if (k==1) rank <- 1
  }
  ## there is an ambiguity in the choise of test statistic, leading to slight
  ## differences in the p-value computation depending on which of 2 alternatives
  ## is arbitrarily selected. Following allows both to be computed and p-values
  ## averaged (can't average test stat as dist then unknown)
  d <- t(vec)%*%(R%*%p)
  d <- sum(d^2)
  d1 <- t(vec1)%*%(R%*%p)
  d1 <- sum(d1^2)
  ##d <- d1 ## uncomment to avoid averaging

  rank1 <- rank ## rank for lower tail pval computation below

  ## note that for <1 edf then d is not weighted by EDF, and instead is
  ## simply refered to a chi-squared 1

  if (nu>0) { ## mixture of chi^2 ref dist
    if (k1==1) rank1 <- val <- 1 else {
      val <- rep(1,k1) ##ed$val[1:k1]
      rp <- nu+1
      val[k] <- (rp + sqrt(rp*(2-rp)))/2
      val[k1] <- (rp - val[k])
    }

    if (res.df <= 0) pval <- (liu2(d,val) + liu2(d1,val))/2 else ##  pval <- davies(d,val)$Qq else
      pval <- (simf(d,val,res.df) + simf(d1,val,res.df))/2
  } else { pval <- 2 }
  ## integer case still needs computing, also liu/pearson approx only good in
  ## upper tail. In lower tail, 2 moment approximation is better (Can check this
  ## by simply plotting the whole interesting range as a contour plot!)
  if (pval > .5) {
    if (res.df <= 0) pval <- (stats::pchisq(d,df=rank1,lower.tail=FALSE)+stats::pchisq(d1,df=rank1,lower.tail=FALSE))/2 else
      pval <- (stats::pf(d/rank1,rank1,res.df,lower.tail=FALSE)+stats::pf(d1/rank1,rank1,res.df,lower.tail=FALSE))/2
  }
  list(stat=d,pval=min(1,pval),rank=rank)

} ## end of testStat

## Other functions within testStat()
liu2 <- function(x, lambda, h = rep(1,length(lambda)),lower.tail=FALSE) {
  # Evaluate Pr[sum_i \lambda_i \chi^2_h_i < x] approximately.
  # Code adapted from CompQuadForm package of Pierre Lafaye de Micheaux
  # and directly from....
  # H. Liu, Y. Tang, H.H. Zhang, A new chi-square approximation to the
  # distribution of non-negative definite quadratic forms in non-central
  # normal variables, Computational Statistics and Data Analysis, Volume 53,
  # (2009), 853-856. Actually, this is just Pearson (1959) given that
  # the chi^2 variables are central.
  # Note that this can be rubbish in lower tail (e.g. lambda=c(1.2,.3), x = .15)

  #  if (FALSE) { ## use Davies exact method in place of Liu et al/ Pearson approx.
  #    require(CompQuadForm)
  #    r <- x
  #    for (i in 1:length(x)) r[i] <- davies(x[i],lambda,h)$Qq
  #    return(pmin(r,1))
  #  }

  if (length(h) != length(lambda)) stop("lambda and h should have the same length!")

  lh <- lambda*h
  muQ <- sum(lh)

  lh <- lh*lambda
  c2 <- sum(lh)

  lh <- lh*lambda
  c3 <- sum(lh)

  s1 <- c3/c2^1.5
  s2 <- sum(lh*lambda)/c2^2

  sigQ <- sqrt(2*c2)

  t <- (x-muQ)/sigQ

  if (s1^2>s2) {
    a <- 1/(s1-sqrt(s1^2-s2))
    delta <- s1*a^3-a^2
    l <- a^2-2*delta
  } else {
    a <- 1/s1
    delta <- 0
    l <- c2^3/c3^2
  }

  muX <- l+delta
  sigX <- sqrt(2)*a

  return(stats::pchisq(t*sigX+muX,df=l,ncp=delta,lower.tail=lower.tail))

}

simf <- function(x,a,df,nq=50) {
  ## suppose T = sum(a_i \chi^2_1)/(chi^2_df/df). We need
  ## Pr[T>x] = Pr(sum(a_i \chi^2_1) > x *chi^2_df/df). Quadrature
  ## used here. So, e.g.
  ## 1-pf(4/3,3,40);simf(4,rep(1,3),40);1-pchisq(4,3)
  p <- (1:nq-.5)/nq
  q <- stats::qchisq(p,df)
  x <- x*q/df
  pr <- sum(liu2(x,a)) ## Pearson/Liu approx to chi^2 mixture
  pr/nq
}


## Gamma Mixture Fit of Permutation p-value
# Gamma / Two Component Gamma Mixture Permutation p-value (parametric p-value)

### Mix gamma pdf
dgammamix <- function(x, lambda, gamma.pars, k = dim(gamma.pars)[2]) {
  if(k == 2) d <- lambda[1]*stats::dgamma(x, shape = gamma.pars[1, 1], scale = gamma.pars[2, 1]) + lambda[2]*stats::dgamma(x, shape = gamma.pars[1, 2], scale = gamma.pars[2, 2])
  else  if(k == 3) d <- lambda[1]*stats::dgamma(x, shape = gamma.pars[1, 1], scale = gamma.pars[2, 1]) +
      lambda[2]*stats::dgamma(x, shape = gamma.pars[1, 2], scale = gamma.pars[2, 2]) + lambda[3]*stats::dgamma(x, shape = gamma.pars[1, 3], scale = gamma.pars[2, 3])
  else d <- NA
  d
}

### Mix gamma cdf
pgammamix <- function(x, lambda, gamma.pars, k = dim(gamma.pars)[2], lower.tail = FALSE) {
  if(k == 2) p <- lambda[1]*stats::pgamma(x, shape = gamma.pars[1, 1], scale = gamma.pars[2, 1], lower.tail = lower.tail) + lambda[2]*stats::pgamma(x, shape = gamma.pars[1, 2], scale = gamma.pars[2, 2], lower.tail = lower.tail)
  else if (k == 3) p <- lambda[1]*stats::pgamma(x, shape = gamma.pars[1, 1], scale = gamma.pars[2, 1], lower.tail = lower.tail) +
      lambda[2]*stats::pgamma(x, shape = gamma.pars[1, 2], scale = gamma.pars[2, 2], lower.tail = lower.tail) + lambda[3]*stats::pgamma(x, shape = gamma.pars[1, 3], scale = gamma.pars[2, 3], lower.tail = lower.tail)
  else p <- NA
  p
}

### Mix gamma rv
rgammamix <- function(n = 1, lambda, gamma.pars, k = dim(gamma.pars)[2]) {
  if(k == 2) {
    z <- stats::rbinom(n, size = 1, lambda[1])
    r <- z*stats::rgamma(n, shape = gamma.pars[1, 1], scale = gamma.pars[2, 1]) + (1 - z)*stats::rgamma(n, shape = gamma.pars[1, 2], scale = gamma.pars[2, 2])
  }
  else if (k == 3) {
    z <- stats::rmultinom(n, size = 1, prob = c(lambda[1], lambda[2], lambda[3]))
    y1 <- stats::rgamma(n, shape = gamma.pars[1, 1], scale = gamma.pars[2, 1])
    y2 <- stats::rgamma(n, shape = gamma.pars[1, 2], scale = gamma.pars[2, 2])
    y3 <- stats::rgamma(n, shape = gamma.pars[1, 3], scale = gamma.pars[2, 3])

    y <- rbind(y1, y2, y3)

    r <- colSums(z*y)
  }
  else r <- NA

  r
}

### Suppress print in mixtools
quiet <- function(x) {
  sink(tempfile())
  on.exit(sink())
  invisible(force(x))
}

### Calculate p-value
cal_pvalue <- function(x, y, epsilon = 1e-4, p.thresh = 0.01, plot.fit = FALSE) {
  x <- as.vector(x)

  ## gamma fit
  fit1 <- suppressWarnings(fitdistrplus::fitdist(x, "gamma"))
  test.res <- goftest::ad.test(x, null = "pgamma", shape = fit1$estimate[1], rate = fit1$estimate[2], estimated = TRUE)

  p <- stats::pgamma(y, shape = fit1$estimate[1], rate = fit1$estimate[2], lower.tail = FALSE)

  ## gamma mix fit
  if(test.res$p.value < 1) {
    fit2 <- quiet(mixtools::gammamixEM(x, maxit = 100, k = 2, epsilon = epsilon, mom.start = TRUE, maxrestarts = 20))

    ## three random start
    for (i in 1:3) {
      set.seed(i)
      fit.temp <- quiet(mixtools::gammamixEM(x, maxit = 100, k = 2, epsilon = epsilon, maxrestarts = 20, lambda = c(i/10, 1-i/10)))
      if(fit2$loglik < fit.temp$loglik) fit2 <- fit.temp
    }

    test.res2 <- goftest::ad.test(x, null = pgammamix, lambda = fit2$lambda, gamma.pars = fit2$gamma.pars, lower.tail = TRUE, estimated = TRUE)

    ## LRT gamma vs gamma mix
    LRT.p <- stats::pchisq(2*(fit2$loglik - fit1$loglik), df = 3, lower.tail = FALSE)

    if(LRT.p < p.thresh) {
      p <- pgammamix(y, lambda = fit2$lambda, gamma.pars = fit2$gamma.pars)
      test.res <- test.res2
      #if(test.res2$p.value < p.thresh) warning(paste0("Final fit does not pass AD test, with p-value ", test.res2$p.value))
    }
  }
  #
  return(list(p = p, ad.pv = test.res$p.value))
}
SONGDONGYUAN1994/PseudotimeDE documentation built on Nov. 2, 2024, 7:55 p.m.