R/functions.R

Defines functions papply calculate.go.enrichment c.view.pathways pagoda.show.pathways t.view.pathways col2hex collapse.aspect.clusters pathway.pc.correlation.distance bh.adjust pgev.upper.log gev.t weightedMatVar weightedMatCenter quick.distribution.summary pairs.panel.smoothScatter pairs.panel.scatter pairs.panel.cor pairs.panel.hist sn glm.nb.fit FLXPmultinomW FLXMRglmCf FLXMRglmC FLXMRnb2gthC FLXMRnb2gth get.corr.theta FLXMRnb2glmC FLXMRnb2glm log.row.sums get.concomitant.prob get.exp.sample get.exp.posterior.samples get.exp.logposterior.matrix get.exp.posterior.matrix get.component.model.loglik get.component.model.lik get.rep.set.general.model.logposteriors get.rep.set.general.model.posteriors get.rep.set.posteriors mc.stepFlexmix plot.nb2.mixture.fit fit.nb2gth.mixture.model fit.nb2.mixture.model .onUnload .onAttach get.fpm.estimates calculate.failure.lfpm.p calculate.failure.p matSlideMult jpmatLogBatchBoot jpmatLogBoot get.ratio.posterior.Z.score calculate.ratio.posterior calculate.batch.joint.posterior.matrix calculate.joint.posterior.matrix sample.posterior calculate.posterior.matrices get.compressed.v1.model get.compressed.v1.models calculate.individual.models estimate.signal.prior estimate.library.sizes calculate.crossfit.models one.sided.test.id make.pagoda.app view.aspects pagoda.view.aspects pagoda.cluster.cells pagoda.reduce.redundancy pagoda.reduce.loading.redundancy pagoda.top.aspects pagoda.gene.clusters pagoda.effective.cells pagoda.pathway.wPCA pagoda.subtract.aspect pagoda.varnorm knn.error.models winsorize.matrix bwpca scde.fit.models.to.reference scde.test.gene.expression.difference scde.failure.probability scde.expression.magnitude scde.posteriors get.scde.server show.app scde.browse.diffexp scde.expression.difference scde.expression.prior scde.error.models clean.counts clean.gos

Documented in bwpca clean.counts clean.gos knn.error.models make.pagoda.app pagoda.cluster.cells pagoda.effective.cells pagoda.gene.clusters pagoda.pathway.wPCA pagoda.reduce.loading.redundancy pagoda.reduce.redundancy pagoda.show.pathways pagoda.subtract.aspect pagoda.top.aspects pagoda.varnorm pagoda.view.aspects papply scde.browse.diffexp scde.error.models scde.expression.difference scde.expression.magnitude scde.expression.prior scde.failure.probability scde.fit.models.to.reference scde.posteriors scde.test.gene.expression.difference show.app view.aspects winsorize.matrix

##' Single-cell Differential Expression (with Pathway And Gene set Overdispersion Analysis)
##'
##' The scde package implements a set of statistical methods for analyzing single-cell RNA-seq data.
##' scde fits individual error models for single-cell RNA-seq measurements. These models can then be used for
##' assessment of differential expression between groups of cells, as well as other types of analysis.
##' The scde package also contains the pagoda framework which applies pathway and gene set overdispersion analysis
##' to identify and characterize putative cell subpopulations based on transcriptional signatures.
##' See vignette("diffexp") for a brief tutorial on differential expression analysis.
##' See vignette("pagoda") for a brief tutorial on pathway and gene set overdispersion analysis to identify and characterize cell subpopulations.
##' More extensive tutorials are available at \url{http://pklab.med.harvard.edu/scde/index.html}.
##'  (test)
##' @name scde
##' @docType package
##' @author Peter Kharchenko \email{Peter_Kharchenko@@hms.harvard.edu}
##' @author Jean Fan \email{jeanfan@@fas.harvard.edu}
NULL

################################# Sample data

##' Sample data
##'
##' A subset of Saiful et al. 2011 dataset containing first 20 ES and 20 MEF cells.
##'
##' @name es.mef.small
##' @docType data
##' @references \url{http://www.ncbi.nlm.nih.gov/pubmed/21543516}
##' @export
NULL

##' Sample data
##'
##' Single cell data from Pollen et al. 2014 dataset.
##'
##' @name pollen
##' @docType data
##' @references \url{www.ncbi.nlm.nih.gov/pubmed/25086649}
##' @export
NULL

##' Sample error model
##'
##' SCDE error model generated from a subset of Saiful et al. 2011 dataset containing first 20 ES and 20 MEF cells.
##'
##' @name o.ifm
##' @docType data
##' @references \url{http://www.ncbi.nlm.nih.gov/pubmed/21543516}
##' @export
NULL

##' Sample error model
##'
##' SCDE error model generated from the Pollen et al. 2014 dataset.
##'
##' @name knn
##' @docType data
##' @references \url{www.ncbi.nlm.nih.gov/pubmed/25086649}
##' @export
NULL

# Internal model data
#
# Numerically-derived correction for NB->chi squared approximation stored as an local regression model
#
# @name scde.edff


################################# Generic methods

##' Filter GOs list
##'
##' Filter GOs list and append GO names when appropriate
##'
##' @param go.env GO or gene set list
##' @param min.size Minimum size for number of genes in a gene set (default: 5)
##' @param max.size Maximum size for number of genes in a gene set (default: 5000)
##' @param annot Whether to append GO annotations for easier interpretation (default: FALSE)
##'
##' @return a filtered GO list
##'
##' @examples
##' \donttest{
##' # 10 sample GOs
##' library(org.Hs.eg.db)
##' go.env <- mget(ls(org.Hs.egGO2ALLEGS)[1:10], org.Hs.egGO2ALLEGS)
##' # Filter this list and append names for easier interpretation
##' go.env <- clean.gos(go.env)
##' }
##'
##' @export
clean.gos <- function(go.env, min.size = 5, max.size = 5000, annot = FALSE) {
  go.env <- as.list(go.env)
  size <- unlist(lapply(go.env, length))
  go.env <- go.env[size > min.size & size < max.size]
  # If we have GO.db installed, then add the term to each GO code.
  if (annot && "GO.db" %in% installed.packages()[,1]) {
    desc <- MASS::select(
      GO.db,
      keys = names(go.env),
      columns = c("TERM"),
      multiVals = 'CharacterList'
    )
    stopifnot(all(names(go.env) == desc$GOID))
    names(go.env) <- paste(names(go.env), desc$TERM)
  }
  return(go.env)
}


##' Filter counts matrix
##'
##' Filter counts matrix based on gene and cell requirements
##'
##' @param counts read count matrix. The rows correspond to genes, columns correspond to individual cells
##' @param min.lib.size Minimum number of genes detected in a cell. Cells with fewer genes will be removed (default: 1.8e3)
##' @param min.reads Minimum number of reads per gene. Genes with fewer reads will be removed (default: 10)
##' @param min.detected Minimum number of cells a gene must be seen in. Genes not seen in a sufficient number of cells will be removed (default: 5)
##'
##' @return a filtered read count matrix
##'
##' @examples
##' data(pollen)
##' dim(pollen)
##' cd <- clean.counts(pollen)
##' dim(cd)
##'
##' @export
clean.counts <- function(counts, min.lib.size = 1.8e3, min.reads = 10, min.detected = 5) {
    # filter out low-gene cells
    counts <- counts[, colSums(counts>0)>min.lib.size]
    # remove genes that don't have many reads
    counts <- counts[rowSums(counts)>min.reads, ]
    # remove genes that are not seen in a sufficient number of cells
    counts <- counts[rowSums(counts>0)>min.detected, ]
    return(counts)
}

################################# SCDE Methods

##' Fit single-cell error/regression models
##'
##' Fit error models given a set of single-cell data (counts) and an optional grouping factor (groups). The cells (within each group) are first cross-compared to determine a subset of genes showing consistent expression. The set of genes is then used to fit a mixture model (Poisson-NB mixture, with expression-dependent concomitant).
##'
##' Note: the default implementation has been changed to use linear-scale fit with expression-dependent NB size (overdispersion) fit. This represents an interative improvement on the originally published model. Use linear.fit=F to revert back to the original fitting procedure.
##'
##' @param counts read count matrix. The rows correspond to genes (should be named), columns correspond to individual cells. The matrix should contain integer counts
##' @param groups an optional factor describing grouping of different cells. If provided, the cross-fits and the expected expression magnitudes will be determined separately within each group. The factor should have the same length as ncol(counts).
##' @param min.nonfailed minimal number of non-failed observations required for a gene to be used in the final model fitting
##' @param threshold.segmentation use a fast threshold-based segmentation during cross-fit (default: TRUE)
##' @param min.count.threshold the number of reads to use to guess which genes may have "failed" to be detected in a given measurement during cross-cell comparison (default: 4)
##' @param zero.count.threshold threshold to guess the initial value (failed/non-failed) during error model fitting procedure (defaults to the min.count.threshold value)
##' @param zero.lambda the rate of the Poisson (failure) component (default: 0.1)
##' @param save.crossfit.plots whether png files showing cross-fit segmentations should be written out (default: FALSE)
##' @param save.model.plots whether pdf files showing model fits should be written out (default = TRUE)
##' @param n.cores number of cores to use
##' @param min.size.entries minimum number of genes to use when determining expected expression magnitude during model fitting
##' @param max.pairs maximum number of cross-fit comparisons that should be performed per group (default: 5000)
##' @param min.pairs.per.cell minimum number of pairs that each cell should be cross-compared with
##' @param verbose 1 for increased output
##' @param linear.fit Boolean of whether to use a linear fit in the regression (default: TRUE).
##' @param local.theta.fit Boolean of whether to fit the overdispersion parameter theta, ie. the negative binomial size parameter, based on local regression (default: set to be equal to the linear.fit parameter)
##' @param theta.fit.range Range of valid values for the overdispersion parameter theta, ie. the negative binomial size parameter (default: c(1e-2, 1e2))
##'
##' @return a model matrix, with rows corresponding to different cells, and columns representing different parameters of the determined models
##'
##' @useDynLib scde
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' sg <- factor(gsub("(MEF|ESC).*", "\\1", colnames(cd)), levels = c("ESC", "MEF"))
##' names(sg) <- colnames(cd)
##' \donttest{
##' o.ifm <- scde.error.models(counts = cd, groups = sg, n.cores = 10, threshold.segmentation = TRUE)
##' }
##'
##' @export
scde.error.models <- function(counts, groups = NULL, min.nonfailed = 3, threshold.segmentation = TRUE, min.count.threshold = 4, zero.count.threshold = min.count.threshold, zero.lambda = 0.1, save.crossfit.plots = FALSE, save.model.plots = TRUE, n.cores = 12, min.size.entries = 2e3, max.pairs = 5000, min.pairs.per.cell = 10, verbose = 0, linear.fit = TRUE, local.theta.fit = linear.fit, theta.fit.range = c(1e-2, 1e2)) {
    # default same group
    if(is.null(groups)) {
        groups <- as.factor(rep("cell", ncol(counts)))
    }
    # check for integer counts
    if(any(!unlist(lapply(counts,is.integer)))) {
      stop("Some of the supplied counts are not integer values (or stored as non-integer types). Aborting!\nThe method is designed to work on read counts - do not pass normalized read counts (e.g. FPKM values). If matrix contains read counts, but they are stored as numeric values, use counts<-apply(counts,2,function(x) {storage.mode(x) <- 'integer'; x}) to recast.");
    }

    # crossfit
    if(verbose) {
        cat("cross-fitting cells.\n")
    }
    cfm <- calculate.crossfit.models(counts, groups, n.cores = n.cores, threshold.segmentation = threshold.segmentation, min.count.threshold = min.count.threshold, zero.lambda = zero.lambda, max.pairs = max.pairs, save.plots = save.crossfit.plots, min.pairs.per.cell = min.pairs.per.cell, verbose = verbose)
    # error model for each cell
    if(verbose) {
        cat("building individual error models.\n")
    }
    ifm <- calculate.individual.models(counts, groups, cfm, min.nonfailed = min.nonfailed, zero.count.threshold = zero.count.threshold, n.cores = n.cores, save.plots = save.model.plots, linear.fit = linear.fit, return.compressed.models = TRUE, verbose = verbose, min.size.entries = min.size.entries, local.theta.fit = local.theta.fit, theta.fit.range = theta.fit.range)
    rm(cfm)
    gc()
    return(ifm)
}


##' Estimate prior distribution for gene expression magnitudes
##'
##' Use existing count data to determine a prior distribution of genes in the dataset
##'
##' @param models models determined by \code{\link{scde.error.models}}
##' @param counts count matrix
##' @param length.out number of points (resolution) of the expression magnitude grid (default: 400). Note: larger numbers will linearly increase memory/CPU demands.
##' @param show.plot show the estimate posterior
##' @param pseudo.count pseudo-count value to use (default 1)
##' @param bw smoothing bandwidth to use in estimating the prior (default: 0.1)
##' @param max.quantile determine the maximum expression magnitude based on a quantile (default : 0.999)
##' @param max.value alternatively, specify the exact maximum expression magnitude value
##'
##' @return a structure describing expression magnitude grid ($x, on log10 scale) and prior ($y)
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' data(o.ifm)  # Load precomputed model. Use ?scde.error.models to see how o.ifm was generated
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##'
##' @export
scde.expression.prior <- function(models, counts, length.out = 400, show.plot = FALSE, pseudo.count = 1, bw = 0.1, max.quantile = 1-1e-3, max.value = NULL) {
    fpkm <- scde.expression.magnitude(models, counts)
    fail <- scde.failure.probability(models, counts = counts)
    fpkm <- log10(exp(as.matrix(fpkm))+1)
    wts <- as.numeric(as.matrix(1-fail[, colnames(fpkm)]))
    wts <- wts/sum(wts)

    # fit density on a mirror image
    if(is.null(max.value)) {
        x <- as.numeric(fpkm)
        max.value <- as.numeric(quantile(x[x<Inf], p = max.quantile))
    }
    md <- density(c(-1*as.numeric(fpkm), as.numeric(fpkm)), bw = bw, weights = c(wts/2, wts/2), n = 2*length.out+1, from = -1*max.value, to = max.value)

    gep <- data.frame(x = md$x[-seq_len(length.out)], y = md$y[-seq_len(length.out)])
    gep$y[is.na(gep$y)] <- 0
    gep$y <- gep$y+pseudo.count/nrow(fpkm) # pseudo-count
    gep$y <- gep$y/sum(gep$y)
    if(show.plot) {
        par(mfrow = c(1, 1), mar = c(3.5, 3.5, 3.5, 0.5), mgp = c(2.0, 0.65, 0), cex = 0.9)
        plot(gep$x, gep$y, col = 4, panel.first = abline(h = 0, lty = 2), type = 'l', xlab = "log10( signal+1 )", ylab = "probability density", main = "signal prior")
    }
    gep$lp <- log(gep$y)

    # grid weighting (for normalization)
    gep$grid.weight <- diff(10^c(gep$x[1], gep$x+c(diff(gep$x)/2, 0))-1)

    return(gep)
    plot(x)
}


##' Test for expression differences between two sets of cells
##'
##' Use the individual cell error models to test for differential expression between two groups of cells.
##'
##' @param models models determined by \code{\link{scde.error.models}}
##' @param counts read count matrix
##' @param prior gene expression prior as determined by \code{\link{scde.expression.prior}}
##' @param groups a factor determining the two groups of cells being compared. The factor entries should correspond to the rows of the model matrix. The factor should have two levels. NAs are allowed (cells will be omitted from comparison).
##' @param batch a factor (corresponding to rows of the model matrix) specifying batch assignment of each cell, to perform batch correction
##' @param n.randomizations number of bootstrap randomizations to be performed
##' @param n.cores number of cores to utilize
##' @param batch.models (optional) separate models for the batch data (if generated using batch-specific group argument). Normally the same models are used.
##' @param return.posteriors whether joint posterior matrices should be returned
##' @param verbose integer verbose level (1 for verbose)
##'
##' @return \subsection{default}{
##' a data frame with the following fields:
##' \itemize{
##' \item{lb, mle, ub} {lower bound, maximum likelihood estimate, and upper bound of the 95% confidence interval for the expression fold change on log2 scale.}
##' \item{ce} { conservative estimate of expression-fold change (equals to the min(abs(c(lb, ub))), or 0 if the CI crosses the 0}
##' \item{Z} { uncorrected Z-score of expression difference}
##' \item{cZ} {expression difference Z-score corrected for multiple hypothesis testing using Holm procedure}
##' }
##'  If batch correction has been performed (\code{batch} has been supplied), analogous data frames are returned in slots \code{$batch.adjusted} for batch-corrected results, and \code{$batch.effect} for the differences explained by batch effects alone.
##' }}
##' \subsection{return.posteriors = TRUE}{
##' A list is returned, with the default results data frame given in the \code{$results} slot.
##' \code{difference.posterior} returns a matrix of estimated expression difference posteriors (rows - genes, columns correspond to different magnitudes of fold-change - log2 values are given in the column names)
##' \code{joint.posteriors} a list of two joint posterior matrices (rows - genes, columns correspond to the expression levels, given by prior$x grid)
##' }
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' sg <- factor(gsub("(MEF|ESC).*", "\\1", colnames(cd)), levels = c("ESC", "MEF"))
##' names(sg) <- colnames(cd)
##' \donttest{
##' o.ifm <- scde.error.models(counts = cd, groups = sg, n.cores = 10, threshold.segmentation = TRUE)
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##' # make sure groups corresponds to the models (o.ifm)
##' groups <- factor(gsub("(MEF|ESC).*", "\\1", rownames(o.ifm)), levels = c("ESC", "MEF"))
##' names(groups) <- row.names(o.ifm)
##' ediff <- scde.expression.difference(o.ifm, cd, o.prior, groups = groups, n.randomizations = 100, n.cores = n.cores, verbose = 1)
##' }
##'
##' @export
scde.expression.difference <- function(models, counts, prior, groups = NULL, batch = NULL, n.randomizations = 150, n.cores = 10, batch.models = models, return.posteriors = FALSE, verbose = 0) {
    if(!all(rownames(models) %in% colnames(counts))) {
        stop("ERROR: provided count data does not cover all of the cells specified in the model matrix")
    }

    ci <- match(rownames(models), colnames(counts))
    counts <- as.matrix(counts[, ci])

    if(is.null(groups)) { # recover groups from models
        groups <- as.factor(attr(models, "groups"))
        if(is.null(groups)) stop("ERROR: groups factor is not provided, and models structure is lacking groups attribute")
        names(groups) <- rownames(models)
    }
    if(length(levels(groups)) != 2) {
        stop(paste("ERROR: wrong number of levels in the grouping factor (", paste(levels(groups), collapse = " "), "), but must be two.", sep = ""))
    }

    correct.batch <- FALSE
    if(!is.null(batch)) {
        if(length(levels(batch)) > 1) {
            correct.batch <- TRUE
        } else {
            if(verbose) {
                cat("WARNING: only one batch level detected. Nothing to correct for.")
            }
        }
    }

    # batch control
    if(correct.batch) {
        batch <- as.factor(batch)
        # check batch-group interactions
        bgti <- table(groups, batch)
        bgti.ft <- fisher.test(bgti)
        if(verbose) {
            cat("controlling for batch effects. interaction:\n")
            print(bgti)
        }
        #if(any(bgti == 0)) {
        #  cat("ERROR: cannot control for batch effect, as some batches are found only in one group:\n")
        #  print(bgti)
        #}
        if(bgti.ft$p.value < 1e-3) {
            cat("WARNING: strong interaction between groups and batches! Correction may be ineffective:\n")
            print(bgti.ft)
        }

        # calculate batch posterior
        if(verbose) {
            cat("calculating batch posteriors\n")
        }
        batch.jpl <- tapply(seq_len(nrow(models)), groups, function(ii) {
            scde.posteriors(models = batch.models, counts = counts, prior = prior, batch = batch, composition = table(batch[ii]), n.cores = n.cores, n.randomizations = n.randomizations, return.individual.posteriors = FALSE)
        })
        if(verbose) {
            cat("calculating batch differences\n")
        }
        batch.bdiffp <- calculate.ratio.posterior(batch.jpl[[1]], batch.jpl[[2]], prior, n.cores = n.cores)
        batch.bdiffp.rep <- quick.distribution.summary(batch.bdiffp)
    } else {
        if(verbose) {
            cat("comparing groups:\n")
            print(table(as.character(groups)))
        }
    }


    # fit joint posteriors for each group
    jpl <- tapply(seq_len(nrow(models)), groups, function(ii) {
        scde.posteriors(models = models[ii, , drop = FALSE], counts = counts[, ii, drop = FALSE], prior = prior, n.cores = n.cores, n.randomizations = n.randomizations)
    })
    if(verbose) {
        cat("calculating difference posterior\n")
    }
    # calculate difference posterior
    bdiffp <- calculate.ratio.posterior(jpl[[1]], jpl[[2]], prior, n.cores = n.cores)

    if(verbose) {
        cat("summarizing differences\n")
    }
    bdiffp.rep <- quick.distribution.summary(bdiffp)

    if(correct.batch) {
        if(verbose) {
            cat("adjusting for batch effects\n")
        }
        # adjust for batch effects
        a.bdiffp <- calculate.ratio.posterior(bdiffp, batch.bdiffp, prior = data.frame(x = as.numeric(colnames(bdiffp)), y = rep(1/ncol(bdiffp), ncol(bdiffp))), skip.prior.adjustment = TRUE, n.cores = n.cores)
        a.bdiffp.rep <- quick.distribution.summary(a.bdiffp)

        # return with batch correction info
        if(return.posteriors) {
            return(list(batch.adjusted = a.bdiffp.rep, results = bdiffp.rep, batch.effect = batch.bdiffp.rep, difference.posterior = bdiffp, batch.adjusted.difference.posterior = a.bdiffp, joint.posteriors = jpl))
        } else {
            return(list(batch.adjusted = a.bdiffp.rep, results = bdiffp.rep, batch.effect = batch.bdiffp.rep))
        }
    } else {
        # no batch correction return
        if(return.posteriors) {
            return(list(results = bdiffp.rep, difference.posterior = bdiffp, joint.posteriors = jpl))
        } else {
            return(bdiffp.rep)
        }
    }
}


##' View differential expression results in a browser
##'
##' Launches a browser app that shows the differential expression results, allowing to sort, filter, etc.
##' The arguments generally correspond to the \code{scde.expression.difference()} call, except that the results of that call are also passed here. Requires \code{Rook} and \code{rjson} packages to be installed.
##'
##' @param results result object returned by \code{scde.expression.difference()}. Note to browse group posterior levels, use \code{return.posteriors = TRUE} in the \code{scde.expression.difference()} call.
##' @param models model matrix
##' @param counts count matrix
##' @param prior prior
##' @param groups group information
##' @param batch batch information
##' @param geneLookupURL The URL that will be used to construct links to view more information on gene names. By default (if can't guess the organism) the links will forward to ENSEMBL site search, using \code{geneLookupURL = "http://useast.ensembl.org/Multi/Search/Results?q = {0}"}. The "{0}" in the end will be substituted with the gene name. For instance, to link to GeneCards, use \code{"http://www.genecards.org/cgi-bin/carddisp.pl?gene = {0}"}.
##' @param server optional previously returned instance of the server, if want to reuse it.
##' @param name app name (needs to be altered only if adding more than one app to the server using \code{server} parameter)
##' @param port Interactive browser port
##'
##' @return server instance, on which $stop() function can be called to kill the process.
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' sg <- factor(gsub("(MEF|ESC).*", "\\1", colnames(cd)), levels = c("ESC", "MEF"))
##' names(sg) <- colnames(cd)
##' \donttest{
##' o.ifm <- scde.error.models(counts = cd, groups = sg, n.cores = 10, threshold.segmentation = TRUE)
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##' # make sure groups corresponds to the models (o.ifm)
##' groups <- factor(gsub("(MEF|ESC).*", "\\1", rownames(o.ifm)), levels = c("ESC", "MEF"))
##' names(groups) <- row.names(o.ifm)
##' ediff <- scde.expression.difference(o.ifm, cd, o.prior, groups = groups, n.randomizations = 100, n.cores = 10, verbose = 1)
##' scde.browse.diffexp(ediff, o.ifm, cd, o.prior, groups = groups, geneLookupURL="http://www.informatics.jax.org/searchtool/Search.do?query={0}")  # creates browser
##' }
##'
##' @export
scde.browse.diffexp <- function(results, models, counts, prior, groups = NULL, batch = NULL, geneLookupURL = NULL, server = NULL, name = "scde", port = NULL) {
    #require(Rook)
    #require(rjson)
    if(is.null(server)) { server <- get.scde.server(port) }
    sa <- ViewDiff$new(results, models, counts, prior, groups = groups, batch = batch, geneLookupURL = geneLookupURL)
    server$add(app = sa, name = name)
    browseURL(paste(server$full_url(name), "index.html", sep = "/"))
    return(server)
}


##' View PAGODA application
##'
##' Installs a given pagoda app (or any other rook app) into a server, optionally
##' making a call to show it in the browser.
##'
##' @param app pagoda app (output of make.pagoda.app()) or another rook app
##' @param name URL path name for this app
##' @param browse whether a call should be made for browser to show the app
##' @param port optional port on which the server should be initiated
##' @param ip IP on which the server should listen (typically localhost)
##' @param server an (optional) Rook server instance (defaults to ___scde.server)
##'
##' @examples
##' \donttest{
##' app <- make.pagoda.app(tamr2, tam, varinfo, go.env, pwpca, clpca, col.cols=col.cols, cell.clustering=hc, title="NPCs")
##' # show app in the browser (port 1468)
##' show.app(app, "pollen", browse = TRUE, port=1468)
##' }
##'
##' @return Rook server instance
##'
##' @export
show.app <- function(app, name, browse = TRUE, port = NULL, ip = '127.0.0.1', server = NULL) {
    # replace special characters
    name <- gsub("[^[:alnum:]]", "_", name)
    
    if (tools:::httpdPort() !=0 && tools:::httpdPort() != port) {
        cat("ERROR: port is already being used. The PAGODA app is currently incompatible with RStudio. Please try running the interactive app in the R console.")
    }
    if(is.null(server)) { server <- get.scde.server(port) }
    server$add(app = app, name = name)
    if(browse) {
        browseURL(paste(server$full_url(name), "index.html", sep = "/"))
    }
    return(server)
}
# get SCDE server from saved session
get.scde.server <- function(port = NULL, ip = '127.0.0.1') {
    if(exists("___scde.server", envir = globalenv())) {
        server <- get("___scde.server", envir = globalenv())
    } else {
        require(Rook)
        server <- Rhttpd$new()
        assign("___scde.server", server, envir = globalenv())
        server$start(listen = ip, port = port)
    }
    return(server)
}


# calculate individual and joint posterior information
# models - all or a subset of models belonging to a particular group
#
##' Calculate joint expression magnitude posteriors across a set of cells
##'
##' Calculates expression magnitude posteriors for the individual cells, and then uses bootstrap resampling to calculate a joint expression posterior for all the specified cells. Alternatively during batch-effect correction procedure, the joint posterior can be calculated for a random composition of cells of different groups (see \code{batch} and \code{composition} parameters).
##'
##' @param models models models determined by \code{\link{scde.error.models}}
##' @param counts read count matrix
##' @param prior gene expression prior as determined by \code{\link{scde.expression.prior}}
##' @param n.randomizations number of bootstrap iterations to perform
##' @param batch a factor describing which batch group each cell (i.e. each row of \code{models} matrix) belongs to
##' @param composition a vector describing the batch composition of a group to be sampled
##' @param return.individual.posteriors whether expression posteriors of each cell should be returned
##' @param return.individual.posterior.modes whether modes of expression posteriors of each cell should be returned
##' @param ensemble.posterior Boolean of whether to calculate the ensemble posterior (sum of individual posteriors) instead of a joint (product) posterior. (default: FALSE)
##' @param n.cores number of cores to utilize
##'
##' @return \subsection{default}{ a posterior probability matrix, with rows corresponding to genes, and columns to expression levels (as defined by \code{prior$x})
##' }
##' \subsection{return.individual.posterior.modes}{ a list is returned, with the \code{$jp} slot giving the joint posterior matrix, as described above. The \code{$modes} slot gives a matrix of individual expression posterior mode values on log scale (rows - genes, columns -cells)}
##' \subsection{return.individual.posteriors}{ a list is returned, with the \code{$post} slot giving a list of individual posterior matrices, in a form analogous to the joint posterior matrix, but reported on log scale }
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' data(o.ifm)  # Load precomputed model. Use ?scde.error.models to see how o.ifm was generated
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##' # calculate joint posteriors
##' jp <- scde.posteriors(o.ifm, cd, o.prior, n.cores = 1)
##'
##' @export
scde.posteriors <- function(models, counts, prior, n.randomizations = 100, batch = NULL, composition = NULL, return.individual.posteriors = FALSE, return.individual.posterior.modes = FALSE, ensemble.posterior = FALSE, n.cores = 20) {
    if(!all(rownames(models) %in% colnames(counts))) { stop("ERROR: provided count data does not cover all of the cells specified in the model matrix") }
    if(!is.null(batch)) { # calculating batch-sampled posteriors instead of evenly sampled ones
        if(is.null(composition)) { stop("ERROR: group composition must be provided if the batch argument is passed") }
        batchil <- tapply(c(1:nrow(models))-1, batch, I)
    }
    # order counts according to the cells
    ci <- match(rownames(models), colnames(counts))
    counts <- as.matrix(counts[, ci, drop = FALSE])
    marginals <- 10^prior$x - 1
    marginals[marginals<0] <- 0
    marginals <- log(marginals)

    min.slope <- 1e-10
    if(any(models$corr.a<min.slope)) {
        cat("WARNING: the following cells have negatively-correlated or 0-slope fits: ", paste(rownames(models)[models$corr.a<min.slope], collapse = " "), ". Setting slopes to 1e-10.\n")
        models$corr.a[models$corr.a<min.slope] <- min.slope
    }

    postflag <- 0
    if(return.individual.posteriors) {
        postflag <- 2
        if(return.individual.posterior.modes) {
            postflag <- 3
        }
    } else if(return.individual.posterior.modes) {
        postflag <- 1
    }

    ensembleflag <- ifelse(ensemble.posterior, 1, 0)

    localthetaflag <- "corr.ltheta.b" %in% colnames(models)
    squarelogitconc <- "conc.a2" %in% colnames(models)

    # prepare matrix models
    mn <- c("conc.b", "conc.a", "fail.r", "corr.b", "corr.a", "corr.theta", "corr.ltheta.b", "corr.ltheta.t", "corr.ltheta.m", "corr.ltheta.s", "corr.ltheta.r", "conc.a2")
    mc <- match(c(mn), colnames(models))
    mm <- matrix(NA, nrow(models), length(mn))
    mm[, which(!is.na(mc))] <- as.matrix(models[, mc[!is.na(mc)], drop = FALSE])

    chunk <- function(x, n) split(x, sort(rank(x) %% n.cores))
    if(n.cores > 1 && nrow(counts) > n.cores) { # split by genes
        xl <- papply(chunk(seq_len(nrow(counts)), n.cores), function(ii) {
            ucl <- lapply(seq_len(ncol(counts)), function(i) as.vector(unique(counts[ii, i, drop = FALSE])))
            uci <- do.call(cbind, lapply(seq_len(ncol(counts)), function(i) match(counts[ii, i, drop = FALSE], ucl[[i]])-1))
            #x <- logBootPosterior(models, ucl, uci, marginals, n.randomizations, 1, postflag)
            if(!is.null(batch)) {
                x <- .Call("logBootBatchPosterior", mm, ucl, uci, marginals, batchil, composition, n.randomizations, ii[1], postflag, localthetaflag, squarelogitconc, PACKAGE = "scde")
            } else {
                x <- .Call("logBootPosterior", mm, ucl, uci, marginals, n.randomizations, ii[1], postflag, localthetaflag, squarelogitconc, ensembleflag, PACKAGE = "scde")
            }
        }, n.cores = n.cores)
        if(postflag == 0) {
            x <- do.call(rbind, xl)
        } else if(postflag == 1) {
            x <- list(jp = do.call(rbind, lapply(xl, function(d) d$jp)), modes = do.call(rbind, lapply(xl, function(d) d$modes)))
        } else if(postflag == 2) {
            x <- list(jp = do.call(rbind, lapply(xl, function(d) d$jp)), post = lapply(seq_along(xl[[1]]$post), function(pi) { do.call(rbind, lapply(xl, function(d) d$post[[pi]])) }))
        } else if(postflag == 3) {
            x <- list(jp = do.call(rbind, lapply(xl, function(d) d$jp)), modes = do.call(rbind, lapply(xl, function(d) d$modes)), post = lapply(seq_along(xl[[1]]$post), function(pi) { do.call(rbind, lapply(xl, function(d) d$post[[pi]])) }))
        }
        rm(xl)
        gc()
    } else {
        # unique count lists with matching indices
        ucl <- lapply(seq_len(ncol(counts)), function(i) as.vector(unique(counts[, i, drop = FALSE])))
        uci <- do.call(cbind, lapply(seq_len(ncol(counts)), function(i) match(counts[, i, drop = FALSE], ucl[[i]])-1))
        #x <- logBootPosterior(models, ucl, uci, marginals, n.randomizations, 1, postflag)
        if(!is.null(batch)) {
            x <- .Call("logBootBatchPosterior", mm, ucl, uci, marginals, batchil, composition, n.randomizations, 1, postflag, localthetaflag, squarelogitconc, PACKAGE = "scde")
        } else {
            x <- .Call("logBootPosterior", mm, ucl, uci, marginals, n.randomizations, 1, postflag, localthetaflag, squarelogitconc, ensembleflag, PACKAGE = "scde")
        }
    }
    if(postflag == 0) {
        rownames(x) <- rownames(counts)
        colnames(x) <- as.character(exp(marginals))
    } else if(postflag == 1) {
        rownames(x$jp) <- rownames(counts)
        colnames(x$jp) <- as.character(exp(marginals))
        rownames(x$modes) <- rownames(counts)
        colnames(x$modes) <- rownames(models)
    } else if(postflag == 2) {
        rownames(x$jp) <- rownames(counts)
        colnames(x$jp) <- as.character(exp(marginals))
        names(x$post) <- rownames(models)
        x$post <- lapply(x$post, function(d) {
            rownames(d) <- rownames(counts)
            colnames(d) <- as.character(exp(marginals))
            return(d)
        })
    } else if(postflag == 3) {
        rownames(x$jp) <- rownames(counts)
        colnames(x$jp) <- as.character(exp(marginals))
        rownames(x$modes) <- rownames(counts)
        colnames(x$modes) <- rownames(models)
        names(x$post) <- rownames(models)
        x$post <- lapply(x$post, function(d) {
            rownames(d) <- rownames(counts)
            colnames(d) <- as.character(exp(marginals))
            return(d)
        })
    }
    return(x)
}


# get estimates of expression magnitude for a given set of models
# models - entire model matrix, or a subset of cells (i.e. select rows) of the model matrix for which the estimates should be obtained
# counts - count data that covers the desired set of genes (rows) and all specified cells (columns)
# return - a matrix of log(FPM) estimates with genes as rows and cells  as columns (in the model matrix order).
##' Return scaled expression magnitude estimates
##'
##' Return point estimates of expression magnitudes of each gene across a set of cells, based on the regression slopes determined during the model fitting procedure.
##'
##' @param models models determined by \code{\link{scde.error.models}}
##' @param counts count matrix
##'
##' @return a matrix of expression magnitudes on a log scale (rows - genes, columns - cells)
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' data(o.ifm)  # Load precomputed model. Use ?scde.error.models to see how o.ifm was generated
##' # get expression magnitude estimates
##' lfpm <- scde.expression.magnitude(o.ifm, cd)
##'
##' @export
scde.expression.magnitude <- function(models, counts) {
    if(!all(rownames(models) %in% colnames(counts))) { stop("ERROR: provided count data does not cover all of the cells specified in the model matrix") }
    t((t(log(counts[, rownames(models), drop = FALSE]))-models$corr.b)/models$corr.a)
}


# calculate drop-out probability given either count data or magnitudes (log(FPM))
# magnitudes can either be a per-cell matrix or a single vector of values which will be evaluated for each cell
# returns a probability of a drop out event for every gene (rows) for every cell (columns)
##' Calculate drop-out probabilities given a set of counts or expression magnitudes
##'
##' Returns estimated drop-out probability for each cell (row of \code{models} matrix), given either an expression magnitude
##' @param models models determined by \code{\link{scde.error.models}}
##' @param magnitudes a vector (\code{length(counts) == nrows(models)}) or a matrix (columns correspond to cells) of expression magnitudes, given on a log scale
##' @param counts a vector (\code{length(counts) == nrows(models)}) or a matrix (columns correspond to cells) of read counts from which the expression magnitude should be estimated
##'
##' @return a vector or a matrix of drop-out probabilities
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' data(o.ifm)  # Load precomputed model. Use ?scde.error.models to see how o.ifm was generated
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##' # calculate probability of observing a drop out at a given set of magnitudes in different cells
##' mags <- c(1.0, 1.5, 2.0)
##' p <- scde.failure.probability(o.ifm, magnitudes = mags)
##' # calculate probability of observing the dropout at a magnitude corresponding to the
##' # number of reads actually observed in each cell
##' self.p <- scde.failure.probability(o.ifm, counts = cd)
##'
##' @export
scde.failure.probability <- function(models, magnitudes = NULL, counts = NULL) {
    if(is.null(magnitudes)) {
        if(!is.null(counts)) {
            magnitudes <- scde.expression.magnitude(models, counts)
        } else {
            stop("ERROR: either magnitudes or counts should be provided")
        }
    }
    if(is.matrix(magnitudes)) { # a different vector for every cell
        if(!all(rownames(models) %in% colnames(magnitudes))) { stop("ERROR: provided magnitude data does not cover all of the cells specified in the model matrix") }
        if("conc.a2" %in% names(models)) {
            x <- t(1/(exp(t(magnitudes)*models$conc.a +t(magnitudes^2)*models$conc.a2 + models$conc.b)+1))
        } else {
            x <- t(1/(exp(t(magnitudes)*models$conc.a + models$conc.b)+1))
        }
    } else { # a common vector of magnitudes for all cells
        if("conc.a2" %in% names(models)) {
            x <- t(1/(exp((models$conc.a %*% t(magnitudes)) + (models$conc.a2 %*% t(magnitudes^2)) + models$conc.b)+1))
        } else {
            x <- t(1/(exp((models$conc.a %*% t(magnitudes)) + models$conc.b)+1))
        }
    }
    x[is.nan(x)] <- 0
    colnames(x) <- rownames(models)
    x
}


##' Test differential expression and plot posteriors for a particular gene
##'
##' The function performs differential expression test and optionally plots posteriors for a specified gene.
##'
##' @param gene name of the gene to be tested
##' @param models models
##' @param counts read count matrix (must contain the row corresponding to the specified gene)
##' @param prior expression magnitude prior
##' @param groups a two-level factor specifying between which cells (rows of the models matrix) the comparison should be made
##' @param batch optional multi-level factor assigning the cells (rows of the model matrix) to different batches that should be controlled for (e.g. two or more biological replicates). The expression difference estimate will then take into account the likely difference between the two groups that is explained solely by their difference in batch composition. Not all batch configuration may be corrected this way.
##' @param batch.models optional set of models for batch comparison (typically the same as models, but can be more extensive, or recalculated within each batch)
##' @param n.randomizations number of bootstrap/sampling iterations that should be performed
##' @param show.plots whether the plots should be shown
##' @param return.details whether the posterior should be returned
##' @param verbose set to T for some status output
##' @param ratio.range optionally specifies the range of the log2 expression ratio plot
##' @param show.individual.posteriors whether the individual cell expression posteriors should be plotted
##' @param n.cores number of cores to use (default = 1)
##'
##' @return by default returns MLE of log2 expression difference, 95% CI (upper, lower bound), and a Z-score testing for expression difference. If return.details = TRUE, a list is returned containing the above structure, as well as the expression fold difference posterior itself.
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' data(o.ifm)  # Load precomputed model. Use ?scde.error.models to see how o.ifm was generated
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##' scde.test.gene.expression.difference("Tdh", models = o.ifm, counts = cd, prior = o.prior)
##'
##' @export
scde.test.gene.expression.difference <- function(gene, models, counts, prior, groups = NULL, batch = NULL, batch.models = models, n.randomizations = 1e3, show.plots = TRUE, return.details = FALSE, verbose = FALSE, ratio.range = NULL, show.individual.posteriors = TRUE, n.cores = 1) {
    if(!gene %in% rownames(counts)) {
        stop("ERROR: specified gene (", gene, ") is not found in the count data")
    }

    ci <- match(rownames(models), colnames(counts))
    counts <- as.matrix(counts[gene, ci, drop = FALSE])


    if(is.null(groups)) { # recover groups from models
        groups <- as.factor(attr(models, "groups"))
        if(is.null(groups)) stop("ERROR: groups factor is not provided, and models structure is lacking groups attribute")
        names(groups) <- rownames(models)
    }
    if(length(levels(groups)) != 2) {
        stop(paste("ERROR: wrong number of levels in the grouping factor (", paste(levels(groups), collapse = " "), "), but must be two.", sep = ""))
    }

    if(verbose) {
        cat("comparing gene ", gene, " between groups:\n")
        print(table(as.character(groups)))
    }

    # calculate joint posteriors
    jpl <- tapply(seq_len(nrow(models)), groups, function(ii) {
        scde.posteriors(models = models[ii, , drop = FALSE], counts = counts[, ii, drop = FALSE], prior = prior, n.cores = n.cores, n.randomizations = n.randomizations, return.individual.posteriors = TRUE)
    })

    bdiffp <- calculate.ratio.posterior(jpl[[1]]$jp, jpl[[2]]$jp, prior, n.cores = n.cores)

    bdiffp.rep <- quick.distribution.summary(bdiffp)

    nam1 <- levels(groups)[1]
    nam2 <- levels(groups)[2]

    # batch control
    correct.batch <- !is.null(batch) && length(levels(batch)) > 1
    if(correct.batch) {
        batch <- as.factor(batch)
        # check batch-group interactions
        bgti <- table(groups, batch)
        bgti.ft <- fisher.test(bgti)
        if(verbose) {
            cat("controlling for batch effects. interaction:\n")
        }
        if(any(bgti == 0)) {
            cat("ERROR: cannot control for batch effect, as some batches are found only in one group:\n")
            print(bgti)
        }
        if(bgti.ft$p.value<1e-3) {
            cat("WARNING: strong interaction between groups and batches! Correction may be ineffective:\n")
            print(bgti)
            print(bgti.ft)
        }
        # calculate batch posterior
        batch.jpl <- tapply(seq_len(nrow(models)), groups, function(ii) {
            scde.posteriors(models = batch.models, counts = counts, prior = prior, batch = batch, composition = table(batch[ii]), n.cores = n.cores, n.randomizations = n.randomizations, return.individual.posteriors = FALSE)
        })
        batch.bdiffp <- calculate.ratio.posterior(batch.jpl[[1]], batch.jpl[[2]], prior, n.cores = n.cores)
        a.bdiffp <- calculate.ratio.posterior(bdiffp, batch.bdiffp, prior = data.frame(x = as.numeric(colnames(bdiffp)), y = rep(1/ncol(bdiffp), ncol(bdiffp))), skip.prior.adjustment = TRUE)
        a.bdiffp.rep <- quick.distribution.summary(a.bdiffp)
    }


    if(show.plots) {
        # show each posterior
        layout(matrix(c(1:3), 3, 1, byrow = TRUE), heights = c(2, 1, 2), widths = c(1), FALSE)
        par(mar = c(2.5, 3.5, 2.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)
        #par(mar = c(2.5, 3.5, 0.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)

        pp <- exp(do.call(rbind, lapply(jpl[[1]]$post, as.numeric)))
        cols <- rainbow(nrow(pp), s = 0.8)
        plot(c(), c(), xlim = range(prior$x), ylim = range(c(0, pp)), xlab = "expression level", ylab = "individual posterior", main = nam1)
        if(show.individual.posteriors) {
            lapply(seq_len(nrow(pp)), function(i) lines(prior$x, pp[i, ], col = rgb(1, 0.5, 0, alpha = 0.25)))
        }
        #legend(x = ifelse(which.max(na.omit(pjpc)) > length(pjpc)/2, "topleft", "topright"), bty = "n", col = cols, legend = rownames(pp), lty = rep(1, nrow(pp)))
        if(correct.batch) {
            par(new = TRUE)
            plot(prior$x, batch.jpl[[1]][1, ], axes = FALSE, ylab = "", xlab = "", type = 'l', col = 8, lty = 1, lwd = 2)
        }
        pjpc <- jpl[[1]]$jp
        par(new = TRUE)
        jpr <- range(c(0, na.omit(pjpc)))
        plot(prior$x, pjpc, axes = FALSE, ylab = "", xlab = "", ylim = jpr, type = 'l', col = 1, lty = 1, lwd = 2)
        axis(4, pretty(jpr, 5), col = 1)
        mtext("joint posterior", side = 4, outer = FALSE, line = 2)


        # ratio plot
        if(is.null(ratio.range)) { ratio.range <- range(as.numeric(colnames(bdiffp))/log10(2)) }

        par(mar = c(2.5, 3.5, 0.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)
        rv <- as.numeric(colnames(bdiffp))/log10(2)
        rp <- as.numeric(bdiffp[1, ])
        plot(rv, rp, xlab = "log2 expression ratio", ylab = "ratio posterior", type = 'l', lwd = ifelse(correct.batch, 1, 2), main = "", axes = FALSE, xlim = ratio.range, ylim = c(0, max(bdiffp)))
        axis(1, pretty(ratio.range, 5), col = 1)
        abline(v = 0, lty = 2, col = 8)
        if(correct.batch) { # with batch correction
            # show batch difference
            par(new = TRUE)
            plot(as.numeric(colnames(batch.bdiffp))/log10(2), as.numeric(batch.bdiffp[1, ]), xlab = "", ylab = "", type = 'l', lwd = 1, main = "", axes = FALSE, xlim = ratio.range, col = 8, ylim = c(0, max(batch.bdiffp)))
            # fill out the a.bdiffp confidence interval
            par(new = TRUE)
            rv <- as.numeric(colnames(a.bdiffp))/log10(2)
            rp <- as.numeric(a.bdiffp[1, ])
            plot(rv, rp, xlab = "", ylab = "", type = 'l', lwd = 2, main = "", axes = FALSE, xlim = ratio.range, col = 2, ylim = c(0, max(rp)))
            axis(2, pretty(c(0, max(a.bdiffp)), 2), col = 1)
            r.lb <- which.min(abs(rv-a.bdiffp.rep$lb))
            r.ub <- which.min(abs(rv-a.bdiffp.rep$ub))
            polygon(c(rv[r.lb], rv[r.lb:r.ub], rv[r.ub]), y = c(-10, rp[r.lb:r.ub], -10), col = rgb(1, 0, 0, alpha = 0.2), border = NA)
            abline(v = a.bdiffp.rep$mle, col = 2, lty = 2)
            abline(v = c(rv[r.ub], rv[r.lb]), col = 2, lty = 3)

            legend(x = ifelse(a.bdiffp.rep$mle > 0, "topleft", "topright"), legend = c(paste("MLE: ", round(a.bdiffp.rep$mle, 2), sep = ""), paste("95% CI: ", round(a.bdiffp.rep$lb, 2), " : ", round(a.bdiffp.rep$ub, 2), sep = ""), paste("Z = ", round(a.bdiffp.rep$Z, 2), sep = ""), paste("cZ = ", round(a.bdiffp.rep$cZ, 2), sep = "")), bty = "n")

        } else {  # without batch correction
            # fill out the bdiffp confidence interval
            axis(2, pretty(c(0, max(bdiffp)), 2), col = 1)

            r.lb <- which.min(abs(rv-bdiffp.rep$lb))
            r.ub <- which.min(abs(rv-bdiffp.rep$ub))
            polygon(c(rv[r.lb], rv[r.lb:r.ub], rv[r.ub]), y = c(-10, rp[r.lb:r.ub], -10), col = rgb(1, 0, 0, alpha = 0.2), border = NA)
            abline(v = bdiffp.rep$mle, col = 2, lty = 2)
            abline(v = c(rv[r.ub], rv[r.lb]), col = 2, lty = 3)

            legend(x = ifelse(bdiffp.rep$mle > 0, "topleft", "topright"), legend = c(paste("MLE: ", round(bdiffp.rep$mle, 2), sep = ""), paste("95% CI: ", round(bdiffp.rep$lb, 2), " : ", round(bdiffp.rep$ub, 2), sep = ""), paste("Z = ", round(bdiffp.rep$Z, 2), sep = ""), paste("aZ = ", round(bdiffp.rep$cZ, 2), sep = "")), bty = "n")
        }

        # distal plot
        par(mar = c(2.5, 3.5, 2.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)
        #par(mar = c(2.5, 3.5, 0.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)
        dp <- exp(do.call(rbind, lapply(jpl[[2]]$post, as.numeric)))
        cols <- rainbow(nrow(dp), s = 0.8)
        plot(c(), c(), xlim = range(prior$x), ylim = range(c(0, dp)), xlab = "expression level", ylab = "individual posterior", main = nam2)
        if(show.individual.posteriors) {
            lapply(seq_len(nrow(dp)), function(i) lines(prior$x, dp[i, ], col = rgb(0, 0.5, 1, alpha = 0.25)))
        }
        if(correct.batch) {
            par(new = TRUE)
            plot(prior$x, batch.jpl[[2]][1, ], axes = FALSE, ylab = "", xlab = "", type = 'l', col = 8, lty = 1, lwd = 2)
        }
        djpc <- jpl[[2]]$jp
        #legend(x = ifelse(which.max(na.omit(djpc)) > length(djpc)/2, "topleft", "topright"), bty = "n", col = cols, legend = rownames(dp), lty = rep(1, nrow(dp)))
        par(new = TRUE)
        jpr <- range(c(0, na.omit(djpc)))
        plot(prior$x, djpc, axes = FALSE, ylab = "", xlab = "", ylim = jpr, type = 'l', col = 1, lty = 1, lwd = 2)
        axis(4, pretty(jpr, 5), col = 1)
        mtext("joint posterior", side = 4, outer = FALSE, line = 2)
    }

    if(return.details) {
        if(correct.batch) { # with batch correction
            return(list(results = a.bdiffp.rep, difference.posterior = a.bdiffp, results.nobatchcorrection = bdiffp.rep))
        } else {
            return(list(results = bdiffp.rep, difference.posterior = bdiffp, posteriors = jpl))
        }
    } else {
        if(correct.batch) { # with batch correction
            return(a.bdiffp.rep)
        } else {
            return(bdiffp.rep)
        }
    }
}


# fit models to external (bulk) reference
##' Fit scde models relative to provided set of expression magnitudes
##'
##' If group-average expression magnitudes are available (e.g. from bulk measurement), this method can be used
##' to fit individual cell error models relative to that reference
##'
##' @param counts count matrix
##' @param reference a vector of expression magnitudes (read counts) corresponding to the rows of the count matrix
##' @param min.fpm minimum reference fpm of genes that will be used to fit the models (defaults to 1). Note: fpm is calculated from the reference count vector as reference/sum(reference)*1e6
##' @param n.cores number of cores to use
##' @param zero.count.threshold read count to use as an initial guess for the zero threshold
##' @param nrep number independent of mixture fit iterations to try (default = 1)
##' @param save.plots whether to write out a pdf file showing the model fits
##' @param plot.filename model fit pdf filename
##' @param verbose verbose level
##'
##' @return matrix of scde models
##'
##' @examples
##' data(es.mef.small)
##' cd <- clean.counts(es.mef.small, min.lib.size=1000, min.reads = 1, min.detected = 1)
##' \donttest{
##' o.ifm <- scde.error.models(counts = cd, groups = sg, n.cores = 10, threshold.segmentation = TRUE)
##' o.prior <- scde.expression.prior(models = o.ifm, counts = cd, length.out = 400, show.plot = FALSE)
##' # calculate joint posteriors across all cells
##' jp <- scde.posteriors(models = o.ifm, cd, o.prior, n.cores = 10, return.individual.posterior.modes = TRUE, n.randomizations = 100)
##' # use expected expression magnitude for each gene
##' av.mag <- as.numeric(jp$jp %*% as.numeric(colnames(jp$jp)))
##' # translate into counts
##' av.mag.counts <- as.integer(round(av.mag))
##' # now, fit alternative models using av.mag as a reference (normally this would correspond to bulk RNA expression magnitude)
##' ref.models <- scde.fit.models.to.reference(cd, av.mag.counts, n.cores = 1)
##' }
##'
##' @export
scde.fit.models.to.reference <- function(counts, reference, n.cores = 10, zero.count.threshold = 1, nrep = 1, save.plots = FALSE, plot.filename = "reference.model.fits.pdf", verbose = 0, min.fpm = 1) {
    return.compressed.models <- TRUE
    verbose <- 1
    ids <- colnames(counts)
    ml <- papply(seq_along(ids), function(i) {
        df <- data.frame(count = counts[, ids[i]], fpm = reference/sum(reference)*1e6)
        df <- df[df$fpm > min.fpm, ]
        m1 <- fit.nb2.mixture.model(df, nrep = nrep, verbose = verbose, zero.count.threshold = zero.count.threshold)
        if(return.compressed.models) {
            v <- get.compressed.v1.model(m1)
            cl <- clusters(m1)
            rm(m1)
            gc()
            return(list(model = v, clusters = cl))
        } else {
            return(m1)
        }
    }, n.cores = n.cores)
    names(ml) <- ids

    # check if there were errors in the multithreaded portion
    lapply(seq_along(ml), function(i) {
        if(class(ml[[i]]) == "try-error") {
            message("ERROR encountered in building a model for cell ", ids[i], ":")
            message(ml[[i]])
            tryCatch(stop(paste("ERROR encountered in building a model for cell ", ids[i])), error = function(e) stop(e))
        }
    })

    if(save.plots) {
        # model fits
        #CairoPNG(file = paste(group, "model.fits.png", sep = "."), width = 1024, height = 300*length(ids))
        pdf(file = plot.filename, width = 13, height = 4)
        #l <- layout(matrix(seq(1, 4*length(ids)), nrow = length(ids), byrow = TRUE), rep(c(1, 1, 1, 0.5), length(ids)), rep(1, 4*length(ids)), FALSE)
        l <- layout(matrix(seq(1, 4), nrow = 1, byrow = TRUE), rep(c(1, 1, 1, 0.5), 1), rep(1, 4), FALSE)
        par(mar = c(3.5, 3.5, 3.5, 0.5), mgp = c(2.0, 0.65, 0), cex = 0.9)
        invisible(lapply(seq_along(ids), function(i) {
            df <- data.frame(count = counts[, ids[i]], fpm = reference/sum(reference)*1e6)
            df <- df[df$fpm > min.fpm, ]
            plot.nb2.mixture.fit(ml[[i]], df, en = ids[i], do.par = FALSE, compressed.models = return.compressed.models)
        }))
        dev.off()
    }

    if(return.compressed.models) {
        # make a joint model matrix
        jmm <- data.frame(do.call(rbind, lapply(ml, function(m) m$model)))
        rownames(jmm) <- names(ml)
        jmm
        return(jmm)
    } else {
        return(ml)
    }
}


##' Determine principal components of a matrix using per-observation/per-variable weights
##'
##' Implements a weighted PCA
##'
##' @param mat matrix of variables (columns) and observations (rows)
##' @param matw  corresponding weights
##' @param npcs number of principal components to extract
##' @param nstarts number of random starts to use
##' @param smooth smoothing span
##' @param em.tol desired EM algorithm tolerance
##' @param em.maxiter maximum number of EM iterations
##' @param seed random seed
##' @param center whether mat should be centered (weighted centering)
##' @param n.shuffles optional number of per-observation randomizations that should be performed in addition to the main calculations to determine the lambda1 (PC1 eigenvalue) magnitude under such randomizations (returned in $randvar)
##'
##' @return a list containing eigenvector matrix ($rotation), projections ($scores), variance (weighted) explained by each component ($var), total (weighted) variance of the dataset ($totalvar)
##'
##' @examples
##' set.seed(0)
##' mat <- matrix( c(rnorm(5*10,mean=0,sd=1), rnorm(5*10,mean=5,sd=1)), 10, 10)  # random matrix
##' base.pca <- bwpca(mat)  # non-weighted pca, equal weights set automatically
##' matw <- matrix( c(rnorm(5*10,mean=0,sd=1), rnorm(5*10,mean=5,sd=1)), 10, 10)  # random weight matrix
##' matw <- abs(matw)/max(matw)
##' base.pca.weighted <- bwpca(mat, matw)  # weighted pca
##'
##' @export
bwpca <- function(mat, matw = NULL, npcs = 2, nstarts = 1, smooth = 0, em.tol = 1e-6, em.maxiter = 25, seed = 1, center = TRUE, n.shuffles = 0) {
    if(smooth<4) { smooth <- 0 }
    if(any(is.nan(matw))) {
      stop("bwpca: weight matrix contains NaN values")
    }
    if(any(is.nan(mat))) {
      stop("bwpca: value matrix contains NaN values")
    }
    if(is.null(matw)) {
        matw <- matrix(1, nrow(mat), ncol(mat))
        nstarts <- 1
    }
    if(center) { mat <- t(t(mat)-colSums(mat*matw)/colSums(matw)) }

    res <- .Call("baileyWPCA", mat, matw, npcs, nstarts, smooth, em.tol, em.maxiter, seed, n.shuffles, PACKAGE = "scde")
    #res <- bailey.wpca(mat, matw, npcs, nstarts, smooth, em.tol, em.maxiter, seed)
    rownames(res$rotation) <- colnames(mat)
    rownames(res$scores) <- rownames(mat)
    colnames(res$rotation) <- paste("PC", seq(1:ncol(res$rotation)), sep = "")
    res$sd <- t(sqrt(res$var))
    res
}


##' Winsorize matrix
##'
##' Sets the ncol(mat)*trim top outliers in each row to the next lowest value same for the lowest outliers
##'
##' @param mat matrix
##' @param trim fraction of outliers (on each side) that should be Winsorized, or (if the value is  >= 1) the number of outliers to be trimmed on each side
##'
##' @return Winsorized matrix
##'
##' @examples
##' set.seed(0)
##' mat <- matrix( c(rnorm(5*10,mean=0,sd=1), rnorm(5*10,mean=5,sd=1)), 10, 10)  # random matrix
##' mat[1,1] <- 1000  # make outlier
##' range(mat)  # look at range of values
##' win.mat <- winsorize.matrix(mat, 0.1)
##' range(win.mat)  # note outliers removed
##'
##' @export
winsorize.matrix <- function(mat, trim) {
    if(trim  >  0.5) { trim <- trim/ncol(mat)  }
    wm <- .Call("winsorizeMatrix", mat, trim, PACKAGE = "scde")
    rownames(wm) <- rownames(mat)
    colnames(wm) <- colnames(mat)
    return(wm)
}


############################ PAGODA functions


##' Build error models for heterogeneous cell populations, based on K-nearest neighbor cells.
##'
##' Builds cell-specific error models assuming that there are multiple subpopulations present
##' among the measured cells. The models for each cell are based on average expression estimates
##' obtained from K closest cells within a given group (if groups = NULL, then within the entire
##' set of measured cells). The method implements fitting of both the original log-fit models
##' (when linear.fit = FALSE), or newer linear-fit models (linear.fit = TRUE, default) with locally
##' fit overdispersion coefficient (local.theta.fit = TRUE, default).
##'
##' @param counts count matrix (integer matrix, rows- genes, columns- cells)
##' @param groups optional groups partitioning known subpopulations
##' @param cor.method correlation measure to be used in determining k nearest cells
##' @param k number of nearest neighbor cells to use during fitting. If k is set sufficiently high, all of the cells within a given group will be used.
##' @param min.nonfailed minimum number of non-failed measurements (within the k nearest neighbor cells) required for a gene to be taken into account during error fitting procedure
##' @param min.size.entries minimum number of genes to use for model fitting
##' @param min.count.threshold minimum number of reads required for a measurement to be considered non-failed
##' @param save.model.plots whether model plots should be saved (file names are (group).models.pdf, or cell.models.pdf if no group was supplied)
##' @param max.model.plots maximum number of models to save plots for (saves time when there are too many cells)
##' @param n.cores number of cores to use through the calculations
##' @param min.fpm optional parameter to restrict model fitting to genes with group-average expression magnitude above a given value
##' @param verbose level of verbosity
##' @param fpm.estimate.trim trim fraction to be used in estimating group-average gene expression magnitude for model fitting (0.5 would be median, 0 would turn off trimming)
##' @param linear.fit whether newer linear model fit with zero intercept should be used (T), or the log-fit model published originally (F)
##' @param local.theta.fit whether local theta fitting should be used (only available for the linear fit models)
##' @param theta.fit.range allowed range of the theta values
##' @param alpha.weight.power 1/theta weight power used in fitting theta dependency on the expression magnitude
##'
##' @return a data frame with parameters of the fit error models (rows- cells, columns- fitted parameters)
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' }
##'
##' @export
knn.error.models <- function(counts, groups = NULL, k = round(ncol(counts)/2), min.nonfailed = 5, min.count.threshold = 1, save.model.plots = TRUE, max.model.plots = 50, n.cores = parallel::detectCores(), min.size.entries = 2e3, min.fpm = 0, cor.method = "pearson", verbose = 0, fpm.estimate.trim = 0.25, linear.fit = TRUE, local.theta.fit = linear.fit, theta.fit.range = c(1e-2, 1e2), alpha.weight.power = 1/2) {
    threshold.prior = 1-1e-6

    # check for integer counts
    if(any(!unlist(lapply(counts,is.integer)))) {
      stop("Some of the supplied counts are not integer values (or stored as non-integer types). Aborting!\nThe method is designed to work on read counts - do not pass normalized read counts (e.g. FPKM values). If matrix contains read counts, but they are stored as numeric values, use counts<-apply(counts,2,function(x) {storage.mode(x) <- 'integer'; x}) to recast.");
    }

    # TODO:
    #  - implement check for k >= n.cells (to avoid correlation calculations)
    #  - implement error reporting/handling for failed cell fits

    if(is.null(groups)) {
        groups <- as.factor(rep("cell", ncol(counts)))
    }
    names(groups) <- colnames(counts)

    if(k >  ncol(counts)-1) {
        message("the value of k (", k, ") is too large, setting to ", (ncol(counts)-1))
        k <- ncol(counts)-1
    }

    ls <- estimate.library.sizes(counts, NULL, groups, min.size.entries, verbose = verbose, return.details = TRUE, vil = counts >= min.count.threshold)
    ca <- counts
    ca[ca<min.count.threshold] <- NA # a version of counts with all "drop-out" components set to NA
    mll <- tapply(colnames(counts), groups, function(ids) {
        # use Spearman rank correlation on pairwise complete observations to establish distance relationships between cells
        group <- as.character(groups[ids[1]])

        if(verbose > 0) {
            cat(group, ": calculating cell-cell similarities ...")
        }

        #if(n.cores > 1) { allowWGCNAThreads(n.cores) } else { disableWGCNAThreads() }
        #celld <- WGCNA::cor(log10(matrix(as.numeric(as.matrix(ca)), nrow = nrow(ca), ncol = ncol(ca))+1), method = cor.method, use = "p", nThreads = n.cores)
        if(is.element("WGCNA", installed.packages()[, 1])) {
            celld <- WGCNA::cor(sqrt(matrix(as.numeric(as.matrix(ca[, ids])), nrow = nrow(ca), ncol = length(ids))), method = cor.method, use = "p", nThreads = n.cores)
        } else {
            celld <- stats::cor(sqrt(matrix(as.numeric(as.matrix(ca[, ids])), nrow = nrow(ca), ncol = length(ids))), method = cor.method, use = "p")
        }
        rownames(celld) <- colnames(celld) <- ids

        if(verbose > 0) {
            cat(" done\n")
        }

        # TODO: correct for batch effect in cell-cell similarity matrix
        if(FALSE) {
            # number batches 10^(seq(0, n)) compute matrix of id sums, NA the diagonal,
            bid <- 10^(as.integer(batch)-1)
            bm <- matrix(bid, byrow = TRUE, nrow = length(bid), ncol = length(bid))+bid
            diag(bm) <- NA

            # use tapply to calculate means shifts per combination reconstruct shift vector, matrix, subtract
            # select the upper triangle, tapply to it to correct celld vector directly
        }

        if(verbose)  message(paste("fitting", group, "models:"))

        ml <- papply(seq_along(ids), function(i) { try({
            if(verbose)  message(paste(group, '.', i, " : ", ids[i], sep = ""))
            # determine k closest cells
            oc <- ids[-i][order(celld[ids[i], -i, drop = FALSE], decreasing = TRUE)[1:min(k, length(ids)-1)]]
            #set.seed(i)   oc <- sample(ids[-i], k)
            # determine a subset of genes that show up sufficiently often
            #fpm <- rowMeans(t(t(counts[, oc, drop = FALSE])/(ls$ls[oc])))
            fpm <- apply(t(ca[, oc, drop = FALSE])/(ls$ls[oc]), 2, mean, trim = fpm.estimate.trim, na.rm = TRUE)
            # rank genes by the number of non-zero occurrences, take top genes
            vi <- which(rowSums(counts[, oc] > min.count.threshold)  >=  min(ncol(oc)-1, min.nonfailed) & fpm > min.fpm)
            if(length(vi)<40)  message("WARNING: only ", length(vi), " valid genes were found to fit ", ids[i], " model")
            df <- data.frame(count = counts[vi, ids[i]], fpm = fpm[vi])

            # determine failed-component posteriors for each gene
            #fp <- ifelse(df$count <=  min.count.threshold, threshold.prior, 1-threshold.prior)
            fp <- ifelse(df$count <=  min.count.threshold & df$fpm  >=  median(df$fpm[df$count <=  min.count.threshold]), threshold.prior, 1-threshold.prior)
            cp <- cbind(fp, 1-fp)

            if(linear.fit) {
                # use a linear fit (nb2gth)
                m1 <- fit.nb2gth.mixture.model(df, prior = cp, nrep = 1, verbose = verbose, zero.count.threshold = min.count.threshold, full.theta.range = theta.fit.range, theta.fit.range = theta.fit.range, use.constant.theta.fit = !local.theta.fit, alpha.weight.power = alpha.weight.power)

            }  else {
                # mixture fit (the originally published method)
                m1 <- fit.nb2.mixture.model(df, prior = cp, nrep = 1, verbose = verbose, zero.count.threshold = min.count.threshold)
            }
            v <- get.compressed.v1.model(m1)
            cl <- clusters(m1)
            m1<-list(model = v, clusters = cl)
            #plot.nb2.mixture.fit(m1, df, en = ids[i], do.par = FALSE, compressed.models = TRUE)
            return(m1)
            #})
        })}, n.cores = n.cores)
        vic <- which(unlist(lapply(seq_along(ml), function(i) {
            if(class(ml[[i]]) == "try-error") {
                message("ERROR encountered in building a model for cell ", ids[i], " - skipping the cell. Error:")
                message(ml[[i]])
                #tryCatch(stop(paste("ERROR encountered in building a model for cell ", ids[i])), error = function(e) stop(e))
                return(FALSE);
            }
            return(TRUE);
        })))
        ml <- ml[vic]; names(ml) <- ids[vic];

        if(length(vic)<length(ids)) {
          message("ERROR fitting of ", (length(ids)-length(vic)), " out of ", length(ids), " cells resulted in errors reporting remaining ", length(vic), " cells")
        }
        if(length(vic)<length(ids)) {
                # model fits
                if(verbose)  message("plotting ", group, " model fits... ")
                tryCatch( {
                    pdf(file = paste(group, "model.fits.pdf", sep = "."), width = ifelse(local.theta.fit, 13, 15), height = 4)
                    l <- layout(matrix(seq(1, 4), nrow = 1, byrow = TRUE), rep(c(1, 1, 1, ifelse(local.theta.fit, 1, 0.5)), 1), rep(1, 4), FALSE)
                    par(mar = c(3.5, 3.5, 3.5, 0.5), mgp = c(2.0, 0.65, 0), cex = 0.9)
                    invisible(lapply(vic[1:min(max.model.plots, length(vic))], function(i) {
                        oc <- ids[-i][order(celld[ids[i], -i, drop = FALSE], decreasing = TRUE)[1:min(k, length(ids)-1)]]
                        #set.seed(i) oc <- sample(ids[-i], k)
                        # determine a subset of genes that show up sufficiently often
                        #fpm <- rowMeans(t(t(counts[, oc, drop = FALSE])/(ls$ls[oc])))
                        fpm <- apply(t(ca[, oc, drop = FALSE])/(ls$ls[oc]), 2, mean, trim = fpm.estimate.trim, na.rm = TRUE)
                        vi <- which(rowSums(counts[, oc] > min.count.threshold)  >=  min(ncol(oc)-1, min.nonfailed) & fpm > min.fpm)
                        df <- data.frame(count = counts[vi, ids[i]], fpm = fpm[vi])
                        plot.nb2.mixture.fit(ml[[ids[i]]], df, en = ids[i], do.par = FALSE, compressed.models = TRUE)
                    }))
                    dev.off()
                }, error = function(e) {
                    message("ERROR encountered during model fit plot outputs:")
                    message(e)
                    dev.off()
                })
        }

        return(ml)
    })


    # make a joint model matrix
    jmm <- data.frame(do.call(rbind, lapply(mll, function(tl) do.call(rbind, lapply(tl, function(m) m$model)))))
    rownames(jmm) <- unlist(lapply(mll, names))
    # reorder in the original cell order
    attr(jmm, "groups") <- rep(names(mll), unlist(lapply(mll, length)))
    return(jmm)
}


##' Normalize gene expression variance relative to transcriptome-wide expectations
##'
##' Normalizes gene expression magnitudes to ensure that the variance follows chi-squared statistics
##' with respect to its ratio to the transcriptome-wide expectation as determined by local regression
##' on expression magnitude (and optionally gene length). Corrects for batch effects.
##'
##' @param models model matrix (select a subset of rows to normalize variance within a subset of cells)
##' @param counts read count matrix
##' @param batch measurement batch (optional)
##' @param trim trim value for Winsorization (optional, can be set to 1-3 to reduce the impact of outliers, can be as large as 5 or 10 for datasets with several thousand cells)
##' @param prior expression magnitude prior
##' @param fit.genes a vector of gene names which should be used to establish the variance fit (default is NULL: use all genes). This can be used to specify, for instance, a set spike-in control transcripts such as ERCC.
##' @param plot whether to plot the results
##' @param minimize.underdispersion whether underdispersion should be minimized (can increase sensitivity in datasets with high complexity of population, however cannot be effectively used in datasets where multiple batches are present)
##' @param n.cores number of cores to use
##' @param n.randomizations number of bootstrap sampling rounds to use in estimating average expression magnitude for each gene within the given set of cells
##' @param weight.k k value to use in the final weight matrix
##' @param verbose verbosity level
##' @param weight.df.power power factor to use in determining effective number of degrees of freedom (can be increased for datasets exhibiting particularly high levels of noise at low expression magnitudes)
##' @param smooth.df degrees of freedom to be used in calculating smoothed local regression between coefficient of variation and expression magnitude (and gene length, if provided). Leave at -1 for automated guess.
##' @param max.adj.var maximum value allowed for the estimated adjusted variance (capping of adjusted variance is recommended when scoring pathway overdispersion relative to randomly sampled gene sets)
##' @param theta.range valid theta range (should be the same as was set in knn.error.models() call
##' @param gene.length optional vector of gene lengths (corresponding to the rows of counts matrix)
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' }
##'
##' @return a list containing the following fields:
##' \itemize{
##' \item{mat} {adjusted expression magnitude values}
##' \item{matw} { weight matrix corresponding to the expression matrix}
##' \item{arv} { a vector giving adjusted variance values for each gene}
##' \item{avmodes} {a vector estimated average expression magnitudes for each gene}
##' \item{modes} {a list of batch-specific average expression magnitudes for each gene}
##' \item{prior} {estimated (or supplied) expression magnitude prior}
##' \item{edf} { estimated effective degrees of freedom}
##' \item{fit.genes} { fit.genes parameter }
##' }
##'
##' @export
pagoda.varnorm <- function(models, counts, batch = NULL, trim = 0, prior = NULL, fit.genes=NULL, plot = TRUE, minimize.underdispersion = FALSE, n.cores = detectCores(), n.randomizations = 100, weight.k = 0.9, verbose = 0, weight.df.power = 1, smooth.df = -1, max.adj.var = 10, theta.range = c(1e-2, 1e2), gene.length = NULL) {

    cd <- counts

    min.edf <- 1
    weight.k.internal <- 1
    use.mean.fpm <- FALSE
    use.expected.value <- TRUE
    cv.fit <- TRUE
    edf.damping <- 1

    # load NB extensions
    data(scde.edff, envir = environment())

    # subset cd to the cells occurring in the models
    if(verbose) { cat("checking counts ... ") }
    if(!all(rownames(models) %in% colnames(cd))) {
        stop(paste("supplied count matrix (cd) is missing data for the following cells:[", paste(rownames(models)[!rownames(models) %in% colnames(cd)], collapse = ", "), "]", sep = ""))
    }
    if(!length(rownames(models)) == length(colnames(cd)) || !all(rownames(models) == colnames(cd))) {
        cd <- cd[, match(rownames(models), colnames(cd))]
    }
    if(verbose) { cat("done\n") }

    # trim counts according to the extreme fpm values
    if(trim > 0) {
        if(verbose) { cat("Winsorizing count matrix ... ") }
        fpm <- t((t(log(cd))-models$corr.b)/models$corr.a)
        #tfpm <- log(winsorize.matrix(exp(fpm), trim = trim))
        tfpm <- winsorize.matrix(fpm, trim)
        rn <- rownames(cd)
        cn <- colnames(cd)
        cd <- round(exp(t(t(tfpm)*models$corr.a+models$corr.b)))
        cd[cd<0] <- 0
        rownames(cd) <- rn
        colnames(cd) <- cn
        rm(fpm, tfpm)
        cd <- cd[rowSums(cd) > 0, ] # omit genes without any data after Winsorization
        if(verbose) { cat("done\n") }
    }

    # check/fix batch vector
    if(verbose) { cat("checking batch ... ") }
    if(!is.null(batch)) {
        if(!is.factor(batch)) {
            batch <- as.factor(batch)
        }
        if(is.null(names(batch))) {
            if(length(batch) != nrow(models)) {
                stop("invalid batch vector supplied: length differs from nrow(models)!")
            }
            names(batch) <- rownames(models)
        } else {
            if(!all(rownames(models) %in% names(batch))) {
                stop(paste("invalid batch vector supplied: the following cell(s) are not present: [", paste(rownames(models)[!rownames(models) %in% names(batch)], collapse = ", "), "]", sep = ""))
            }
            batch <- batch[rownames(models)]
        }

        bt <- table(batch)
        min.batch.level <- 2
        if(any(bt<min.batch.level)) {
            if(verbose) { cat("omitting small batch levels [", paste(names(bt)[bt<min.batch.level], collapse = " "), "] ... ") }
            batch[batch %in% names(bt)[bt<min.batch.level]] <- names(bt)[which.max(bt)]
        }
    }
    if(verbose) { cat("ok\n") }

    # recalculate modes as needed
    if(verbose) { cat("calculating modes ... ") }
    if(is.null(prior)) {
        if(verbose) { cat("prior ") }
        prior <- scde.expression.prior(models = models, counts = cd, length.out = 400, show.plot = FALSE)
    }
    # dataset-wide mode
    if(use.mean.fpm) { # use mean fpm across cells
        avmodes <- modes <- rowMeans(exp(scde.expression.magnitude(models, cd)))
    } else { # use joint posterior mode/expected value
        jp <- scde.posteriors(models = models, cd, prior, n.cores = n.cores, return.individual.posterior.modes = TRUE, n.randomizations = n.randomizations)
        if(use.expected.value) {
            avmodes <- modes <- (jp$jp %*% as.numeric(colnames(jp$jp)))[, 1]
        } else { # use mode
            avmodes <- modes <- (as.numeric(colnames(jp$jp)))[max.col(jp$jp)]
        }
    }
    if(verbose) { cat(". ") }

    # batch-specific modes, if necessary
    if(!is.null(batch) && length(levels(batch)) > 1) {
        # calculate mode for each batch
        if(verbose) { cat("batch: [ ") }
        modes <- tapply(seq_len(nrow(models)), batch, function(ii) {
            if(verbose) { cat(as.character(batch[ii[1]]), " ") }
            if(use.mean.fpm) { # use mean fpm across cells
                modes <- rowMeans(exp(scde.expression.magnitude(models[ii, ], cd[, ii])))
            } else { # use joint posterior mode
                jp <- scde.posteriors(models = models[ii, ], cd[, ii], prior, n.cores = n.cores, return.individual.posterior.modes = TRUE, n.randomizations = n.randomizations)
                if(use.expected.value) {
                    modes <- (jp$jp %*% as.numeric(colnames(jp$jp)))[, 1]
                } else { # use mode
                    modes <- (as.numeric(colnames(jp$jp)))[max.col(jp$jp)]
                }
            }
        })
        # set dataset-wide mode
        #if(use.mean.fpm) { # use mean fpm across cells
        #  avmodes <- colMeans(do.call(rbind, modes)*as.vector(unlist(tapply(1:length(batch), batch, length))))*length(levels(batch))/length(batch)
        #jp <- scde.posteriors(models = models, cd, prior, n.cores = n.cores, return.individual.posterior.modes = TRUE, n.randomizations = n.randomizations)
        if(verbose) { cat("] ") }
    }
    if(verbose) { cat("done\n") }

    # check/calculate weights
    if(verbose) { cat("calculating weight matrix ... ") }

    # calculate default weighting scheme
    if(verbose) { cat("calculating ... ") }

    # dataset-wide version of matw (disregarding batch)
    sfp <- do.call(cbind, lapply(seq_len(ncol(cd)), function(i) ppois(cd[, i]-1, exp(models[i, "fail.r"]), lower.tail = FALSE)))
    mfp <- scde.failure.probability(models = models, magnitudes = log(avmodes))
    ofpT <- do.call(cbind, lapply(seq_len(ncol(cd)), function(i) { # for each cell
        lfpm <- log(avmodes)
        mu <- models$corr.b[i] + models$corr.a[i]*lfpm
        thetas <- get.corr.theta(models[i, ], lfpm, theta.range)
        pnbinom(1, size = thetas, mu = exp(mu), lower.tail = TRUE)
    }))
    matw <- 1-weight.k.internal*mfp*sfp # only mode failure probability
    # mode failure or NB failure
    #tmfp <- 1-(1-mfp)*(1-ofpT)
    #matw <- 1-weight.k.internal*tmfp*sfp


    # calculate batch-specific version of the weight matrix if needed
    if(!is.null(batch) && length(levels(batch)) > 1) { # with batch correction
        # save the dataset-wide one as avmatw
        # calculate mode for each batch
        if(verbose) { cat("batch: [ ") }
        bmatw <- do.call(cbind, tapply(seq_len(nrow(models)), batch, function(ii) {
            if(verbose) { cat(as.character(batch[ii[1]]), " ") }
            # set self-fail probability to p(count|background)
            # total mode failure (including overdispersion dropouts)
            #sfp <- do.call(cbind, lapply(ii, function(i) dpois(cd[, i], exp(models[i, "fail.r"]), log = FALSE)))
            sfp <- do.call(cbind, lapply(ii, function(i) ppois(cd[, i]-1, exp(models[i, "fail.r"]), lower.tail = FALSE)))

            mfp <- scde.failure.probability(models = models[ii, ], magnitudes = log(modes[[batch[ii[1]]]]))
            ofpT <- do.call(cbind, lapply(ii, function(i) { # for each cell
                lfpm <- log(modes[[batch[i]]])
                mu <- models$corr.b[i] + models$corr.a[i]*lfpm
                thetas <- get.corr.theta(models[i, ], lfpm, theta.range)
                pnbinom(1, size = thetas, mu = exp(mu), lower.tail = TRUE)
            }))

            x <- 1-weight.k.internal*mfp*sfp # only mode failure probability
            # mode failure or NB failure
            #tmfp <- 1-(1-mfp)*(1-ofpT)
            #x <- 1-weight.k.internal*tmfp*sfp
        }))
        # reorder
        bmatw <- bmatw[, rownames(models)]
        if(verbose) { cat("] ") }
    }
    if(verbose) { cat("done\n") }

    # calculate effective degrees of freedom
    # total effective degrees of freedom per gene
    if(verbose) { cat("calculating effective degrees of freedom ..") }
    ids <- 1:ncol(cd)
    names(ids) <- colnames(cd)
    # dataset-wide version
    edf.mat <- do.call(cbind, papply(ids, function(i) {
        v <- models[i, ]
        lfpm <- log(avmodes)
        mu <- exp(lfpm*v$corr.a + v$corr.b)
        # adjust very low mu levels except for those that have 0 counts (to avoid inf values)

        thetas <- get.corr.theta(v, lfpm, theta.range)
        edf <- exp(predict(scde.edff, data.frame(lt = log(thetas))))
        edf[thetas > 1e3] <- 1
        edf
    }, n.cores = n.cores))
    if(edf.damping != 1) {
        edf.mat <- ((edf.mat/ncol(edf.mat))^edf.damping) * ncol(edf.mat)
    }

    # incorporate weight into edf
    #edf.mat <- ((matw^weight.df.power)*edf.mat)
    edf.mat <- (matw*edf.mat)^weight.df.power
    #edf <- rowSums(matw*edf.mat)+1.5 # summarize eDF per gene
    edf <- rowSums(edf.mat)+1 # summarize eDF per gene
    if(verbose) { cat(".") }

    # batch-specific version if necessary
    if(!is.null(batch) && is.list(modes)) { # batch-specific mode
        bedf.mat <- do.call(cbind, papply(ids, function(i) {
            v <- models[i, ]
            lfpm <- log(modes[[batch[i]]])
            mu <- exp(lfpm*v$corr.a + v$corr.b)
            # adjust very low mu levels except for those that have 0 counts (to avoid inf values)

            thetas <- get.corr.theta(v, lfpm, theta.range)
            edf <- exp(predict(scde.edff, data.frame(lt = log(thetas))))
            edf[thetas > 1e3] <- 1
            return(edf)
        }, n.cores = n.cores))
        if(edf.damping != 1) { bedf.mat <-  ((bedf.mat/ncol(bedf.mat))^edf.damping) * ncol(edf.mat) }

        # incorporate weight into edf
        #bedf.mat <- ((bmatw^weight.df.power)*bedf.mat)
        bedf.mat <- (bmatw*bedf.mat)^weight.df.power
        bedf <- rowSums(bedf.mat)+1 # summarize eDF per gene
        if(verbose) { cat(".") }
    }

    if(verbose) { cat(" done\n") }

    if(verbose) { cat("calculating normalized expression values ... ") }
    # evaluate negative binomial deviations and effective degrees of freedom
    ids <- 1:ncol(cd)
    names(ids) <- colnames(cd)
    mat <- do.call(cbind, papply(ids, function(i) {
        v <- models[i, ]
        lfpm <- log(avmodes)
        mu <- exp(lfpm*v$corr.a + v$corr.b)
        # adjust very low mu levels except for those that have 0 counts (to avoid inf values)
        thetas <- get.corr.theta(v, lfpm, theta.range)

        #matw[, i]*edf.mat[, i]*(cd[, i]-mu)^2/(mu+mu^2/thetas)
        #x <- (cd[, i]-mu)^2/(mu+mu^2/thetas)
        #edf.mat[, i]*(cd[, i]-mu)^2/(mu+mu^2/thetas)
        # considering Poisson-nb mixture
        fail.lambda <- exp(as.numeric(v["fail.r"]))
        #edf.mat[, i]*(cd[, i]-mu)^2/(matw[, i]*(mu+mu^2/thetas) + (1-matw[, i])*((mu-fail.lambda)^2 + fail.lambda))
        edf.mat[, i]*(cd[, i]-mu)^2/(mu+mu^2/thetas +  fail.lambda)

        #edf.mat[, i]*(cd[, i]-mu)^2/(matw[, i]*mu+(mu^2)*((1-matw[, i])+matw[, i]/thetas))
    }, n.cores = n.cores))
    rownames(mat) <- rownames(cd)
    if(verbose) { cat(".") }
    # batch-specific version of mat
    if(!is.null(batch) && is.list(modes)) { # batch-specific mode
        bmat <- do.call(cbind, papply(ids, function(i) {
            v <- models[i, ]
            lfpm <- log(modes[[batch[i]]])
            mu <- exp(lfpm*v$corr.a + v$corr.b)
            # adjust very low mu levels except for those that have 0 counts (to avoid inf values)
            thetas <- get.corr.theta(v, lfpm, theta.range)

            #matw[, i]*edf.mat[, i]*(cd[, i]-mu)^2/(mu+mu^2/thetas)
            #x <- (cd[, i]-mu)^2/(mu+mu^2/thetas)
            #edf.mat[, i]*(cd[, i]-mu)^2/(mu+mu^2/thetas)
            #edf.mat[, i]*(cd[, i]-mu)^2/(matw[, i]*mu+(mu^2)*((1-matw[, i])+matw[, i]/thetas))
            fail.lambda <- exp(as.numeric(v["fail.r"]))
            #edf.mat[, i]*(cd[, i]-mu)^2/(matw[, i]*(mu+mu^2/thetas) + (1-matw[, i])*((mu-fail.lambda)^2 + fail.lambda))
            edf.mat[, i]*(cd[, i]-mu)^2/(mu+mu^2/thetas +  fail.lambda)
        }, n.cores = n.cores))
        rownames(bmat) <- rownames(cd)

        if(verbose) { cat(".") }
    }
    if(verbose) { cat(" done\n") }

    # do a model fit on the weighted standard deviation (as a function of the batch-average expression mode)
    wvar <- rowSums(mat)/rowSums(edf.mat)

    if(!is.null(batch) && is.list(modes)) { # batch-specific mode
        # estimate the ratio of the batch-specific variance to the total dataset variance
        bwvar <- rowSums(bmat)/rowSums(bedf.mat)
        bwvar.ratio <- bwvar/wvar
        wvar <- bwvar # replace wvar now that we have the ratio of
        matw <- bmatw # replace matw with the batch-specific one that will be used from here on
        # ALTERNATIVE: could adjust wvar for the bwvar.ratio here, before fitting expression dependency
      }
    fvi <- vi <- rowSums(matw) > 0 & is.finite(wvar) & wvar > 0
    if(!is.null(fit.genes)) { fvi <- fvi & rownames(mat) %in% fit.genes }
    if(!any(fvi)) { stop("unable to find a set of valid genes to establish the variance fit") }

    # s = mgcv:::s
    s = mgcv::s
    if(cv.fit) {
        #x <- gam(as.formula("cv2 ~ s(lev)"), data = df[vi, ], weights = rowSums(matw[vi, ]))
        if(is.null(gene.length)) {
            df <- data.frame(lev = log10(avmodes), cv2 = log10(wvar/avmodes^2))
            x <- mgcv::gam(cv2 ~ s(lev, k = smooth.df), data = df[fvi, ], weights = rowSums(matw[fvi, ]))
        } else {
            df <- data.frame(lev = log10(avmodes), cv2 = log10(wvar/avmodes^2), len = gene.length[rownames(cd)])
            x <- mgcv::gam(cv2 ~ s(lev, k = smooth.df) + s(len, k = smooth.df), data = df[fvi, ], weights = rowSums(matw[fvi, ]))
        }
        #x <- lm(cv2~lev, data = df[vi, ], weights = rowSums(matw[vi, ]))

        zval.m <- 10^(df$cv2[vi]-predict(x, newdata = df[vi, ]))

        if(plot) {
            par(mfrow = c(1, 2), mar = c(3.5, 3.5, 1.0, 1.0), mgp = c(2, 0.65, 0))
            #smoothScatter(df$lev[vi], log(wvar[vi]), nbin = 256, xlab = "expression magnitude (log10)", ylab = "wvar (log)") abline(h = 0, lty = 2, col = 2)
            #points(df[paste("g", diff.exp.gene.ids, sep = ""), "lev"], log(wvar[paste("g", diff.exp.gene.ids, sep = "")]), col = 2)

            smoothScatter(df$lev[vi], df$cv2[vi], nbin = 256, xlab = "expression magnitude (log10)", ylab = "cv^2 (log10)")
            lines(sort(df$lev[vi]), predict(x, newdata = df[vi, ])[order(df$lev[vi])], col = 2, pch = ".", cex = 1)
            if(!is.null(fit.genes)) { # show genes used for the fit
              points(df$lev[fvi],df$cv2[fvi],pch=".",col="green",cex=1)
            }

            #points(df[paste("g", diff.exp.gene.ids, sep = ""), "lev"], df[paste("g", diff.exp.gene.ids, sep = ""), "cv2"], col = 2)
        }

        # optional : re-weight to minimize the underdispersed points
        if(minimize.underdispersion) {
            pv <- pchisq(zval.m*(edf[vi]-1), edf[vi], log.p = FALSE, lower.tail = TRUE)
            pv[edf[vi]<= min.edf] <- 0
            pv <- p.adjust(pv)
            #x <- gam(as.formula("cv2 ~ s(lev)"), data = df[vi, ], weights = (pmin(10, -log(pv))+1)*rowSums(matw[vi, ]))
            x <- mgcv::gam(cv2 ~ s(lev, k = smooth.df), data = df[fvi, ], weights = (pmin(10, -log(pv))+1)*rowSums(matw[fvi, ]))
            zval.m <- 10^(df$cv2[vi]-predict(x,newdata=df[vi,]))
            if(plot) {
              lines(sort(df$lev[vi]), predict(x, newdata = df[vi, ])[order(df$lev[vi])], col = 4, pch = ".", cex = 1)
            }
        }
    } else {
        df <- data.frame(lev = log10(avmodes), sd = sqrt(wvar))
        #x <- gam(as.formula("sd ~ s(lev)"), data = df[vi, ], weights = rowSums(matw[vi, ]))
        x <- mgcv::gam(sd ~ s(lev, k = smooth.df), data = df[fvi, ], weights = rowSums(matw[fvi, ]))
        zval.m <- (as.numeric((df$sd[vi])/pmax(min.sd, predict(x,newdata=df[vi,]))))^2

        if(plot) {
            par(mfrow = c(1, 2), mar = c(3.5, 3.5, 1.0, 1.0), mgp = c(2, 0.65, 0))
            smoothScatter(df$lev[vi], df$sd[vi], nbin = 256, xlab = "expression magnitude", ylab = "weighted sdiv")
            lines(sort(df$lev[vi]), predict(x, newdata = df[vi, ])[order(df$lev[vi])], col = 2, pch = ".", cex = 1)
            if(!is.null(fit.genes)) { # show genes used for the fit
              points(df$lev[fvi],df$sd[fvi],pch=".",col="green",cex=1)
            }
        }

        # optional : re-weight to minimize the underdispersed points
        if(minimize.underdispersion) {
            pv <- pchisq(zval.m*(edf[vi]-1), edf[vi], log.p = FALSE, lower.tail = TRUE)
            pv[edf[vi]<= min.edf] <- 0
            pv <- p.adjust(pv)
            #x <- gam(as.formula("sd ~ s(lev)"), data = df[vi, ], weights = (pmin(20, -log(pv))+1)*rowSums(matw[vi, ]))
            x <- mgcv::gam(sd ~ s(lev, k = smooth.df), data = df[fvi, ], weights = (pmin(20, -log(pv))+1)*rowSums(matw[fvi, ]))
            zval.m <- (as.numeric((df$sd[vi])/pmax(x$fitted.values, min.sd)))^2
            if(plot) {
              lines(sort(df$lev[vi]), predict(x, newdata = df[vi, ])[order(df$lev[vi])], col = 4, pch = ".", cex = 1)
            }
        }
    }

    # adjust for inter-batch variance
    if(!is.null(batch) && is.list(modes)) { # batch-specific mode
        #zval.m <- zval.m*pmin(bwvar.ratio[vi], 1) # don't increase zval.m even if batch-specific specific variance is higher than the dataset-wide variance
        zval.m <- zval.m*pmin(bwvar.ratio[vi], 1/bwvar.ratio[vi]) # penalize for strong deviation in either direction
    }

    # calculate adjusted variance
    qv <- pchisq(zval.m*(edf[vi]-1), edf[vi], log.p = TRUE, lower.tail = FALSE)
    qv[edf[vi]<= min.edf] <- 0
    qv[abs(qv)<1e-10] <- 0
    arv <- rep(NA, length(vi))
    arv[vi] <- qchisq(qv, ncol(matw)-1, lower.tail = FALSE, log.p = TRUE)/ncol(matw)
    arv <- pmin(max.adj.var, arv)
    names(arv) <- rownames(cd)
    if(plot) {
        smoothScatter(df$lev[vi], arv[vi], xlab = "expression magnitude (log10)", ylab = "adjusted variance (log10)", nbin = 256)
        abline(h = 1, lty = 2, col = 8)
        abline(h = max.adj.var, lty = 3, col = 2)
        if(!is.null(fit.genes)) {
          points(df$lev[fvi],arv[fvi],pch=".",col="green",cex=1)
        }
        #points(df[paste("g", diff.exp.gene.ids, sep = ""), "lev"], arv[paste("g", diff.exp.gene.ids, sep = "")], col = 2)
        #points(df$lev[vi], arv[vi], col = 2, pch = ".", cex = 2)
    }

    # Wilcox score upper bound
    wsu <- function(k, n, z = qnorm(0.975)) {
        p <- k/n
        pmin(1, (2*n*p+z^2+(z*sqrt(z^2-1/n+4*n*p*(1-p)-(4*p-2)) +1))/(2*(n+z^2)))
    }

    # use milder weight matrix
    #matw <- 1-0.9*((1-matw)^2) # milder weighting for the the PCA (1-0.9*sp*mf)
    matw <- 1-weight.k*(1-matw) # milder weighting for the the PCA (1-0.9*sp*mf)
    matw <- matw/rowSums(matw)
    mat <- log10(exp(scde.expression.magnitude(models, cd))+1)

    # estimate observed variance (for scaling) before batch adjustments
    #varm <- sqrt(arv/pmax(weightedMatVar(mat, matw, batch = batch), 1e-5)) varm[varm<1e-5] <- 1e-5 mat <- mat*varm
    ov <- weightedMatVar(mat, matw)
    vr <- arv/ov
    vr[ov <=  0] <- 0

    if(!is.null(batch) && is.list(modes)) { # batch-specific mode
        # adjust proportion of zeros
        # determine lowest upper bound of non-zero measurement probability among the batch (for each gene)
        nbub <- apply(do.call(cbind, tapply(seq_len(ncol(mat)), batch, function(ii) {
            wsu(rowSums(mat[, ii] > 0), length(ii), z = qnorm(1-1e-2))
        })), 1, min)

        # decrease the batch weights for each gene to match the total
        # expectation of the non-zero measurements
        nbo <- do.call(cbind, tapply(seq_len(ncol(mat)), batch, function(ii) {
            matw[, ii]*pmin(1, ceiling(nbub*length(ii))/rowSums(mat[, ii] > 0))
        }))
        nbo <- nbo[, colnames(matw)]
        matw <- nbo

        ## # center 0 and non-0 observations between batches separately
        ## amat <- mat amat[amat == 0] <- NA
        ## amat.av <- rowMeans(amat, na.rm = TRUE) # dataset means
        ## # adjust each batch by the mean of its non-0 measurements
        ## amat <- do.call(cbind, tapply(1:ncol(amat), batch, function(ii) {
        ##   amat[, ii]-rowMeans(amat[, ii], na.rm = TRUE)
        ## }))
        ## amat <- amat[, colnames(mat)] # fix the ordering
        ## # shift up each gene by the dataset mean
        ## amat <- amat+amat.av
        ## amat[is.na(amat)] <- 0
        ## mat <- amat

        amat <- mat
        nr <- ncol(matw)/rowSums(matw)
        amat.av <- rowMeans(amat*matw)*nr # dataset means
        amat <- do.call(cbind, tapply(seq_len(ncol(amat)), batch, function(ii) {
            amat[, ii]-(rowMeans(amat[, ii]*matw[, ii]*nr, na.rm = TRUE))
        }))
        amat <- amat[, colnames(matw)]
        mat <- amat+amat.av

        # alternative: actually zero-out entries in mat
        ## nbub <- rowMin(do.call(cbind, tapply(1:ncol(amat), batch, function(ii) {
        ##   wsu(rowSums(amat[, ii] > 0), length(ii), z = qnorm(1-1e-2))
        ## })))
        ## set.seed(0)

        ## # decrease the batch weights for each gene to match the total
        ## # expectation of the non-zero measurements
        ## matm <- do.call(cbind, tapply(1:ncol(amat), batch, function(ii) {
        ##   # number of entries to zero-out per gene
        ##   nze <- rowSums(amat[, ii] > 0) - ceiling(nbub*length(ii))
        ##   # construct mat multiplier submatrix
        ##   sa <- rep(1, length(ii))
        ##   smatm <- do.call(rbind, lapply(1:length(nze), function(ri) {
        ##     if(nze[ri]<1) { return(sa) }
        ##     vi <- which(mat[ri, ii] > 0)
        ##     a <- sa a[vi[sample.int(length(vi), nze[ri])]] <- 0
        ##     a
        ##   }))
        ##   colnames(smatm) <- colnames(mat[, ii])
        ##   rownames(smatm) <- rownames(mat)
        ##   smatm
        ## }))
        ## matm <- matm[, colnames(mat)]
        ## mat <- mat*matm
        ## matw <- matw*matm
    }

    # center (no batch)
    mat <- weightedMatCenter(mat, matw)
    mat <- mat*sqrt(vr)

    if(!is.null(batch) && is.list(modes)) { # batch-specific mode
        return(list(mat = mat, matw = matw, arv = arv, modes = modes, avmodes = avmodes, prior = prior, edf = edf, batch = batch, trim = trim, bwvar.ratio = bwvar.ratio))
    } else {
        return(list(mat = mat, matw = matw, arv = arv, modes = modes, avmodes = avmodes, prior = prior, edf = edf, batch = batch, trim = trim))
    }
}


##' Control for a particular aspect of expression heterogeneity in a given population
##'
##' Similar to subtracting n-th principal component, the current procedure determines
##' (weighted) projection of the expression matrix onto a specified aspect (some pattern
##' across cells, for instance sequencing depth, or PC corresponding to an undesired process
##' such as ribosomal pathway variation) and subtracts it from the data so that it is controlled
##' for in the subsequent weighted PCA analysis.
##'
##' @param varinfo normalized variance info (from pagoda.varnorm())
##' @param aspect a vector giving a cell-to-cell variation pattern that should be controlled for (length should be corresponding to ncol(varinfo$mat))
##' @param center whether the matrix should be re-centered following pattern subtraction
##'
##' @return a modified varinfo object with adjusted expression matrix (varinfo$mat)
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' # create go environment
##' library(org.Hs.eg.db)
##' # translate gene names to ids
##' ids <- unlist(lapply(mget(rownames(cd), org.Hs.egALIAS2EG, ifnotfound = NA), function(x) x[1]))
##' rids <- names(ids); names(rids) <- ids
##' go.env <- lapply(mget(ls(org.Hs.egGO2ALLEGS), org.Hs.egGO2ALLEGS), function(x) as.character(na.omit(rids[x])))
##' # clean GOs
##' go.env <- clean.gos(go.env)
##' # convert to an environment
##' go.env <- list2env(go.env)
##' # subtract the pattern
##' cc.pattern <- pagoda.show.pathways(ls(go.env)[1:2], varinfo, go.env, show.cell.dendrogram = TRUE, showRowLabels = TRUE)  # Look at pattern from 2 GO annotations
##' varinfo.cc <- pagoda.subtract.aspect(varinfo, cc.pattern)
##' }
##'
##' @export
pagoda.subtract.aspect <- function(varinfo, aspect, center = TRUE) {
    if(length(aspect) != ncol(varinfo$mat)) { stop("aspect should be a numeric vector of the same length as the number of cells (i.e. ncol(varinfo$mat))") }
    v <- aspect
    v <- v-mean(v)
    v <- v/sqrt(sum(v^2))
    nr <- ((varinfo$mat * varinfo$matw) %*% v)/(varinfo$matw %*% v^2)
    mat.c <- varinfo$mat - t(v %*% t(nr))
    if(center) {
        mat.c <- weightedMatCenter(mat.c, varinfo$matw) # this commonly re-introduces some background dependency because of the matw
    }
    varinfo$mat <- mat.c
    varinfo
}


##' Run weighted PCA analysis on pre-annotated gene sets
##'
##' For each valid gene set (having appropriate number of genes) in the provided environment (setenv),
##' the method will run weighted PCA analysis, along with analogous analyses of random gene sets of the
##' same size, or shuffled expression magnitudes for the same gene set.
##'
##' @param varinfo adjusted variance info from pagoda.varinfo() (or pagoda.subtract.aspect())
##' @param setenv environment listing gene sets (contains variables with names corresponding to gene set name, and values being vectors of gene names within each gene set)
##' @param n.components number of principal components to determine for each gene set
##' @param n.cores number of cores to use
##' @param min.pathway.size minimum number of observed genes that should be contained in a valid gene set
##' @param max.pathway.size maximum number of observed genes in a valid gene set
##' @param n.randomizations number of random gene sets (of the same size) to be evaluated in parallel with each gene set (can be kept at 5 or 10, but should be increased to 50-100 if the significance of pathway overdispersion will be determined relative to random gene set models)
##' @param n.internal.shuffles number of internal (independent row shuffles) randomizations of expression data that should be evaluated for each gene set (needed only if one is interested in gene set coherence P values, disabled by default; set to 10-30 to estimate)
##' @param n.starts number of random starts for the EM method in each evaluation
##' @param center whether the expression matrix should be recentered
##' @param batch.center whether batch-specific centering should be used
##' @param proper.gene.names alternative vector of gene names (replacing rownames(varinfo$mat)) to be used in cases when the provided setenv uses different gene names
##' @param verbose verbosity level
##'
##' @return a list of weighted PCA info for each valid gene set
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' # create go environment
##' library(org.Hs.eg.db)
##' # translate gene names to ids
##' ids <- unlist(lapply(mget(rownames(cd), org.Hs.egALIAS2EG, ifnotfound = NA), function(x) x[1]))
##' rids <- names(ids); names(rids) <- ids
##' go.env <- lapply(mget(ls(org.Hs.egGO2ALLEGS), org.Hs.egGO2ALLEGS), function(x) as.character(na.omit(rids[x])))
##' # clean GOs
##' go.env <- clean.gos(go.env)
##' # convert to an environment
##' go.env <- list2env(go.env)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' }
##'
##' @export
pagoda.pathway.wPCA <- function(varinfo, setenv, n.components = 2, n.cores = detectCores(), min.pathway.size = 10, max.pathway.size = 1e3, n.randomizations = 10, n.internal.shuffles = 0, n.starts = 10, center = TRUE, batch.center = TRUE, proper.gene.names = NULL, verbose = 0) {
    mat <- varinfo$mat
    matw <- varinfo$matw
    gsl <- NULL
    return.gsl <- FALSE
    smooth <- 0
    if(batch.center) { batch <- varinfo$batch } else { batch <- NULL }

    if(is.null(proper.gene.names)) { proper.gene.names <- rownames(mat) }


    if(center) {
        mat <- weightedMatCenter(mat, matw, batch = batch)
    }

    vi <- apply(mat, 1, function(x) sum(abs(diff(x))) > 0)
    vi[is.na(vi)] <- FALSE
    mat <- mat[vi, , drop = FALSE] # remove constant rows
    matw <- matw[vi, , drop = FALSE]
    proper.gene.names <- proper.gene.names[vi]

    if(is.null(gsl)) {
        gsl <- ls(envir = setenv)
        gsl.ng <- unlist(lapply(sn(gsl), function(go) sum(unique(get(go, envir = setenv)) %in% proper.gene.names)))
        gsl <- gsl[gsl.ng >= min.pathway.size & gsl.ng<= max.pathway.size]
        names(gsl) <- gsl
    }
    if(verbose) {
        message("processing ", length(gsl), " valid pathways")
    }
    if(return.gsl) return(gsl)


    # transpose mat to save a bit of calculations
    mat <- t(mat)
    matw <- t(matw)

    mcm.pc <- papply(gsl, function(x) {
        lab <- proper.gene.names %in% get(x, envir = setenv)
        if(sum(lab)<1) { return(NULL) }

        #smooth <- round(sum(lab)*smooth.fraction)
        #smooth <- max(sum(lab), smooth)

        #xp <- pca(d, nPcs = n.components, center = TRUE, scale = "none")
        #xp <- epca(mat[, lab], ncomp = n.components, center = FALSE, nstarts = n.starts)
        xp <- bwpca(mat[, lab, drop = FALSE], matw[, lab, drop = FALSE], npcs = n.components, center = FALSE, nstarts = n.starts, smooth = smooth, n.shuffles = n.internal.shuffles)

        # get standard deviations for the random samples
        ngenes <- sum(lab)
        z <- do.call(rbind, lapply(seq_len(n.randomizations), function(i) {
            si <- sample(1:ncol(mat), ngenes)
            #epca(mat[, si], ncomp = 1, center = FALSE, nstarts = n.starts)$sd
            xp <- bwpca(mat[, si, drop = FALSE], matw[, si, drop = FALSE], npcs = 1, center = FALSE, nstarts = n.starts, smooth = smooth)$sd
        }))

        # flip orientations to roughly correspond with the means
        cs <- unlist(lapply(seq_len(ncol(xp$scores)), function(i) sign(cor(xp$scores[, i], colMeans(t(mat[, lab, drop = FALSE])*abs(xp$rotation[, i]))))))

        xp$scores <- t(t(xp$scores)*cs)
        xp$rotation <- t(t(xp$rotation)*cs)

        # local normalization of each component relative to sampled PC1 sd
        avar <- pmax(0, (xp$sd^2-mean(z[, 1]^2))/sd(z[, 1]^2))
        xv <- t(xp$scores)
        xv <- xv/apply(xv, 1, sd)*sqrt(avar)
        return(list(xv = xv, xp = xp, z = z, sd = xp$sd, n = ngenes))
    }, n.cores = n.cores)
}


##' Estimate effective number of cells based on lambda1 of random gene sets
##'
##' Examines the dependency between the amount of variance explained by the first principal component
##' of a gene set and the number of genes in a gene set to determine the effective number of cells
##' for the Tracy-Widom distribution
##'
##' @param pwpca result of the pagoda.pathway.wPCA() call with n.randomizations > 1
##' @param start optional starting value for the optimization (if the NLS breaks, trying high starting values usually fixed the local gradient problem)
##'
##' @return effective number of cells
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' pagoda.effective.cells(pwpca)
##' }
##'
##' @export
pagoda.effective.cells <- function(pwpca, start = NULL) {
    n.genes <- unlist(lapply(pwpca, function(x) rep(x$n, nrow(x$z))))
    var <- unlist(lapply(pwpca, function(x) x$z[, 1]))^2
    if(is.null(start)) { start <- nrow(pwpca[[1]]$xp$scores)*10 }

    n.cells <- nrow(pwpca[[1]]$xp$scores)
    of <- function(p, v, sp) {
        sn <- p[1]
        vfit <- (sn+sp)^2/(sn*sn+1/2) -1.2065335745820*(sn+sp)*((1/sn + 1/sp)^(1/3))/(sn*sn+1/2)
        residuals <- (v-vfit)^2
        return(sum(residuals))
    }
    x <- nlminb(objective = of, start = c(start), v = var, sp = sqrt(n.genes-1/2), lower = c(1), upper = c(n.cells))
    return((x$par)^2+1/2)
}


##' Determine de-novo gene clusters and associated overdispersion info
##'
##' Determine de-novo gene clusters, their weighted PCA lambda1 values, and random matrix expectation.
##'
##' @param varinfo varinfo adjusted variance info from pagoda.varinfo() (or pagoda.subtract.aspect())
##' @param trim additional Winsorization trim value to be used in determining clusters (to remove clusters that group outliers occurring in a given cell). Use higher values (5-15) if the resulting clusters group outlier patterns
##' @param n.clusters number of clusters to be determined (recommended range is 100-200)
##' @param cor.method correlation method ("pearson", "spearman") to be used as a distance measure for clustering
##' @param n.samples number of randomly generated matrix samples to test the background distribution of lambda1 on
##' @param n.starts number of wPCA EM algorithm starts at each iteration
##' @param n.internal.shuffles number of internal shuffles to perform (only if interested in set coherence, which is quite high for clusters by definition, disabled by default; set to 10-30 shuffles to estimate)
##' @param n.cores number of cores to use
##' @param verbose verbosity level
##' @param plot whether a plot showing distribution of random lambda1 values should be shown (along with the extreme value distribution fit)
##' @param show.random whether the empirical random gene set values should be shown in addition to the Tracy-Widom analytical approximation
##' @param n.components number of PC to calculate (can be increased if the number of clusters is small and some contain strong secondary patterns - rarely the case)
##' @param method clustering method to be used in determining gene clusters
##' @param secondary.correlation whether clustering should be performed on the correlation of the correlation matrix instead
##' @param n.cells number of cells to use for the randomly generated cluster lambda1 model
##' @param old.results optionally, pass old results just to plot the model without recalculating the stats
##'
##' @return a list containing the following fields:
##' \itemize{
##' \item{clusters} {a list of genes in each cluster values}
##' \item{xf} { extreme value distribution fit for the standardized lambda1 of a randomly generated pattern}
##' \item{tci} { index of a top cluster in each random iteration}
##' \item{cl.goc} {weighted PCA info for each real gene cluster}
##' \item{varm} {standardized lambda1 values for each randomly generated matrix cluster}
##' \item{clvlm} {a linear model describing dependency of the cluster lambda1 on a Tracy-Widom lambda1 expectation}
##' }
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' clpca <- pagoda.gene.clusters(varinfo, trim=7.1/ncol(varinfo$mat), n.clusters=150, n.cores=10, plot=FALSE)
##' }
##'
##' @export
pagoda.gene.clusters <- function(varinfo, trim = 3.1/ncol(varinfo$mat), n.clusters = 150, n.samples = 60, cor.method = "p", n.internal.shuffles = 0, n.starts = 10, n.cores = detectCores(), verbose = 0, plot = FALSE, show.random = FALSE, n.components = 1, method = "ward.D", secondary.correlation = FALSE, n.cells = ncol(varinfo$mat), old.results = NULL) {

    smooth <- 0
    mat <- varinfo$mat
    matw <- varinfo$matw
    batch = varinfo$batch

    if(trim > 0) {
        mat <- winsorize.matrix(mat, trim = trim)
    }
    if(!is.null(batch)) {
        # center mat by batch
        mat <- weightedMatCenter(mat, matw, batch)
    }


    if(!is.null(old.results)) {
        if(verbose) { cat ("reusing old results for the observed clusters\n")}
        gcls <- old.results$clusters
        cl.goc <- old.results$cl.goc
    } else {
        if(verbose) { cat ("determining gene clusters ...")}
        # actual clusters
        vi<-which(abs(apply(mat, 1, function(x) sum(abs(diff(x))))) > 0)
        if(is.element("WGCNA", installed.packages()[, 1])) {
            gd <- as.dist(1-WGCNA::cor(t(mat)[, vi], method = cor.method, nThreads = n.cores))
        } else {
            gd <- as.dist(1-cor(t(mat)[, vi], method = cor.method))
        }

        if(secondary.correlation) {
            if(is.element("WGCNA", installed.packages()[, 1])) {
                gd <- as.dist(1-WGCNA::cor(as.matrix(gd), method = "p", nThreads = n.cores))
            } else {
                gd <- as.dist(1-cor(as.matrix(gd), method = "p"))
            }
        }

        if(is.element("fastcluster", installed.packages()[, 1])) {
            gcl <- fastcluster::hclust(gd, method = method)
        } else {
            gcl <- stats::hclust(gd, method = method)
        }
        gcll <- cutree(gcl, n.clusters)
        gcls <- tapply(rownames(mat)[vi], as.factor(gcll), I)
        names(gcls) <- paste("geneCluster", names(gcls), sep = ".")

        rm(gd, gcl)
        gc()

        # determine PC1 for the actual clusters
        if(verbose) { cat (" cluster PCA ...")}
        il <- tapply(vi, factor(gcll, levels = c(1:length(gcls))), I)
        cl.goc <- papply(il, function(ii) {
            xp <- bwpca(t(mat[ii, , drop = FALSE]), t(matw[ii, , drop = FALSE]), npcs = n.components, center = FALSE, nstarts = n.starts, smooth = smooth, n.shuffles = n.internal.shuffles)

            cs <- unlist(lapply(seq_len(ncol(xp$scores)), function(i) sign(cor(xp$scores[, i], colMeans(mat[ii, , drop = FALSE]*abs(xp$rotation[, i]))))))

            xp$scores <- t(t(xp$scores)*cs)
            xp$rotation <- t(t(xp$rotation)*cs)

            return(list(xp = xp, sd = xp$sd, n = length(ii)))
        }, n.cores = n.cores)
        names(cl.goc) <- paste("geneCluster", names(cl.goc), sep = ".")

        if(verbose) { cat ("done\n")}
    }

    # sampled variation
    if(!is.null(old.results) && !is.null(old.results$varm)) {
        if(verbose) { cat ("reusing old results for the sampled clusters\n")}
        varm <- old.results$varm } else {
            if(verbose) { cat ("generating", n.samples, "randomized samples ")}
            varm <- do.call(rbind, papply(seq_len(n.samples), function(i) { # each sampling iteration
                set.seed(i)
                # generate random normal matrix
                # TODO: use n.cells instead of ncol(matw)
                m <- matrix(rnorm(nrow(mat)*n.cells), nrow = nrow(mat), ncol = n.cells)
                #m <- weightedMatCenter(m, matw, batch = batch)

                if(show.random) {
                    full.m <- t(m) # save untrimmed version of m for random gene set controls
                }

                if(trim > 0) {
                    m <- winsorize.matrix(m, trim = trim)
                }

                vi<-which(abs(apply(m, 1, function(x) sum(diff(abs(x))))) > 0)
                if(is.element("WGCNA", installed.packages()[, 1])) {
                    gd <- as.dist(1-WGCNA::cor(t(m[vi, ]), method = cor.method, nThreads = 1))
                } else {
                    gd <- as.dist(1-cor(t(m[vi, ]), method = cor.method))
                }
                if(secondary.correlation) {
                    if(is.element("WGCNA", installed.packages()[, 1])) {
                        gd <- as.dist(1-WGCNA::cor(as.matrix(gd), method = "p", nThreads = 1))
                    } else {
                        gd <- as.dist(1-cor(as.matrix(gd), method = "p"))
                    }
                }

                if(is.element("fastcluster", installed.packages()[, 1])) {
                    gcl <- fastcluster::hclust(gd, method = method)
                } else {
                    gcl <- stats::hclust(gd, method = method)
                }
                gcll <- cutree(gcl, n.clusters)
                rm(gd, gcl)
                gc()

                # transpose to save time
                m <- t(m) # matw <- t(matw)

                sdv <- tapply(vi, gcll, function(ii) {
                    #as.numeric(bwpca(m[, ii], matw[, ii], npcs = 1, center = FALSE, nstarts = n.starts, smooth = smooth)$sd)^2
                    pcaMethods::sDev(pcaMethods::pca(m[, ii], nPcs = 1, center = FALSE))^2
                })

                pathsizes <- unlist(tapply(vi, gcll, length))
                names(pathsizes) <- pathsizes

                if(show.random) {
                    rsdv <- unlist(lapply(names(pathsizes), function(s) {
                        vi <- sample(1:ncol(full.m), as.integer(s))
                        pcaMethods::sDev(pcaMethods::pca(full.m[, vi], nPcs = 1, center = FALSE))^2
                    }))
                    if(verbose) { cat (".")}
                    return(data.frame(n = as.integer(pathsizes), var = unlist(sdv), round = i, rvar = rsdv))
                }

                if(verbose) { cat (".")}
                data.frame(n = as.integer(pathsizes), var = unlist(sdv), round = i)

            }, n.cores = n.cores))
            if(verbose) { cat ("done\n")}
        }

    # score relative to Tracey-Widom distribution
    #require(RMTstat)
    x <- RMTstat::WishartMaxPar(n.cells, varm$n)
    varm$pm <- x$centering-(1.2065335745820)*x$scaling # predicted mean of a random set
    varm$pv <- (1.607781034581)*x$scaling # predicted variance of a random set
    #clvlm <- lm(var~pm, data = varm)
    clvlm <- lm(var~0+pm+n, data = varm)
    varm$varst <- (varm$var-predict(clvlm))/sqrt(varm$pv)
    #varm$varst <- as.numeric(varm$var - (cbind(1, varm$pm) %*% coef(clvlm)))/sqrt(varm$pv)
    #varm$varst <- as.numeric(varm$var - (varm$pm* coef(clvlm)[2]))/sqrt(varm$pv)

    #varm$varst <- (varm$var-varm$pm)/sqrt(varm$pv)
    tci <- tapply(seq_len(nrow(varm)), as.factor(varm$round), function(ii) ii[which.max(varm$varst[ii])])

    #xf <- fevd(varm$varst[tci], type = "Gumbel") # fit on top clusters
    xf <- extRemes::fevd(varm$varst, type = "Gumbel") # fit on all clusters

    if(plot) {
        require(extRemes)
        par(mfrow = c(1, 2), mar = c(3.5, 3.5, 3.5, 1.0), mgp = c(2, 0.65, 0), cex = 0.9)
        smoothScatter(varm$n, varm$var, main = "simulations", xlab = "cluster size", ylab = "PC1 variance")
        if(show.random) {
            points(varm$n, varm$rvar, pch = ".", col = "red")
        }
        #pv <- predict(rsm, newdata = data.frame(n = sort(varm$n)), se.fit = TRUE)
        on <- order(varm$n, decreasing = TRUE)
        lines(varm$n[on], predict(clvlm)[on], col = 4, lty = 3)
        lines(varm$n[on], varm$pm[on], col = 2)
        lines(varm$n[on], (varm$pm+1.96*sqrt(varm$pv))[on], col = 2, lty = 2)
        lines(varm$n[on], (varm$pm-1.96*sqrt(varm$pv))[on], col = 2, lty = 2)
        legend(x = "bottomright", pch = c(1, 19, 19), col = c(1, 4, 2), legend = c("top clusters", "clusters", "random"), bty = "n")

        points(varm$n[tci], varm$var[tci], col = 1)
        extRemes::plot.fevd(xf, type = "density", main = "Gumbel fit")
        abline(v = 0, lty = 3, col = 4)
    }

    #pevd(9, loc = xf$results$par[1], scale = xf$results$par[2], lower.tail = FALSE)
    #xf$results$par

    return(list(clusters = gcls, xf = xf, tci = tci, cl.goc = cl.goc, varm = varm, clvlm = clvlm, trim = trim))
}


##' Score statistical significance of gene set and cluster overdispersion
##'
##' Evaluates statistical significance of the gene set and cluster lambda1 values, returning
##' either a text table of Z scores, etc, a structure containing normalized values of significant
##' aspects, or a set of genes underlying the significant aspects.
##'
##' @param pwpca output of pagoda.pathway.wPCA()
##' @param clpca output of pagoda.gene.clusters() (optional)
##' @param n.cells effective number of cells (if not provided, will be determined using pagoda.effective.cells())
##' @param z.score Z score to be used as a cutoff for statistically significant patterns (defaults to 0.05 P-value
##' @param return.table whether a text table showing
##' @param return.genes whether a set of genes driving significant aspects should be returned
##' @param plot whether to plot the cv/n vs. dataset size scatter showing significance models
##' @param adjust.scores whether the normalization of the aspect patterns should be based on the adjusted Z scores - qnorm(0.05/2, lower.tail = FALSE)
##' @param score.alpha significance level of the confidence interval for determining upper/lower bounds
##' @param use.oe.scale whether the variance of the returned aspect patterns should be normalized using observed/expected value instead of the default chi-squared derived variance corresponding to overdispersion Z score
##' @param effective.cells.start starting value for the pagoda.effective.cells() call
##'
##' @return if return.table = FALSE and return.genes = FALSE (default) returns a list structure containing the following items:
##' \itemize{
##' \item{xv} {a matrix of normalized aspect patterns (rows- significant aspects, columns- cells}
##' \item{xvw} { corresponding weight matrix }
##' \item{gw} { set of genes driving the significant aspects }
##' \item{df} { text table with the significance testing results }
##' }
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' tam <- pagoda.top.aspects(pwpca, return.table = TRUE, plot=FALSE, z.score=1.96)  # top aspects based on GO only
##' }
##'
##' @export
pagoda.top.aspects <- function(pwpca, clpca = NULL, n.cells = NULL, z.score = qnorm(0.05/2, lower.tail = FALSE), return.table = FALSE, return.genes = FALSE, plot = FALSE, adjust.scores = TRUE, score.alpha = 0.05, use.oe.scale = FALSE, effective.cells.start = NULL) {
    basevar = 1

    if(is.null(n.cells)) {
        n.cells <- pagoda.effective.cells(pwpca, start = effective.cells.start)
    }


    vdf <- data.frame(do.call(rbind, lapply(seq_along(pwpca), function(i) {
        vars <- as.numeric((pwpca[[i]]$sd)^2)
        shz <- NA
        if(!is.null(pwpca[[i]]$xp$randvar)) { shz <- (vars - mean(pwpca[[i]]$xp$randvar))/sd(pwpca[[i]]$xp$randvar) }
        cbind(i = i, var = vars, n = pwpca[[i]]$n, npc = seq(1:ncol(pwpca[[i]]$xp$scores)), shz = shz)
    })))

    # fix p-to-q mistake in qWishartSpike
    qWishartSpikeFixed <- function (q, spike, ndf = NA, pdim = NA, var = 1, beta = 1, lower.tail = TRUE, log.p = FALSE)  {
        params <- RMTstat::WishartSpikePar(spike, ndf, pdim, var, beta)
        qnorm(q, mean = params$centering, sd = params$scaling, lower.tail, log.p)
    }

    # add right tail approximation to ptw, which gives up quite early
    pWishartMaxFixed <- function (q, ndf, pdim, var = 1, beta = 1, lower.tail = TRUE) {
        params <- RMTstat::WishartMaxPar(ndf, pdim, var, beta)
        q.tw <- (q - params$centering)/(params$scaling)
        p <- RMTstat::ptw(q.tw, beta, lower.tail, log.p = TRUE)
        p[p == -Inf] <- pgamma((2/3)*q.tw[p == -Inf]^(3/2), 2/3, lower.tail = FALSE, log.p = TRUE) + lgamma(2/3) + log((2/3)^(1/3))
        p
    }


    #bi <- which.max(unlist(lapply(pwpca, function(x) x$n)))
    #vshift <- mean(pwpca[[bi]]$z[, 1]^2)/pwpca[[bi]]$n
    #ev <- ifelse(spike > 0, qWishartSpikeFixed(0.5, spike, n.cells, pwpca[[bi]]$n, var = basevar, lower.tail = FALSE), RMTstat::qWishartMax(0.5, n.cells, pwpca[[bi]]$n, var = basevar, lower.tail = FALSE))/pwpca[[bi]]$n
    #cat("vshift = ", vshift)
    vshift <- 0
    ev <- 0

    vdf$var <- vdf$var-(vshift-ev)*vdf$n

    #vdf$var[vdf$npc == 1] <- vdf$var[vdf$npc == 1]-(vshift-ev)*vdf$n[vdf$npc == 1]
    vdf$exp <- RMTstat::qWishartMax(0.5, n.cells, vdf$n, var = basevar, lower.tail = FALSE)
    #vdf$z <- qnorm(pWishartMax(vdf$var, n.cells, vdf$n, log.p = TRUE, lower.tail = FALSE, var = basevar), lower.tail = FALSE, log.p = TRUE)
    vdf$z <- qnorm(pWishartMaxFixed(vdf$var, n.cells, vdf$n, lower.tail = FALSE, var = basevar), lower.tail = FALSE, log.p = TRUE)
    vdf$cz <- qnorm(bh.adjust(pnorm(as.numeric(vdf$z), lower.tail = FALSE, log.p = TRUE), log = TRUE), lower.tail = FALSE, log.p = TRUE)
    vdf$ub <- RMTstat::qWishartMax(score.alpha/2, n.cells, vdf$n, var = basevar, lower.tail = FALSE)
    vdf$ub.stringent <- RMTstat::qWishartMax(score.alpha/nrow(vdf)/2, n.cells, vdf$n, var = basevar, lower.tail = FALSE)

    if(!is.null(clpca)) {
        clpca$xf <- extRemes::fevd(varst, data = clpca$varm, type = "Gumbel")
        #clpca$xf <- fevd(clpca$varm$varst[clpca$tci], type = "Gumbel")
        clpca$xf$results$par <- c(clpca$xf$results$par, c(shape = 0))
        #plot(xf)

        clvdf <- data.frame(do.call(rbind, lapply(seq_along(clpca$cl.goc), function(i)  {
            vars <- as.numeric((clpca$cl.goc[[i]]$sd)^2)
            shz <- NA
            if(!is.null(clpca$cl.goc[[i]]$xp$randvar)) {
                shz <- (vars - mean(clpca$cl.goc[[i]]$xp$randvar))/sd(clpca$cl.goc[[i]]$xp$randvar)
            }
            cbind(i = i, var = vars, n = clpca$cl.goc[[i]]$n, npc = seq(1:ncol(clpca$cl.goc[[i]]$xp$scores)), shz = shz)
        })))

        clvdf$var <- clvdf$var-(vshift-ev)*clvdf$n

        x <- RMTstat::WishartMaxPar(n.cells, clvdf$n)
        clvdf$pm <- x$centering-(1.2065335745820)*x$scaling # predicted mean of a random set
        clvdf$pv <- (1.607781034581)*x$scaling # predicted variance of a random set
        pvar <- predict(clpca$clvlm, newdata = clvdf)
        clvdf$varst <- (clvdf$var-pvar)/sqrt(clvdf$pv)
        clvdf$exp <- clpca$xf$results$par[1]*sqrt(clvdf$pv)+pvar
        #clvdf$varst <- (clvdf$var-clvdf$pm)/sqrt(clvdf$pv)
        #clvdf$exp <- clpca$xf$results$par[1]*sqrt(clvdf$pv)+clvdf$pm

        lp <- pgev.upper.log(clvdf$varst, clpca$xf$results$par[1], clpca$xf$results$par[2], rep(clpca$xf$results$par[3], nrow(clvdf)))
        clvdf$z <- qnorm(lp, lower.tail = FALSE, log.p = TRUE)
        clvdf$cz <- qnorm(bh.adjust(pnorm(as.numeric(clvdf$z), lower.tail = FALSE, log.p = TRUE), log = TRUE), lower.tail = FALSE, log.p = TRUE)

        # CI relative to the background
        clvdf$ub <- extRemes::qevd(score.alpha/2, loc = clpca$xf$results$par[1], scale = clpca$xf$results$par[2], shape = clpca$xf$results$par[3], lower.tail = FALSE)*sqrt(clvdf$pv) + pvar
        clvdf$ub.stringent <- extRemes::qevd(score.alpha/2/nrow(clvdf), loc = clpca$xf$results$par[1], scale = clpca$xf$results$par[2], shape = clpca$xf$results$par[3], lower.tail = FALSE)*sqrt(clvdf$pv) + pvar

    }

    if(plot) {
        par(mfrow = c(1, 1), mar = c(3.5, 3.5, 1.0, 1.0), mgp = c(2, 0.65, 0))
        un <- sort(unique(vdf$n))
        on <- order(vdf$n, decreasing = FALSE)
        pccol <- colorRampPalette(c("black", "grey70"), space = "Lab")(max(vdf$npc))
        plot(vdf$n, vdf$var/vdf$n, xlab = "gene set size", ylab = "PC1 var/n", ylim = c(0, max(vdf$var/vdf$n)), col = pccol[vdf$npc])
        lines(vdf$n[on], (vdf$exp/vdf$n)[on], col = 2, lty = 1)
        lines(vdf$n[on], (vdf$ub.stringent/vdf$n)[on], col = 2, lty = 2)

        if(!is.null(clpca)) {
            pccol <- colorRampPalette(c("darkgreen", "lightgreen"), space = "Lab")(max(clvdf$npc))
            points(clvdf$n, clvdf$var/clvdf$n, col = pccol[clvdf$npc], pch = 1)

            #clvm <- clpca$xf$results$par[1]*sqrt(pmax(1e-3, predict(vm, data.frame(n = un)))) + predict(mm, data.frame(n = un))
            on <- order(clvdf$n, decreasing = FALSE)

            lines(clvdf$n[on], (clvdf$exp/clvdf$n)[on], col = "darkgreen")
            lines(clvdf$n[on], (clvdf$ub.stringent/clvdf$n)[on], col = "darkgreen", lty = 2)
        }
        #mi<-which.max(vdf$n) sv<- (vdf$var/vdf$n)[mi] - (vdf$exp/vdf$n)[mi]
        #lines(vdf$n[on], (vdf$exp/vdf$n)[on]+sv, col = 2, lty = 3)
        #lines(vdf$n[on], (vdf$ub.stringent/vdf$n)[on]+sv, col = 2, lty = 2)
    }


    if(!is.null(clpca)) { # merge in cluster stats based on their own model

        # merge pwpca, psd and pm
        # all processing from here is common
        clvdf$i <- clvdf$i+length(pwpca) # shift cluster ids
        pwpca <- c(pwpca, clpca$cl.goc)
        vdf <- rbind(vdf, clvdf[, c("i", "var", "n", "npc", "exp", "cz", "z", "ub", "ub.stringent", "shz")])
    }

    vdf$adj.shz <- qnorm(bh.adjust(pnorm(as.numeric(vdf$shz), lower.tail = FALSE, log.p = TRUE), log = TRUE), lower.tail = FALSE, log.p = TRUE)
    #vdf$oe <- vdf$var/vdf$exp
    rs <- (vshift-ev)*vdf$n
    #rs <- ifelse(vdf$npc == 1, (vshift-ev)*vdf$n, 0)
    vdf$oe <- (vdf$var+rs)/(vdf$exp+rs)
    #vdf$oe[vdf$oe<0] <- 0
    #vdf$oec <- (vdf$var-vdf$ub.stringent+vdf$exp)/vdf$exp
    #vdf$oec <- (vdf$var-vdf$ub+vdf$exp)/vdf$exp
    #vdf$oec <- (vdf$var-vdf$ub+vdf$exp+rs)/(vdf$exp+rs)
    vdf$oec <- (vdf$var+rs)/(vdf$ub+rs)
    #vdf$oec[vdf$oec<0] <- 0
    #vdf$z[vdf$z<0] <- 0



    df <- data.frame(name = names(pwpca)[vdf$i], npc = vdf$npc, n = vdf$n, score = vdf$oe, z = vdf$z, adj.z = vdf$cz, sh.z = vdf$shz, adj.sh.z = vdf$adj.shz, stringsAsFactors = FALSE)
    if(adjust.scores) {
        vdf$valid <- vdf$cz  >=  z.score
    } else {
        vdf$valid <- vdf$z  >=  z.score
    }

    if(return.table) {
        df <- df[vdf$valid, ]
        df <- df[order(df$score, decreasing = TRUE), ]
        return(df)
    }

    # determine genes driving significant pathways
    # return genes within top 2/3rds of PC loading
    gl <- lapply(which(vdf$valid), function(i) { s <- abs(pwpca[[vdf[i, "i"]]]$xp$rotation[, vdf[i, "npc"]] )
    s[s >= max(s)/3] })
    gw <- tapply(abs(unlist(gl)), as.factor(unlist(lapply(gl, names))), max)
    if(return.genes) {
        return(gw)
    }
    # return combined data structure

    # weight
    xvw <- do.call(rbind, lapply(pwpca, function(x) {
        xm <- t(x$xp$scoreweights)
    }))
    vi <- vdf$valid
    xvw <- xvw[vi, ]/rowSums(xvw[vi, ])

    # return scaled patterns
    xmv <- do.call(rbind, lapply(pwpca, function(x) {
        xm <- t(x$xp$scores)
    }))

    if(use.oe.scale) {
        xmv <- (xmv[vi, ] -rowMeans(xmv[vi, ]))* (as.numeric(vdf$oe[vi])/sqrt(apply(xmv[vi, ], 1, var)))
    } else {
        # chi-squared
        xmv <- (xmv[vi, ]-rowMeans(xmv[vi, ])) * sqrt((qchisq(pnorm(vdf$z[vi], lower.tail = FALSE, log.p = TRUE), n.cells, lower.tail = FALSE, log.p = TRUE)/n.cells)/apply(xmv[vi, ], 1, var))
    }
    rownames(xmv) <- paste("#PC", vdf$npc[vi], "# ", names(pwpca)[vdf$i[vi]], sep = "")

    return(list(xv = xmv, xvw = xvw, gw = gw, df = df))

}


##' Collapse aspects driven by the same combinations of genes
##'
##' Examines PC loading vectors underlying the identified aspects and clusters aspects based
##' on a product of loading and score correlation (raised to corr.power). Clusters of aspects
##' driven by the same genes are determined based on the distance.threshold and collapsed.
##'
##' @param tam output of pagoda.top.aspects()
##' @param pwpca output of pagoda.pathway.wPCA()
##' @param clpca output of pagoda.gene.clusters() (optional)
##' @param plot whether to plot the resulting clustering
##' @param cluster.method one of the standard clustering methods to be used (fastcluster::hclust is used if available or stats::hclust)
##' @param distance.threshold similarity threshold for grouping interdependent aspects
##' @param corr.power power to which the product of loading and score correlation is raised
##' @param abs Boolean of whether to use absolute correlation
##' @param n.cores number of cores to use during processing
##' @param ... additional arguments are passed to the pagoda.view.aspects() method during plotting
##'
##' @return a list structure analogous to that returned by pagoda.top.aspects(), but with addition of a $cnam element containing a list of aspects summarized by each row of the new (reduced) $xv and $xvw
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' tam <- pagoda.top.aspects(pwpca, return.table = TRUE, plot=FALSE, z.score=1.96)  # top aspects based on GO only
##' tamr <- pagoda.reduce.loading.redundancy(tam, pwpca)
##' }
##'
##' @export
pagoda.reduce.loading.redundancy <- function(tam, pwpca, clpca = NULL, plot = FALSE, cluster.method = "complete", distance.threshold = 0.01, corr.power = 4, n.cores = detectCores(), abs = TRUE, ...) {
    pclc <- pathway.pc.correlation.distance(c(pwpca, clpca$cl.goc), tam$xv, target.ndf = 100, n.cores = n.cores)
    cda <- cor(t(tam$xv))
    if(abs) {
        cda <- abs(cda)
    } else {
        cda[cda<0] <- 0
    }
    cda <- as.dist(1-cda)
    cc <- (1-sqrt((1-pclc)*(1-cda)))^corr.power

    if(is.element("fastcluster", installed.packages()[, 1])) {
        y <- fastcluster::hclust(cc, method = cluster.method)
    } else {
        y <- stats::hclust(cc, method = cluster.method)
    }
    ct <- cutree(y, h = distance.threshold)
    ctf <- factor(ct, levels = sort(unique(ct)))
    xvl <- collapse.aspect.clusters(tam$xv, tam$xvw, ct, pick.top = FALSE, scale = TRUE)

    if(plot) {
        sc <- sample(colors(), length(levels(ctf)), replace = TRUE)
        view.aspects(tam$xv, row.clustering = y, row.cols = sc[as.integer(ctf)], ...)
    }

    # collapsed names
    if(!is.null(tam$cnam)) { # already has collapsed names
        cnam <- tapply(rownames(tam$xv), ctf, function(xn) unlist(tam$cnam[xn]))
    } else {
        cnam <- tapply(rownames(tam$xv), ctf, I)
    }
    names(cnam) <- rownames(xvl$d)
    tam$xv <- xvl$d
    tam$xvw <- xvl$w
    tam$cnam <- cnam
    return(tam)
}


##' Collapse aspects driven by similar patterns (i.e. separate the same sets of cells)
##'
##' Examines PC loading vectors underlying the identified aspects and clusters aspects based on score correlation. Clusters of aspects driven by the same patterns are determined based on the distance.threshold.
##'
##' @param tamr output of pagoda.reduce.loading.redundancy()
##' @param distance.threshold similarity threshold for grouping interdependent aspects
##' @param cluster.method one of the standard clustering methods to be used (fastcluster::hclust is used if available or stats::hclust)
##' @param distance distance matrix
##' @param weighted.correlation Boolean of whether to use a weighted correlation in determining the similarity of patterns
##' @param plot Boolean of whether to show plot
##' @param top Restrict output to the top n aspects of heterogeneity
##' @param trim Winsorization trim to use prior to determining the top aspects
##' @param abs Boolean of whether to use absolute correlation
##' @param ... additional arguments are passed to the pagoda.view.aspects() method during plotting
##'
##' @return a list structure analogous to that returned by pagoda.top.aspects(), but with addition of a $cnam element containing a list of aspects summarized by each row of the new (reduced) $xv and $xvw
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' tam <- pagoda.top.aspects(pwpca, return.table = TRUE, plot=FALSE, z.score=1.96)  # top aspects based on GO only
##' tamr <- pagoda.reduce.loading.redundancy(tam, pwpca)
##' tamr2 <- pagoda.reduce.redundancy(tamr, distance.threshold = 0.9, plot = TRUE, labRow = NA, labCol = NA, box = TRUE, margins = c(0.5, 0.5), trim = 0)
##' }
##'
##' @export
pagoda.reduce.redundancy <- function(tamr, distance.threshold = 0.2, cluster.method = "complete", distance = NULL, weighted.correlation = TRUE, plot = FALSE, top = Inf, trim = 0, abs = FALSE, ...) {
    if(is.null(distance)) {
        if(weighted.correlation) {
            distance <- .Call("matWCorr", t(tamr$xv), t(tamr$xvw), PACKAGE = "scde")
            rownames(distance) <- colnames(distance) <- rownames(tamr$xv)
            if(abs) {
                distance <- stats::as.dist(1-abs(distance), upper = TRUE)
            } else {
                distance <- stats::as.dist(1-distance, upper = TRUE)
            }
        } else {
            if(abs) {
                distance <- stats::as.dist(1-abs(cor(t(tamr$xv))))
            } else {
                distance <- stats::as.dist(1-cor(t(tamr$xv)))
            }
        }
    }
    if(is.element("fastcluster", installed.packages()[, 1])) {
        y <- fastcluster::hclust(distance, method = cluster.method)
    } else {
        y <- stats::hclust(distance, method = cluster.method)
    }

    ct <- cutree(y, h = distance.threshold)
    ctf <- factor(ct, levels = sort(unique(ct)))
    xvl <- collapse.aspect.clusters(tamr$xv, tamr$xvw, ct, pick.top = FALSE, scale = TRUE)

    if(plot) {
        sc <- sample(colors(), length(levels(ctf)), replace = TRUE)
        view.aspects(tamr$xv, row.clustering = y, row.cols = sc[as.integer(ctf)], ...)
    }

    # collapsed names
    if(!is.null(tamr$cnam)) { # already has collapsed names
        cnam <- tapply(rownames(tamr$xv), ctf, function(xn) unlist(tamr$cnam[xn]))
    } else {
        cnam <- tapply(rownames(tamr$xv), ctf, I)
    }
    names(cnam) <- rownames(xvl$d)

    if(trim > 0) { xvl$d <- winsorize.matrix(xvl$d, trim) } # trim prior to determining the top sets

    rcmvar <- apply(xvl$d, 1, var)
    vi <- order(rcmvar, decreasing = TRUE)[1:min(length(rcmvar), top)]

    tamr2 <- tamr
    tamr2$xv <- xvl$d[vi, ]
    tamr2$xvw <- xvl$w[vi, ]
    tamr2$cnam <- cnam[vi]
    return(tamr2)
}


##' Determine optimal cell clustering based on the genes driving the significant aspects
##'
##' Determines cell clustering (hclust result) based on a weighted correlation of genes
##' underlying the top aspects of transcriptional heterogeneity. Branch orientation is optimized
##' if 'cba' package is installed.
##'
##' @param tam result of pagoda.top.aspects() call
##' @param varinfo result of pagoda.varnorm() call
##' @param method clustering method ('ward.D' by default)
##' @param verbose 0 or 1 depending on level of desired verbosity
##' @param include.aspects whether the aspect patterns themselves should be included alongside with the individual genes in calculating cell distance
##' @param return.details Boolean of whether to return just the hclust result or a list containing the hclust result plus the distance matrix and gene values
##'
##' @return hclust result
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' tam <- pagoda.top.aspects(pwpca, return.table = TRUE, plot=FALSE, z.score=1.96)  # top aspects based on GO only
##' hc <- pagoda.cluster.cells(tam, varinfo)
##' plot(hc)
##' }
##'
##' @export
pagoda.cluster.cells <- function(tam, varinfo, method = "ward.D", include.aspects = FALSE, verbose = 0, return.details = FALSE) {
    # gene clustering
    gw <- tam$gw
    gw <- gw[(rowSums(varinfo$matw)*varinfo$arv)[names(gw)] > 1]

    gw <- gw/gw
    mi <- match(names(gw), rownames(varinfo$mat))
    wgm <- varinfo$mat[mi, ]
    wgm <- wgm*as.numeric(gw)
    wgwm <- varinfo$matw[mi, ]

    if(include.aspects) {
        if(verbose) { message("clustering cells based on ", nrow(wgm), " genes and ", nrow(tam$xv), " aspect patterns")}
        wgm <- rbind(wgm, tam$xv)
        wgwm <- rbind(wgwm, tam$xvw)
    } else {
        if(verbose) { message("clustering cells based on ", nrow(wgm), " genes")}
    }

    snam <- sample(colnames(wgm))

    dm <- .Call("matWCorr", wgm, wgwm, PACKAGE = "scde")
    dm <- 1-dm
    rownames(dm) <- colnames(dm) <- colnames(wgm)
    wcord <- stats::as.dist(dm, upper = TRUE)
    hc <- hclust(wcord, method = method)

    if(is.element("cba", installed.packages()[, 1])) {
        co <- cba::order.optimal(wcord, hc$merge)
        hc$merge <- co$merge
        hc$order <- co$order
    }
    if(return.details) {
        return(list(clustering = hc, distance = wcord, genes = gw))
    } else {
        return(hc)
    }
}


##' View PAGODA output
##'
##' Create static image of PAGODA output visualizing cell hierarchy and top aspects of transcriptional heterogeneity
##'
##' @param tamr Combined pathways that show similar expression patterns. Output of \code{\link{pagoda.reduce.redundancy}}
##' @param row.clustering Dendrogram of combined pathways clustering
##' @param top Restrict output to the top n aspects of heterogeneity
##' @param ... additional arguments are passed to the \code{\link{view.aspects}} method during plotting
##'
##' @return PAGODA heatmap
##'
##' @examples
##' data(pollen)
##' cd <- clean.counts(pollen)
##' \donttest{
##' knn <- knn.error.models(cd, k=ncol(cd)/4, n.cores=10, min.count.threshold=2, min.nonfailed=5, max.model.plots=10)
##' varinfo <- pagoda.varnorm(knn, counts = cd, trim = 3/ncol(cd), max.adj.var = 5, n.cores = 1, plot = FALSE)
##' pwpca <- pagoda.pathway.wPCA(varinfo, go.env, n.components=1, n.cores=10, n.internal.shuffles=50)
##' tam <- pagoda.top.aspects(pwpca, return.table = TRUE, plot=FALSE, z.score=1.96)  # top aspects based on GO only
##' pagoda.view.aspects(tam)
##' }
##'
##' @export
pagoda.view.aspects <- function(tamr, row.clustering = hclust(dist(tamr$xv)), top = Inf, ...) {
    if(is.finite(top)) {
        rcmvar <- apply(tamr$xv, 1, var)
        vi <- order(rcmvar, decreasing = TRUE)[1:min(length(rcmvar), top)]
        tamr$xv <- tamr$xv[vi, ]
        tamr$xvw <- tamr$xvw[vi, ]
        tamr$cnam <- tamr$cnam[vi]
    }

    view.aspects(tamr$xv, row.clustering = row.clustering, ... )
}


##' View heatmap
##'
##' Internal function to visualize aspects of transcriptional heterogeneity as a heatmap. Used by \code{\link{pagoda.view.aspects}}.
##'
##' @param mat Numeric matrix
##' @param row.clustering Row dendrogram
##' @param cell.clustering Column dendrogram
##' @param zlim Range of the normalized gene expression levels, inputted as a list: c(lower_bound, upper_bound). Values outside this range will be Winsorized. Useful for increasing the contrast of the heatmap visualizations. Default, set to the 5th and 95th percentiles.
##' @param row.cols  Matrix of row colors.
##' @param col.cols  Matrix of column colors. Useful for visualizing cell annotations such as batch labels.
##' @param cols Heatmap colors
##' @param show.row.var.colors Boolean of whether to show row variance as a color track
##' @param top Restrict output to the top n aspects of heterogeneity
##' @param ... additional arguments for heatmap plotting
##'
##' @return A heatmap
##'
view.aspects <- function(mat, row.clustering = NA, cell.clustering = NA, zlim = c(-1, 1)*quantile(mat, p = 0.95), row.cols = NULL, col.cols = NULL, cols = colorRampPalette(c("darkgreen", "white", "darkorange"), space = "Lab")(1024), show.row.var.colors = TRUE, top = Inf, ...) {
    #row.cols, col.cols are matrices for now
    rcmvar <- apply(mat, 1, var)
    mat[mat<zlim[1]] <- zlim[1]
    mat[mat > zlim[2]] <- zlim[2]
    if(class(row.clustering) == "hclust") { row.clustering <- as.dendrogram(row.clustering) }
    if(class(cell.clustering) == "hclust") { cell.clustering <- as.dendrogram(cell.clustering) }
    if(show.row.var.colors) {
        if(is.null(row.cols)) {
            icols <- colorRampPalette(c("white", "black"), space = "Lab")(1024)[1023*(rcmvar/max(rcmvar))+1]
            row.cols <- cbind(var = icols)
        }
    }
    my.heatmap2(mat, Rowv = row.clustering, Colv = cell.clustering, zlim = zlim, RowSideColors = row.cols, ColSideColors = col.cols, col = cols, ...)
}


##' Make the PAGODA app
##'
##' Create an interactive user interface to explore output of PAGODA.
##'
##' @param tamr Combined pathways that show similar expression patterns. Output of \code{\link{pagoda.reduce.redundancy}}
##' @param tam Combined pathways that are driven by the same gene sets. Output of \code{\link{pagoda.reduce.loading.redundancy}}
##' @param varinfo Variance information. Output of \code{\link{pagoda.varnorm}}
##' @param env Gene sets as an environment variable.
##' @param pwpca Weighted PC magnitudes for each gene set provided in the \code{env}. Output of \code{\link{pagoda.pathway.wPCA}}
##' @param clpca Weighted PC magnitudes for de novo gene sets identified by clustering on expression. Output of \code{\link{pagoda.gene.clusters}}
##' @param col.cols  Matrix of column colors. Useful for visualizing cell annotations such as batch labels. Default NULL.
##' @param cell.clustering Dendrogram of cell clustering. Output of \code{\link{pagoda.cluster.cells} } . Default   NULL.
##' @param row.clustering Dendrogram of combined pathways clustering. Default NULL.
##' @param title Title text to be used in the browser label for the app. Default, set as 'pathway clustering'
##' @param zlim Range of the normalized gene expression levels, inputted as a list: c(lower_bound, upper_bound). Values outside this range will be Winsorized. Useful for increasing the contrast of the heatmap visualizations. Default, set to the 5th and 95th percentiles.
##'
##' @return PAGODA app
##'
##' @export
make.pagoda.app <- function(tamr, tam, varinfo, env, pwpca, clpca = NULL, col.cols = NULL, cell.clustering = NULL, row.clustering = NULL, title = "pathway clustering", zlim = c(-1, 1)*quantile(tamr$xv, p = 0.95)) {
    # rcm - xv
    # matvar
    if(is.null(cell.clustering)) {
        cell.clustering <- pagoda.cluster.cells(tam, varinfo)
    }
    if(is.null(row.clustering)) {
        row.clustering <- hclust(dist(tamr$xv))
        row.clustering$order <- rev(row.clustering$order)
    }

    #fct - which tam row in which tamr$xv cluster.. remap tamr$cnams
    cn <- tamr$cnam
    fct <- rep(1:length(cn), lapply(cn, length))
    names(fct) <- unlist(cn)
    fct <- fct[rownames(tam$xv)]
    rcm <- tamr$xv
    rownames(rcm) <- as.character(1:nrow(rcm))
    fres <- list(hvc = cell.clustering, tvc = row.clustering, rcm = rcm, zlim2 = zlim, matvar = apply(tam$xv, 1, sd), ct = fct, matrcmcor = rep(1, nrow(tam$xv)), cols = colorRampPalette(c("darkgreen", "white", "darkorange"), space = "Lab")(1024), colcol = col.cols)

    # gene df
    gene.df <- data.frame(var = varinfo$arv*rowSums(varinfo$matw))
    gene.df$gene <- rownames(varinfo$mat)
    gene.df <- gene.df[order(gene.df$var, decreasing = TRUE), ]

    # prepare pathway df
    df <- tamr$df
    if(exists("myGOTERM", envir = globalenv())) {
        df$desc <- mget(df$name, get("myGOTERM", envir = globalenv()), ifnotfound = "")
    } else {
        df$desc <- ""
    }
    min.z <- -9
    df$z[df$z<min.z] <- min.z
    df$adj.z[df$adj.z<min.z] <- min.z
    df$sh.z[df$sh.z<min.z] <- min.z
    df$adj.sh.z[df$adj.sh.z<min.z] <- min.z
    df <- data.frame(id = paste("#PC", df$npc, "# ", df$name, sep = ""), npc = df$npc, n = df$n, score = df$score, Z = df$z, aZ = df$adj.z, sh.Z = df$sh.z, sh.aZ = df$adj.sh.z, name = paste(df$name, df$desc))

    df <- df[order(df$score, decreasing = TRUE), ]

    # merge go.env
    if(!is.null(clpca)) {
        set.env <- list2env(c(as.list(env), clpca$clusters))
    } else {
        set.env <- env
    }
    sa <- ViewPagodaApp$new(fres, df, gene.df, varinfo$mat, varinfo$matw, set.env, name = title, trim = 0, batch = varinfo$batch)
}

##################### Internal functions

one.sided.test.id <- function(id, nam1, nam2, ifm, dm, prior, difference.prior = 0.5, bootstrap = TRUE, n.samples = 1e3, show.plots = TRUE, return.posterior = FALSE, return.both = FALSE) {
    gr <- 10^prior$x - 1
    gr[gr<0] <- 0
    lpp <- get.rep.set.general.model.logposteriors(ifm[[nam1]], dm[rep(id, length(gr)), names(ifm[[nam1]])], data.frame(fpm = gr), grid.weight = prior$grid.weight)
    ldp <- get.rep.set.general.model.logposteriors(ifm[[nam2]], dm[rep(id, length(gr)), names(ifm[[nam2]])], data.frame(fpm = gr), grid.weight = prior$grid.weight)

    if(bootstrap) {
        pjp <- do.call(cbind, lapply(seq_along(n.samples), function(i) {
            pjp <- rowSums(lpp[, sample(1:ncol(lpp), replace = TRUE)])
            pjp <- exp(pjp-max(pjp))
            pjp <- pjp/sum(pjp)
            return(pjp)
        }))
        pjp <- rowSums(pjp)
        pjp <- log(pjp/sum(pjp))

        djp <- do.call(cbind, lapply(seq_along(n.samples), function(i) {
            djp <- rowSums(ldp[, sample(1:ncol(ldp), replace = TRUE)])
            djp <- exp(djp-max(djp))
            djp <- djp/sum(djp)
            return(djp)
        }))
        djp <- rowSums(djp)
        djp <- log(djp/sum(djp))
    } else {
        pjp <- rowSums(lpp)
        djp <- rowSums(ldp)
    }

    dpy <- exp(prior$lp+djp)
    mpgr <- sum(exp(prior$lp+pjp+log(c(0, cumsum(dpy)[-length(dpy)])))) # m1
    mpls <- sum(exp(prior$lp+pjp+log(sum(dpy)-cumsum(dpy)))) # m0
    mpls/mpgr

    pjpc <- exp(prior$lp+pjp)
    pjpc <- pjpc/sum(pjpc)
    djpc <- exp(prior$lp+djp)
    djpc <- djpc/sum(djpc)

    if(show.plots || return.posterior || return.both) {
        # calculate log-fold-change posterior
        n <- length(pjpc)
        rp <- c(unlist(lapply(n:2, function(i) sum(pjpc[1:(n-i+1)]*djpc[i:n]))), unlist(lapply(seq_along(n), function(i) sum(pjpc[i:n]*djpc[1:(n-i+1)]))))
        rv <- seq(prior$x[1]-prior$x[length(prior$x)], prior$x[length(prior$x)]-prior$x[1], length = length(prior$x)*2-1)
        fcp <- data.frame(v = rv, p = rp)
    }

    if(show.plots) {
        # show each posterior
        layout(matrix(c(1:3), 3, 1, byrow = TRUE), heights = c(2, 1, 2), widths = c(1), FALSE)
        par(mar = c(2.5, 3.5, 2.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)
        jpr <- range(c(0, pjpc), na.rm = TRUE)
        pp <- exp(lpp)
        cols <- rainbow(dim(pp)[2], s = 0.8)
        plot(c(), c(), xlim = range(prior$x), ylim = range(c(0, pp)), xlab = "expression level", ylab = "individual posterior", main = nam1)
        lapply(seq_len(ncol(pp)), function(i) lines(prior$x, pp[, i], col = cols[i]))
        legend(x = ifelse(which.max(na.omit(pjpc)) > length(pjpc)/2, "topleft", "topright"), bty = "n", col = cols, legend = colnames(pp), lty = rep(1, dim(pp)[2]))
        par(new = TRUE)
        plot(prior$x, pjpc, axes = FALSE, ylab = "", xlab = "", ylim = jpr, type = 'l', col = 1, lty = 1, lwd = 2)
        axis(4, pretty(jpr, 5), col = 1)
        mtext("joint posterior", side = 4, outer = FALSE, line = 2)

        # ratio plot
        par(mar = c(2.5, 3.5, 0.5, 3.5), mgp = c(1.5, 0.65, 0), cex = 0.9)
        plot(fcp$v, fcp$p, xlab = "log10 expression ratio", ylab = "ratio posterior", type = 'l', lwd = 2, main = "")
        r.mle <- fcp$v[which.max(fcp$p)]
        r.lb <- max(which(cumsum(fcp$p)<0.025))
        r.ub <- min(which(cumsum(fcp$p) > (1-0.025)))
        polygon(c(fcp$v[r.lb], fcp$v[r.lb:r.ub], fcp$v[r.ub]), y = c(-10, fcp$p[r.lb:r.ub], -10), col = "grey90")
        abline(v = r.mle, col = 2, lty = 2)
        abline(v = c(fcp$v[r.ub], fcp$v[r.lb]), col = 2, lty = 3)
        box()
        legend(x = ifelse(r.mle > 0, "topleft", "topright"