R/rliger.R

Defines functions calcNormLoadings optimize_UANLS convertOldLiger reorganizeLiger subsetLiger seuratToLiger ligerToSeurat getFactorMarkers plotClusterFactors plotClusterProportions makeRiverplot plotGenes plotGene plotGeneViolin plotGeneLoadings plotWordClouds plotFactors plotFeature plotByDatasetAndCluster getProportionMito calcPurity calcARI calcAlignmentPerCluster calcAlignment calcAgreement calcDatasetSpecificity runUMAP runTSNE runGSEA makeInteractTrack linkGenesAndPeaks runWilcoxon imputeKNN GroupSingletons louvainCluster quantile_norm.liger quantile_norm.list quantile_norm suggestK suggestLambda optimizeNewLambda optimizeSubset optimizeNewData optimizeNewK optimizeALS.liger optimizeALS.list optimizeALS nonneg online_iNMF readSubset downsample removeMissingObs scaleNotCenter selectGenes normalize safe_h5_create createLiger restoreOnlineLiger mergeH5 read10X

Documented in calcAgreement calcAlignment calcAlignmentPerCluster calcARI calcDatasetSpecificity calcNormLoadings calcPurity convertOldLiger createLiger getFactorMarkers getProportionMito imputeKNN ligerToSeurat linkGenesAndPeaks louvainCluster makeInteractTrack makeRiverplot mergeH5 normalize online_iNMF optimizeALS optimizeALS.liger optimizeALS.list optimizeNewData optimizeNewK optimizeNewLambda optimizeSubset plotByDatasetAndCluster plotClusterFactors plotClusterProportions plotFactors plotFeature plotGene plotGeneLoadings plotGenes plotGeneViolin plotWordClouds quantile_norm quantile_norm.liger quantile_norm.list read10X readSubset removeMissingObs reorganizeLiger restoreOnlineLiger runGSEA runTSNE runUMAP runWilcoxon scaleNotCenter selectGenes seuratToLiger subsetLiger suggestK suggestLambda

#' @import Matrix
#' @importFrom grDevices dev.off pdf
#' @import hdf5r
#' @importFrom methods new
#' @importFrom utils packageVersion
#' @importFrom Rcpp evalCpp
NULL

#' The LIGER Class
#'
#' The liger object is created from two or more single cell datasets. To construct a
#' liger object, the user needs to provide at least two expression (or another
#' single-cell modality) matrices. The class provides functions for data
#' preprocessing, integrative analysis, and visualization.
#'
#' The key slots used in the liger object are described below.
#'
#' @slot raw.data List of raw data matrices, one per experiment/dataset (genes by cells)
#' @slot norm.data List of normalized matrices (genes by cells)
#' @slot scale.data List of scaled matrices (cells by genes)
#' @slot sample.data List of sampled matrices (gene by cells)
#' @slot scale.unshared.data List of scaled matrices of unshared features
#' @slot h5file.info List of HDF5-related information for each input dataset. Paths to raw data, indices,
#'       indptr, barcodes, genes and  the pipeline through which the HDF5 file is formated (10X, AnnData, etc),
#'       type of sampled data (raw, normalized or scaled).
#' @slot cell.data Dataframe of cell attributes across all datasets (nrows equal to total number
#'   cells across all datasets)
#' @slot var.genes Subset of informative genes shared across datasets to be used in matrix
#'   factorization
#' @slot var.unshared.features Highly variable unshared features selected from each dataset
#' @slot H Cell loading factors (one matrix per dataset, dimensions cells by k)
#' @slot H.norm Normalized cell loading factors (cells across all datasets combined into single
#'   matrix)
#' @slot W Shared gene loading factors (k by genes)
#' @slot V Dataset-specific gene loading factors (one matrix per dataset, dimensions k by genes)
#' @slot A Matrices used for online learning (XH)
#' @slot B Matrices used for online learning (HTH)
#' @slot U Matrices used for unshared Matrix factorization
#' @slot tsne.coords Matrix of 2D coordinates obtained from running t-SNE on H.norm or H matrices
#' @slot alignment.clusters Initial joint cluster assignments from shared factor alignment
#' @slot clusters Joint cluster assignments for cells
#' @slot snf List of values associated with shared nearest factor matrix for use in clustering and
#'   alignment (out.summary contains edge weight information between cell combinations)
#' @slot agg.data Data aggregated within clusters
#' @slot parameters List of parameters used throughout analysis
#' @slot version Version of package used to create object
#'
#' @name liger-class
#' @rdname liger-class
#' @aliases liger-class
#' @exportClass liger
#' @useDynLib rliger
liger <- methods::setClass(
  "liger",
  slots = c(
    raw.data = "list",
    norm.data = "list",
    scale.data = "list",
    sample.data = "list",
    scale.unshared.data = "list",
    h5file.info = "list",
    cell.data = "data.frame",
    var.genes = "vector",
    var.unshared.features = "list",
    H = "list",
    H.norm = "matrix",
    W = "matrix",
    V = "list",
    A = "list",
    B = "list",
    U = "list",
    tsne.coords = "matrix",
    alignment.clusters = 'factor',
    clusters= "factor",
    agg.data = "list",
    parameters = "list",
    snf = 'list',
    version = 'ANY'
  )
)

#' show method for liger
#'
#' @param object liger object
#' @name show
#' @aliases show,liger-method
#' @docType methods
#' @rdname show-methods
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl))
#' show(ligerex)
setMethod(
  f = "show",
  signature = "liger",
  definition = function(object) {
    cat(
      "An object of class",
      class(object),
      "\nwith",
      length(object@raw.data),
      "datasets and\n",
      nrow(object@cell.data),
      "total cells."
    )
    invisible(x = NULL)
  }
)

#######################################################################################
#### Data Preprocessing

#' Read 10X alignment data (including V3)
#'
#' This function generates a sparse matrix (genes x cells) from the data generated by 10X's
#' cellranger count pipeline. It can process V2 and V3 data together, producing either a single
#' merged matrix or list of matrices. Also handles multiple data types produced by 10X V3 (Gene
#' Expression, Antibody Capture, CRISPR, CUSTOM).
#'
#' @param sample.dirs List of directories containing either matrix.mtx(.gz) file along with genes.tsv,
#'   (features.tsv), and barcodes.tsv, or outer level 10X output directory (containing outs directory).
#' @param sample.names Vector of names to use for samples (corresponding to sample.dirs)
#' @param merge Whether to merge all matrices of the same data type across samples or leave as list
#'   of matrices (default TRUE).
#' @param num.cells Optional limit on number of cells returned for each sample (only for Gene
#'   Expression data). Retains the cells with the highest numbers of transcripts (default NULL).
#' @param min.umis Minimum UMI threshold for cells (default 0).
#' @param use.filtered Whether to use 10X's filtered data (as opposed to raw). Only relevant for
#'   sample.dirs containing 10X outs directory (default FALSE).
#' @param reference For 10X V<3, specify which reference directory to use if sample.dir is outer
#'   level 10X directory (only necessary if more than one reference used for sequencing).
#'   (default NULL)
#' @param data.type Indicates the protocol of the input data. If not specified, input data will be
#' considered scRNA-seq data (default 'rna', alternatives: 'atac').
#' @param verbose Print messages (TRUE by default)
#'
#' @return List of merged matrices across data types (returns sparse matrix if only one data type
#'   detected), or nested list of matrices organized by sample if merge = FALSE.
#'
#' @importFrom utils read.delim read.table
#'
#' @export
#' @examples
#' \dontrun{
#' # 10X output directory V2 -- contains outs/raw_gene_bc_matrices/<reference>/...
#' sample.dir1 <- "path/to/outer/dir1"
#' # 10X output directory V3 -- for two data types, Gene Expression and CUSTOM
#' sample.dir2 <- "path/to/outer/dir2"
#' dges1 <- read10X(list(sample.dir1, sample.dir2), c("sample1", "sample2"), min.umis = 50)
#' ligerex <- createLiger(expr = dges1[["Gene Expression"]], custom = dges1[["CUSTOM"]])
#' }
read10X <- function(sample.dirs, sample.names, merge = TRUE, num.cells = NULL, min.umis = 0,
                    use.filtered = FALSE, reference = NULL, data.type = "rna", verbose = TRUE) {
  datalist <- list()
  datatypes <- c("Gene Expression")

  if (length(num.cells) == 1) {
    num.cells <- rep(num.cells, length(sample.dirs))
  }
  for (i in seq_along(sample.dirs)) {
    print(paste0("Processing sample ", sample.names[i]))
    sample.dir <- sample.dirs[[i]]
    inner1 <- paste0(sample.dir, "/outs")
    if (dir.exists(inner1)) {
      sample.dir <- inner1
      is_v3 <- dir.exists(paste0(sample.dir, "/filtered_feature_bc_matrix"))
      matrix.prefix <- ifelse(use.filtered, "filtered", "raw")
      if (is_v3) {
        sample.dir <- paste0(sample.dir, "/", matrix.prefix, "_feature_bc_matrix")
      } else {
        if (is.null(reference)) {
          references <- list.dirs(paste0(sample.dir, "/raw_gene_bc_matrices"),
                                  full.names = FALSE,
                                  recursive = FALSE
          )
          if (length(references) > 1) {
            stop("Multiple reference genomes found. Please specify a single one.")
          } else {
            reference <- references[1]
          }
        }
        sample.dir <- paste0(sample.dir, "/", matrix.prefix, "_gene_bc_matrices/", reference)
      }
    } else {
      is_v3 <- file.exists(paste0(sample.dir, "/features.tsv.gz"))
    }
    suffix <- ifelse(is_v3, ".gz", "")
    if (data.type == "rna") {
      features.file <- ifelse(is_v3, paste0(sample.dir, "/features.tsv.gz"),
                              paste0(sample.dir, "/genes.tsv")
      )
    } else if (data.type == "atac") {
      features.file <- ifelse(is_v3, paste0(sample.dir, "/peaks.bed.gz"),
                              paste0(sample.dir, "/peaks.bed")
      )
    }
    matrix.file <- paste0(sample.dir, "/matrix.mtx", suffix)
    barcodes.file <- paste0(sample.dir, "/barcodes.tsv", suffix)

    rawdata <- readMM(matrix.file)
    # convert to dgc matrix
    if (class(rawdata)[1] == "dgTMatrix") {
      rawdata <- as(rawdata, "CsparseMatrix")
    }

    # filter for UMIs first to increase speed
    umi.pass <- which(colSums(rawdata) > min.umis)
    if (length(umi.pass) == 0) {
      message("No cells pass UMI cutoff. Please lower it.")
    }
    rawdata <- rawdata[, umi.pass, drop = FALSE]

    barcodes <- readLines(barcodes.file)[umi.pass]
    # Remove -1 tag from barcodes
    if (all(grepl(barcodes, pattern = "\\-1$"))) {
      barcodes <- as.vector(sapply(barcodes, function(x) {
        strsplit(x, "-")[[1]][1]
      }))
    }
    if (data.type == "rna") {
      features <- read.delim(features.file, header = FALSE, stringsAsFactors = FALSE)
      rownames(rawdata) <- make.unique(features[, 2])
    } else if (data.type == "atac") {
      features <- read.table(features.file, header = FALSE)
      features <- paste0(features[, 1], ":", features[, 2], "-", features[, 3])
      rownames(rawdata) <- features
    }
    # since some genes are only differentiated by ENSMBL
    colnames(rawdata) <- barcodes

    # split based on 10X datatype -- V3 has Gene Expression, Antibody Capture, CRISPR, CUSTOM
    # V2 has only Gene Expression by default and just two columns
    if (is.null(ncol(features))) {
      samplelist <- list(rawdata)
      names(samplelist) <- c("Chromatin Accessibility")
    } else if (ncol(features) < 3) {
      samplelist <- list(rawdata)
      names(samplelist) <- c("Gene Expression")
    } else {
      sam.datatypes <- features[, 3]
      sam.datatypes.unique <- unique(sam.datatypes)
      # keep track of all unique datatypes
      datatypes <- union(datatypes, unique(sam.datatypes))
      samplelist <- lapply(sam.datatypes.unique, function(x) {
        rawdata[which(sam.datatypes == x), ]
      })
      names(samplelist) <- sam.datatypes.unique
    }

    # num.cells filter only for gene expression data
    if (!is.null(num.cells)) {
      if (names(samplelist) == "Gene Expression" | names(samplelist) == "Chromatin Accessibility") {
        data_label <- names(samplelist)
        cs <- colSums(samplelist[[data_label]])
        limit <- ncol(samplelist[[data_label]])
        if (num.cells[i] > limit) {
          if (verbose) {
            message("You selected more cells than are in matrix ", i,
                    ". Returning all ", limit, " cells.")
          }
          num.cells[i] <- limit
        }
        samplelist[[data_label]] <- samplelist[[data_label]][, order(cs, decreasing = TRUE)
                                                             [1:num.cells[i]]]
      }

      # cs <- colSums(samplelist[["Gene Expression"]])
      # limit <- ncol(samplelist[["Gene Expression"]])
      # if (num.cells[i] > limit) {
      #   print(paste0(
      #     "You selected more cells than are in matrix ", i,
      #     ". Returning all ", limit, " cells."
      #   ))
      #   num.cells[i] <- limit
      # }
      # samplelist[["Gene Expression"]] <- samplelist[["Gene Expression"]][, order(cs, decreasing = TRUE)
      #                                                                    [1:num.cells[i]]]
    }

    datalist[[i]] <- samplelist
  }
  if (merge) {
    if (verbose) {
      message("Merging samples")
    }
    return_dges <- lapply(datatypes, function(x) {
      mergelist <- lapply(datalist, function(d) {
        d[[x]]
      })
      mergelist <- mergelist[!sapply(mergelist, is.null)]
      sample.names.x <- sample.names[!sapply(mergelist, is.null)]
      MergeSparseDataAll(mergelist, sample.names)
    })
    names(return_dges) <- datatypes

    # if only one type of data present
    if (length(return_dges) == 1) {
      if (verbose){
        message("Returning ", datatypes, " data matrix")
      }
      return(return_dges[[1]])
    }
    return(return_dges)
  } else {
    names(datalist) <- sample.names
    return(datalist)
  }
}

#' Merge hdf5 files
#'
#' This function merges hdf5 files generated from different libraries (cell ranger by default)
#' before they are preprocessed through Liger pipeline.
#'
#' @param file.list List of path to hdf5 files.
#' @param library.names Vector of library names (corresponding to file.list)
#' @param new.filename String of new hdf5 file name after merging (default new.h5).
#' @param format.type string of HDF5 format (10X CellRanger by default).
#' @param data.name Path to the data values stored in HDF5 file.
#' @param indices.name Path to the indices of data points stored in HDF5 file.
#' @param indptr.name Path to the pointers stored in HDF5 file.
#' @param genes.name Path to the gene names stored in HDF5 file.
#' @param barcodes.name Path to the barcodes stored in HDF5 file.
#' @return Directly generates newly merged hdf5 file.
#' @export
#' @examples
#' \dontrun{
#' # For instance, we want to merge two datasets saved in HDF5 files (10X CellRanger)
#' # paths to datasets: "library1.h5","library2.h5"
#' # dataset names: "lib1", "lib2"
#' # name for output HDF5 file: "merged.h5"
#' mergeH5(list("library1.h5","library2.h5"), c("lib1","lib2"), "merged.h5")
#' }
mergeH5 <- function(file.list,
                    library.names,
                    new.filename,
                    format.type = "10X",
                    data.name = NULL,
                    indices.name = NULL,
                    indptr.name = NULL,
                    genes.name = NULL,
                    barcodes.name = NULL){
  h5_merged = hdf5r::H5File$new(paste0(new.filename,".h5"), mode = "w")
  h5_merged$create_group("matrix")
  h5_merged$create_group("matrix/features")
  num_data_prev = 0
  num_indptr_prev = 0
  num_cells_prev = 0
  last_inptr = 0
  for (i in 1:length(file.list)){
    h5file = hdf5r::H5File$new(file.list[[i]], mode = "r")
    if (format.type == "10X"){
      data = h5file[["matrix/data"]][]
      indices = h5file[["matrix/indices"]][]
      indptr = h5file[["matrix/indptr"]][]
      barcodes = paste0(library.names[i], "_", h5file[["matrix/barcodes"]][])
      genes = h5file[["matrix/features/name"]][]
    } else if (format.type == "AnnData"){
      data = h5file[["raw.X/data"]][]
      indices = h5file[["raw.X/indices"]][]
      indptr = h5file[["raw.X/indptr"]][]
      barcodes = paste0(library.names[i], "_", h5file[["obs"]][]$cell)
      genes = h5file[["raw.var"]][]$index

    } else {
      data = h5file[[data.name]][]
      indices = h5file[[indices.name]][]
      indptr = h5file[[indptr.name]][]
      barcodes = paste0(library.names[i], "_", h5file[[barcodes.name]][])
      genes = h5file[[genes.name]][]
    }

    if (i != 1) indptr = indptr[2:length(indptr)]
    num_data = length(data)
    num_indptr = length(indptr)
    num_cells = length(barcodes)
    indptr = indptr + last_inptr
    last_inptr = indptr[num_indptr]
    if (i == 1) {
      h5_merged[["matrix/data"]] = data
      h5_merged[["matrix/indices"]] = indices
      h5_merged[["matrix/indptr"]] = indptr
      h5_merged[["matrix/barcodes"]] = barcodes
      h5_merged[["matrix/features/name"]] = genes
    } else {
      h5_merged[["matrix/data"]][(num_data_prev + 1):(num_data_prev + num_data)] = data
      h5_merged[["matrix/indices"]][(num_data_prev + 1):(num_data_prev + num_data)] = indices
      h5_merged[["matrix/indptr"]][(num_indptr_prev + 1):(num_indptr_prev + num_indptr)] = indptr
      h5_merged[["matrix/barcodes"]][(num_cells_prev + 1):(num_cells_prev + num_cells)] = barcodes
    }
    num_data_prev = num_data_prev + num_data
    num_indptr_prev = num_indptr_prev + num_indptr
    num_cells_prev = num_cells_prev + num_cells
    h5file$close_all()
  }
  h5_merged$close_all()
}


#' Restore links (to hdf5 files) for reloaded online Liger object
#'
#' When loading the saved online Liger object in a new R session, the links to hdf5 files may be corrupted. This functions enables
#' the restoration of those links so that new analyses can be carried out.
#'
#' @param object \code{liger} object.
#' @param file.path List of paths to hdf5 files.
#' @return \code{liger} object with restored links.
#' @export
#' @examples
#' \dontrun{
#' # We want to restore the ligerex (liger object based on HDF5 files)
#' # It has broken connections to HDF5 files
#' # Call the following function and provide the paths to the correspoinding files
#' ligerex = restoreOnlineLiger(ligerex, file.path = list("path1/library1.h5", "path2/library2.h5"))
#' }
restoreOnlineLiger <- function(object, file.path = NULL) {
  if (is.null(file.path) & is.null(object@h5file.info[[1]][["file.path"]])) { # file path is not provided by file.path param or liger object
    stop('File path information is not stored in the liger object. Please provide a list of file paths through file.path parameter.')
  }

  if (!is.null(file.path)) { # if new file path is provided, update liger object h5file.info
    for (i in 1:length(object@h5file.info)) {
      object@h5file.info[[i]][["file.path"]] = file.path[[i]]
    }
  }
  # restore access to corresponding h5 files
  object@raw.data = lapply(object@h5file.info, function(x) hdf5r::H5File$new(x[["file.path"]], mode="r+"))
  object@norm.data = lapply(object@raw.data, function(x) x[["norm.data"]])
  object@scale.data = lapply(object@raw.data, function(x) x[["scale.data"]])

  for (i in 1:length(object@raw.data)){
    if (object@h5file.info[[i]][["format.type"]] == "10X"){
      barcodes.name = "matrix/barcodes"
      barcodes = object@raw.data[[i]][[barcodes.name]][]
      num_cells = object@raw.data[[i]][[barcodes.name]]$dims
      data.name = "matrix/data"
      indices.name = "matrix/indices"
      indptr.name = "matrix/indptr"
      genes.name = "matrix/features/name"
    } else if (object@h5file.info[[i]][["format.type"]] == "AnnData"){
      barcodes.name = "obs"
      barcodes = object@raw.data[[i]][[barcodes.name]][]$cell
      num_cells = length(object@raw.data[[i]][[barcodes.name]][]$cell)
      data.name = "raw.X/data"
      indices.name = "raw.X/indices"
      indptr.name = "raw.X/indptr"
      genes.name = "raw.var"
    } else {
      barcodes = object@raw.data[[i]][[barcodes.name]][]
      num_cells = length(object@raw.data[[i]][[barcodes.name]][])
      data.name = data.name
      indices.name = indices.name
      indptr.name = indptr.name
    }
    object@h5file.info[[i]][["data"]] = object@raw.data[[i]][[data.name]]
    object@h5file.info[[i]][["indices"]] = object@raw.data[[i]][[indices.name]]
    object@h5file.info[[i]][["indptr"]] = object@raw.data[[i]][[indptr.name]]
    object@h5file.info[[i]][["barcodes"]] = object@raw.data[[i]][[barcodes.name]]
    object@h5file.info[[i]][["genes"]] = object@raw.data[[i]][[genes.name]]
  }
  return(object)
}

#' Create a liger object.
#'
#' This function initializes a liger object with the raw data passed in. It requires a list of
#' expression (or another single-cell modality) matrices (gene by cell) for at least two datasets.
#' By default, it converts all passed data into sparse matrices (dgCMatrix) to reduce object size.
#' It initializes cell.data with nUMI and nGene calculated for every cell.
#'
#' @param raw.data List of expression matrices (gene by cell). Should be named by dataset.
#' @param take.gene.union Whether to fill out raw.data matrices with union of genes across all
#'   datasets (filling in 0 for missing data) (requires make.sparse = TRUE) (default FALSE).
#' @param remove.missing Whether to remove cells not expressing any measured genes, and genes not
#'   expressed in any cells (if take.gene.union = TRUE, removes only genes not expressed in any
#'   dataset) (default TRUE).
#' @param format.type HDF5 format (10X CellRanger by default).
#' @param data.name Path to the data values stored in HDF5 file.
#' @param indices.name Path to the indices of data points stored in HDF5 file.
#' @param indptr.name Path to the pointers stored in HDF5 file.
#' @param genes.name Path to the gene names stored in HDF5 file.
#' @param barcodes.name Path to the barcodes stored in HDF5 file.
#' @param verbose Print messages (TRUE by default)
#' @return \code{liger} object with raw.data slot set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
createLiger <- function(raw.data,
                        take.gene.union = FALSE,
                        remove.missing = TRUE,
                        format.type = "10X",
                        data.name = NULL,
                        indices.name = NULL,
                        indptr.name = NULL,
                        genes.name = NULL,
                        barcodes.name = NULL,
                        verbose = TRUE) {
  if (class(raw.data[[1]])[1] == "character") { #HDF5 filenames instead of in-memory matrices
    object <- methods::new(Class = "liger", raw.data = raw.data,
                           version = packageVersion("rliger"))
    object@V = rep(list(NULL), length(raw.data))
    object@H = rep(list(NULL), length(raw.data))
    cell.data = list()
    format.type.list = format.type
    if (length(format.type) == 1) format.type.list = rep(format.type, length(raw.data))
    for (i in 1:length(raw.data)){
      file.h5 = hdf5r::H5File$new(raw.data[[i]], mode="r+")
      object@raw.data[[i]] = file.h5
      if (format.type.list[i] == "10X"){
        barcodes.name = "matrix/barcodes"
        barcodes = file.h5[[barcodes.name]][]
        num_cells = file.h5[[barcodes.name]]$dims
        data.name = "matrix/data"
        indices.name = "matrix/indices"
        indptr.name = "matrix/indptr"
        genes.name = "matrix/features/name"
      } else if (format.type.list[i] == "AnnData"){
        barcodes.name = "obs"
        barcodes = file.h5[[barcodes.name]][]$cell
        num_cells = length(file.h5[[barcodes.name]][]$cell)
        data.name = "raw.X/data"
        indices.name = "raw.X/indices"
        indptr.name = "raw.X/indptr"
        genes.name = "raw.var"
      } else {
        barcodes = file.h5[[barcodes.name]][]
        num_cells = length(file.h5[[barcodes.name]][])
        data.name = data.name
        indices.name = indices.name
        indptr.name = indptr.name
      }
      object@h5file.info[[i]] = list(data = file.h5[[data.name]],
                                     indices = file.h5[[indices.name]],
                                     indptr = file.h5[[indptr.name]],
                                     barcodes = file.h5[[barcodes.name]],
                                     genes = file.h5[[genes.name]],
                                     format.type = format.type.list[i],
                                     sample.data.type = NULL,
                                     file.path = raw.data[[i]])
      if (file.h5$exists("norm.data")){
        object@norm.data[[i]] = file.h5[["norm.data"]]
        names(object@norm.data)[[i]] = names(object@raw.data)[[i]]
      }

      if (file.h5$exists("scale.data")){
        object@scale.data[[i]] = file.h5[["scale.data"]]
        names(object@scale.data)[[i]] = names(object@raw.data)[[i]]
      }

      if (file.h5$exists("cell.data")){
        cell.data[[i]] = data.frame(dataset = file.h5[["cell.data"]][]$dataset,
                                    nUMI = file.h5[["cell.data"]][]$nUMI,
                                    nGene = file.h5[["cell.data"]][]$nGene)
        rownames(cell.data[[i]]) = file.h5[["cell.data"]][]$barcode
      } else {
        dataset = rep(names(object@raw.data)[i], num_cells)
        cell.data[[i]] = data.frame(dataset)
        rownames(cell.data[[i]]) = barcodes
      }
    }
    if (is.null(names(object@raw.data))){
      names(object@raw.data) <- as.character(paste0("data",1:length(object@raw.data)))
    }
    object@cell.data = Reduce(rbind, cell.data)
    names(object@H) <- names(object@V) <- names(object@h5file.info) <- names(object@raw.data)
    return(object)
  }

  raw.data <- lapply(raw.data, function(x) {
    if (class(x)[1] == "dgTMatrix" | class(x)[1] == 'dgCMatrix') {
      mat <- as(x, 'CsparseMatrix')
      # Check if dimnames exist
      if (is.null(x@Dimnames[[1]])) {
        stop('Raw data must have both row (gene) and column (cell) names.')
      }
      mat@Dimnames <- x@Dimnames
      return(mat)
    } else {
      as(as.matrix(x), 'CsparseMatrix')
    }
  })

  if (length(Reduce(intersect, lapply(raw.data, colnames))) > 0 & length(raw.data) > 1) {
    stop('At least one cell name is repeated across datasets; please make sure all cell names
         are unique.')
  }
  if (take.gene.union) {
    merged.data <- MergeSparseDataAll(raw.data)
    if (remove.missing) {
      missing_genes <- which(rowSums(merged.data) == 0)
      if (length(missing_genes) > 0) {
        if (verbose) {
          message("Removing ", length(missing_genes),
                  " genes not expressed in any cells across merged datasets.")
        }
        if (length(missing_genes) < 25) {
          if (verbose) {
            message(rownames(merged.data)[missing_genes])
          }
        }
        merged.data <- merged.data[-missing_genes, ]
      }
    }
    raw.data <- lapply(raw.data, function(x) {
      merged.data[, colnames(x)]
    })
  }
  object <- methods::new(
    Class = "liger",
    raw.data = raw.data,
    version = packageVersion("rliger")
  )
  # remove missing cells
  if (remove.missing) {
    object <- removeMissingObs(object, use.cols = TRUE, verbose = verbose)
    # remove missing genes if not already merged
    if (!take.gene.union) {
      object <- removeMissingObs(object, use.cols = FALSE, verbose = verbose)
    }
  }

  # Initialize cell.data for object with nUMI, nGene, and dataset
  nUMI <- unlist(lapply(object@raw.data, function(x) {
    colSums(x)
  }), use.names = FALSE)
  nGene <- unlist(lapply(object@raw.data, function(x) {
    colSums(x > 0)
  }), use.names = FALSE)
  dataset <- unlist(lapply(seq_along(object@raw.data), function(i) {
    rep(names(object@raw.data)[i], ncol(object@raw.data[[i]]))
  }), use.names = FALSE)
  object@cell.data <- data.frame(nUMI, nGene, dataset)
  rownames(object@cell.data) <- unlist(lapply(object@raw.data, function(x) {
    colnames(x)
  }), use.names = FALSE)

  return(object)
}

#create new dataset, first deleting existing record if dataset already exists
safe_h5_create = function(object, idx, dataset_name, dims, mode="double", chunk_size = dims)
{
  if (!object@raw.data[[idx]]$exists(dataset_name)) {
    object@raw.data[[idx]]$create_dataset(name = dataset_name,dims = dims,dtype = mode, chunk_dims = chunk_size)
  } else {
    if (object@raw.data[[idx]]$exists("scale.data")) {
      if (object@raw.data[[idx]][["scale.data"]]$dims[1] < length(object@var.genes)){
        extendDataSet(object@raw.data[[idx]][["scale.data"]], c(length(object@var.genes), object@raw.data[[idx]][["scale.data"]]$dims[2]))
      }
    } else if (object@raw.data[[idx]]$exists("gene_vars")) {
      if (object@raw.data[[idx]][["gene_vars"]]$dims[1] < length(object@var.genes)){
        extendDataSet(object@raw.data[[idx]][["gene_vars"]], length(object@var.genes))
      }
    }
  }
}

#' Normalize raw datasets to column sums
#'
#' This function normalizes data to account for total gene expression across a cell.
#'
#' @param object \code{liger} object.
#' @param chunk size of chunks in hdf5 file. (default 1000)
#' @param format.type string of HDF5 format (10X CellRanger by default).
#' @param remove.missing Whether to remove cells not expressing any measured genes, and genes not
#'   expressed in any cells (if take.gene.union = TRUE, removes only genes not expressed in any
#'   dataset) (default TRUE).
#' @param verbose Print progress bar/messages (TRUE by default)
#' @return \code{liger} object with norm.data slot set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
normalize <- function(object,
                      chunk = 1000,
                      format.type = "10X",
                      remove.missing = TRUE,
                      verbose = TRUE) {
  if (class(object@raw.data[[1]])[1] == "H5File") {
    hdf5_files = names(object@raw.data)
    nUMI = c()
    nGene = c()
    for (i in 1:length(hdf5_files))
    {
      if (verbose) {
        message(hdf5_files[i])
      }
      chunk_size = chunk
      #fname = hdf5_files[[i]]
      num_entries = object@h5file.info[[i]][["data"]]$dims
      num_cells = object@h5file.info[[i]][["barcodes"]]$dims
      num_genes = object@h5file.info[[i]][["genes"]]$dims


      prev_end_col = 1
      prev_end_data = 1
      prev_end_ind = 0
      gene_sum_sq = rep(0,num_genes)
      gene_means = rep(0,num_genes)
      #file.h5$close_all()

      safe_h5_create(object = object, idx = i, dataset_name = "/norm.data", dims = num_entries, mode = h5types$double, chunk_size = chunk_size)
      safe_h5_create(object = object, idx = i, dataset_name = "/cell_sums", dims = num_cells, mode = h5types$int, chunk_size = chunk_size)

      #file.h5 = H5File$new(fname, mode="r+")
      num_chunks = ceiling(num_cells/chunk_size)
      if (verbose) {
        pb = txtProgressBar(0,num_chunks,style = 3)
      }
      ind = 0
      while(prev_end_col < num_cells)
      {
        ind = ind + 1
        if (num_cells - prev_end_col < chunk_size)
        {
          chunk_size = num_cells - prev_end_col + 1
        }
        start_inds = object@h5file.info[[i]][["indptr"]][prev_end_col:(prev_end_col+chunk_size)]
        row_inds = object@h5file.info[[i]][["indices"]][(prev_end_ind+1):(tail(start_inds, 1))]
        counts = object@h5file.info[[i]][["data"]][(prev_end_ind+1):(tail(start_inds, 1))]
        raw.data = sparseMatrix(i=row_inds[1:length(counts)]+1,p=start_inds[1:(chunk_size+1)]-prev_end_ind,x=counts,dims=c(num_genes,chunk_size))
        nUMI = c(nUMI, colSums(raw.data))
        nGene = c(nGene, colSums(raw.data > 0))
        norm.data = Matrix.column_norm(raw.data)
        object@raw.data[[i]][["norm.data"]][(prev_end_ind+1):(tail(start_inds, 1))] = norm.data@x
        object@raw.data[[i]][["cell_sums"]][prev_end_col:(prev_end_col+chunk_size-1)] = Matrix::colSums(raw.data)
        #h5write(norm.data,file=fname,name="/norm.data",index=list(prev_end_ind:tail(start_inds, 1)))
        #h5write(colSums(raw.data),file=fname,name="/cell_sums",index=list(prev_end_col:(prev_end_col+chunk_size)))
        prev_end_col = prev_end_col + chunk_size
        prev_end_data = prev_end_data + length(norm.data@x)
        prev_end_ind = tail(start_inds, 1)

        # calculate row sum and sum of squares using normalized data
        row_sums = Matrix::rowSums(norm.data)
        gene_sum_sq = gene_sum_sq + rowSums(norm.data*norm.data)
        gene_means = gene_means + row_sums
        if (verbose) {
          setTxtProgressBar(pb,ind)
        }
      }
      if (verbose) {
        setTxtProgressBar(pb,num_chunks)
        cat("\n")
      }
      gene_means = gene_means / num_cells
      safe_h5_create(object = object, idx = i, dataset_name = "gene_means", dims=num_genes, mode=h5types$double)
      safe_h5_create(object = object, idx = i, dataset_name = "gene_sum_sq", dims=num_genes, mode=h5types$double)
      object@raw.data[[i]][["gene_means"]][1:length(gene_means)] = gene_means
      object@raw.data[[i]][["gene_sum_sq"]][1:length(gene_sum_sq)] = gene_sum_sq
      object@norm.data[[i]] = object@raw.data[[i]][["norm.data"]]
      rm(row_sums)
      rm(raw.data)
    }
    object@cell.data$nUMI = nUMI
    object@cell.data$nGene = nGene

    for (i in 1:length(object@raw.data)){
      if (!object@raw.data[[i]]$exists("cell.data")) {
        cell.data.i = object@cell.data[object@cell.data$dataset == names(object@raw.data)[i], ]
        cell.data.i$barcode = rownames(cell.data.i)
        object@raw.data[[i]][["cell.data"]] = cell.data.i
      }
    }

    names(object@norm.data) = names(object@raw.data)
  } else {
    if (remove.missing) {
      object <- removeMissingObs(object, slot.use = "raw.data", use.cols = TRUE)
    }
    if (class(object@raw.data[[1]])[1] == "dgTMatrix" |
        class(object@raw.data[[1]])[1] == "dgCMatrix") {
      object@norm.data <- lapply(object@raw.data, Matrix.column_norm)
    } else {
      object@norm.data <- lapply(object@raw.data, function(x) {
        sweep(x, 2, colSums(x), "/")
      })
    }
  }
  return(object)
}

#' Calculate variance of gene expression across cells in an online fashion
#'
#' This function calculates the variance of gene expression values across cells for hdf5 files.
#'
#' @param object \code{liger} object. The input raw.data should be a list of hdf5 files.
#'    Should call normalize and selectGenes before calling.
#' @param chunk size of chunks in hdf5 file. (default 1000)
#' @param verbose Print progress bar/messages (TRUE by default)
#' @return \code{liger} object with scale.data slot set.
calcGeneVars = function (object, chunk = 1000, verbose = TRUE)
{
  hdf5_files = names(object@raw.data)
  for (i in 1:length(hdf5_files)) {
    if (verbose) {
      message(hdf5_files[i])
    }
    chunk_size = chunk
    num_cells = object@h5file.info[[i]][["barcodes"]]$dims
    num_genes = object@h5file.info[[i]][["genes"]]$dims
    num_entries = object@h5file.info[[i]][["data"]]$dims

    prev_end_col = 1
    prev_end_data = 1
    prev_end_ind = 0
    gene_vars = rep(0,num_genes)
    gene_means = object@raw.data[[i]][["gene_means"]][]
    gene_num_pos = rep(0,num_genes)

    num_chunks = ceiling(num_cells/chunk_size)
    if (verbose) {
      pb = txtProgressBar(0, num_chunks, style = 3)
    }
    ind = 0
    while (prev_end_col < num_cells) {
      ind = ind + 1
      if (num_cells - prev_end_col < chunk_size) {
        chunk_size = num_cells - prev_end_col + 1
      }
      start_inds = object@h5file.info[[i]][["indptr"]][prev_end_col:(prev_end_col+chunk_size)]
      row_inds = object@h5file.info[[i]][["indices"]][(prev_end_ind+1):(tail(start_inds, 1))]
      counts = object@norm.data[[i]][(prev_end_ind+1):(tail(start_inds, 1))]
      norm.data = sparseMatrix(i=row_inds[1:length(counts)]+1,p=start_inds[1:(chunk_size+1)]-prev_end_ind,x=counts,dims=c(num_genes,chunk_size))

      num_read = length(counts)
      prev_end_col = prev_end_col + chunk_size
      prev_end_data = prev_end_data + num_read
      prev_end_ind = tail(start_inds, 1)
      gene_vars = gene_vars + sumSquaredDeviations(norm.data,gene_means)
      if (verbose) {
        setTxtProgressBar(pb, ind)
      }
    }
    if (verbose) {
      setTxtProgressBar(pb, num_chunks)
      cat("\n")
    }
    gene_vars = gene_vars/(num_cells - 1)
    safe_h5_create(object = object, idx = i, dataset_name = "/gene_vars", dims = num_genes, mode = h5types$double)
    object@raw.data[[i]][["gene_vars"]][1:num_genes] = gene_vars
  }
  return(object)
}

#' Select a subset of informative genes
#'
#' This function identifies highly variable genes from each dataset and combines these gene sets
#' (either by union or intersection) for use in downstream analysis. Assuming that gene
#' expression approximately follows a Poisson distribution, this function identifies genes with
#' gene expression variance above a given variance threshold (relative to mean gene expression).
#' It also provides a log plot of gene variance vs gene expression (with a line indicating expected
#' expression across genes and cells). Selected genes are plotted in green.
#'
#' @param object \code{liger} object. Should have already called normalize.
#' @param var.thresh Variance threshold. Main threshold used to identify variable genes. Genes with
#'   expression variance greater than threshold (relative to mean) are selected.
#'   (higher threshold -> fewer selected genes). Accepts single value or vector with separate
#'   var.thresh for each dataset. (default 0.1)
#' @param alpha.thresh Alpha threshold. Controls upper bound for expected mean gene expression
#'   (lower threshold -> higher upper bound). (default 0.99)
#' @param num.genes Number of genes to find for each dataset. Optimises the value of var.thresh
#'   for each dataset to get this number of genes. Accepts single value or vector with same length
#'   as number of datasets (optional, default=NULL).
#' @param tol Tolerance to use for optimization if num.genes values passed in (default 0.0001).
#' @param datasets.use List of datasets to include for discovery of highly variable genes.
#'   (default 1:length(object@raw.data))
#' @param combine How to combine variable genes across experiments. Either "union" or "intersection".
#'   (default "union")
#' @param capitalize Capitalize gene names to match homologous genes (ie. across species)
#'   (default FALSE)
#' @param do.plot Display log plot of gene variance vs. gene expression for each dataset.
#'   Selected genes are plotted in green. (default FALSE)
#' @param cex.use Point size for plot.
#' @param chunk size of chunks in hdf5 file. (default 1000)
#' @param unshared Whether to consider unshared features (Default FALSE)
#' @param unshared.datasets A list of the datasets to consider unshared features for, i.e. list(2), to use the second dataset
#' @param unshared.thresh A list of threshold values to apply to each unshared dataset. If only one value is provided, it will apply to all unshared
#'  datasets. If a list is provided, it must match the length of the unshared datasets submitted.
#' @return \code{liger} object with var.genes slot set.
#' @importFrom stats optimize
#' @importFrom graphics abline plot points title
#' @importFrom stats qnorm
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
selectGenes <- function(object, var.thresh = 0.1, alpha.thresh = 0.99, num.genes = NULL,
                        tol = 0.0001, datasets.use = 1:length(object@raw.data), combine = "union",
                        capitalize = FALSE, do.plot = FALSE, cex.use = 0.3, chunk=1000, unshared = FALSE, unshared.datasets = NULL, unshared.thresh = NULL)
{
  if (class(object@raw.data[[1]])[1] == "H5File") {
    if (!object@raw.data[[1]]$exists("gene_vars")) {
      object = calcGeneVars(object,chunk)
    }
    hdf5_files = names(object@raw.data)
    if (length(var.thresh) == 1) {
      var.thresh <- rep(var.thresh, length(hdf5_files))
    }
    genes.use <- c()
    for (i in datasets.use) {
      if (object@h5file.info[[i]][["format.type"]] == "AnnData"){
        genes = object@h5file.info[[i]][["genes"]][]$index
      } else {
        genes = object@h5file.info[[i]][["genes"]][]
      }

      if (capitalize) {
        genes = toupper(genes)
      }
      trx_per_cell = object@raw.data[[i]][["cell_sums"]][]
      gene_expr_mean = object@raw.data[[i]][["gene_means"]][]
      gene_expr_var = object@raw.data[[i]][["gene_vars"]][]

      names(gene_expr_mean) <- names(gene_expr_var) <- genes # assign gene names
      nolan_constant <- mean((1/trx_per_cell))
      alphathresh.corrected <- alpha.thresh/length(genes)
      genemeanupper <- gene_expr_mean + qnorm(1 - alphathresh.corrected/2) *
        sqrt(gene_expr_mean * nolan_constant/length(trx_per_cell))
      genes.new <- names(gene_expr_var)[which(gene_expr_var/nolan_constant >
                                                genemeanupper & log10(gene_expr_var) > log10(gene_expr_mean) +
                                                (log10(nolan_constant) + var.thresh[i]))]
      if (do.plot) {
        plot(log10(gene_expr_mean), log10(gene_expr_var),
             cex = cex.use, xlab = "Gene Expression Mean (log10)",
             ylab = "Gene Expression Variance (log10)")
        points(log10(gene_expr_mean[genes.new]), log10(gene_expr_var[genes.new]),
               cex = cex.use, col = "green")
        abline(log10(nolan_constant), 1, col = "purple")
        legend("bottomright", paste0("Selected genes: ",
                                     length(genes.new)), pch = 20, col = "green")
        title(main = hdf5_files[i])
      }
      if (combine == "union") {
        genes.use <- union(genes.use, genes.new)
      }
      if (combine == "intersection") {
        if (length(genes.use) == 0) {
          genes.use <- genes.new
        }
        genes.use <- intersect(genes.use, genes.new)
      }
    }

    for (i in 1:length(hdf5_files)) {
      if (object@h5file.info[[i]][["format.type"]] == "AnnData"){
        genes = object@h5file.info[[i]][["genes"]][]$index
      } else {
        genes = object@h5file.info[[i]][["genes"]][]
      }
      genes.use <- genes.use[genes.use %in% genes]
    }

    if (length(genes.use) == 0) {
      warning("No genes were selected; lower var.thresh values or choose 'union' for combine parameter",
              immediate. = TRUE)
    }
    object@var.genes = genes.use
  } else {
    # Expand if only single var.thresh passed
    if (length(var.thresh) == 1) {
      var.thresh <- rep(var.thresh, length(object@raw.data))
    }
    if (length(num.genes) == 1) {
      num.genes <- rep(num.genes, length(object@raw.data))
    }
    if (!identical(intersect(datasets.use, 1:length(object@raw.data)),datasets.use)) {
      datasets.use = intersect(datasets.use, 1:length(object@raw.data))
    }
    genes.use <- c()
    for (i in datasets.use) {
      if (capitalize) {
        rownames(object@raw.data[[i]]) <- toupper(rownames(object@raw.data[[i]]))
        rownames(object@norm.data[[i]]) <- toupper(rownames(object@norm.data[[i]]))
      }
      trx_per_cell <- colSums(object@raw.data[[i]])
      # Each gene's mean expression level (across all cells)
      gene_expr_mean <- rowMeansFast(object@norm.data[[i]])
      # Each gene's expression variance (across all cells)
      gene_expr_var <- rowVarsFast(object@norm.data[[i]], gene_expr_mean)
      names(gene_expr_mean) <- names(gene_expr_var) <- rownames(object@norm.data[[i]])
      nolan_constant <- mean((1 / trx_per_cell))
      alphathresh.corrected <- alpha.thresh / nrow(object@raw.data[[i]])
      genemeanupper <- gene_expr_mean + qnorm(1 - alphathresh.corrected / 2) *
        sqrt(gene_expr_mean * nolan_constant / ncol(object@raw.data[[i]]))
      basegenelower <- log10(gene_expr_mean * nolan_constant)

      num_varGenes <- function(x, num.genes.des){
        # This function returns the difference between the desired number of genes and
        # the number actually obtained when thresholded on x
        y <- length(which(gene_expr_var / nolan_constant > genemeanupper &
                            log10(gene_expr_var) > basegenelower + x))
        return(abs(num.genes.des - y))
      }

      if (!is.null(num.genes)) {
        # Optimize to find value of x which gives the desired number of genes for this dataset
        # if very small number of genes requested, var.thresh may need to exceed 1
        optimized <- optimize(num_varGenes, c(0, 1.5), tol = tol,
                              num.genes.des = num.genes[i])
        var.thresh[i] <- optimized$minimum
        if (optimized$objective > 1) {
          warning(paste0("Returned number of genes for dataset ", i, " differs from requested by ",
                         optimized$objective, ". Lower tol or alpha.thresh for better results."))
        }
      }

      genes.new <- names(gene_expr_var)[which(gene_expr_var / nolan_constant > genemeanupper &
                                                log10(gene_expr_var) > basegenelower + var.thresh[i])]

      if (do.plot) {
        graphics::plot(log10(gene_expr_mean), log10(gene_expr_var), cex = cex.use,
                       xlab='Gene Expression Mean (log10)',
                       ylab='Gene Expression Variance (log10)')

        graphics::points(log10(gene_expr_mean[genes.new]), log10(gene_expr_var[genes.new]),
                         cex = cex.use, col = "green")
        graphics::abline(log10(nolan_constant), 1, col = "purple")

        legend("bottomright", paste0("Selected genes: ", length(genes.new)), pch = 20, col = "green")
        graphics::title(main = names(object@raw.data)[i])
      }
      if (combine == "union") {
        genes.use <- union(genes.use, genes.new)
      }
      if (combine == "intersection") {
        if (length(genes.use) == 0) {
          genes.use <- genes.new
        }
        genes.use <- intersect(genes.use, genes.new)
      }
    }

    for (i in 1:length(object@raw.data)) {
      genes.use <- genes.use[genes.use %in% rownames(object@raw.data[[i]])]
    }

    if (length(genes.use) == 0) {
      warning("No genes were selected; lower var.thresh values or choose 'union' for combine parameter",
              immediate. = TRUE)
    }
    object@var.genes <- genes.use
  }
  # Only for unshared Features
  if (isTRUE(unshared)) {
    ind.thresh = c()
    # If only one threshold is provided, apply to all unshared datasets
    if(length(unshared.thresh == 1)){
      ind.thresh = rep(unshared.thresh,length(object@raw.data))
    }    else{ # If thresholds are provided for every dataset, use the respective threshold for each datatset
      if (length(unshared.thresh) != length(unshared.datasets)) {
        warning("The number of thresholds does not match the number of datasets; Please provide either a single threshold value or a value for each unshared dataset.",
                immediate. = TRUE)
      }
      names(unshared.thresh) = unshared.datasets
      for (i in unshared.datasets){
        ind.thresh[[i]] = unshared.thresh$i
      }
    }
    unshared.feats <- c()

    for (i in 1:length(object@raw.data)){
      unshared.feats[i] <- list(NULL)
    }

    #construct a list of shared features
    shared_names = rownames(object@raw.data[[1]])
    for (matrix in 2:length(object@raw.data)){
      shared_names = subset(shared_names, shared_names %in% rownames(object@raw.data[[i]]))
    }

    for (i in unshared.datasets){
      unshared.use <- c()
      #Provides normalized subset of unshared features
      normalized_unshared = object@norm.data[[i]][!rownames(object@norm.data[[i]]) %in% shared_names,]
      #Selects top variable features
      genes.unshared <- c()
      trx_per_cell <- colSums(object@raw.data[[i]])
      # Each gene's mean expression level (across all cells)
      gene_expr_mean <- rowMeansFast(normalized_unshared)
      # Each gene's expression variance (across all cells)
      gene_expr_var <- rowVarsFast(normalized_unshared, gene_expr_mean)
      names(gene_expr_mean) <- names(gene_expr_var) <- rownames(normalized_unshared)
      nolan_constant <- mean((1 / trx_per_cell))
      alphathresh.corrected <- alpha.thresh / nrow(object@raw.data[[i]])
      genemeanupper <- gene_expr_mean + qnorm(1 - alphathresh.corrected / 2) *
        sqrt(gene_expr_mean * nolan_constant / ncol(object@raw.data[[i]]))
      basegenelower <- log10(gene_expr_mean * nolan_constant)
      genes.unshared <- names(gene_expr_var)[which(gene_expr_var / nolan_constant > genemeanupper &
                                                     log10(gene_expr_var) > basegenelower + ind.thresh[[i]])]
      if (length(genes.unshared) == 0) {
        warning('Dataset ', i ,' does not contain any unshared features. Please remove this dataset from the unshared.datasets list and rerun the function', immediate. = TRUE)
      }
      if (length(genes.unshared != 0)) {
        unshared.feats[[i]] <- c(genes.unshared)
      }
    }
    names(unshared.feats) <- names(object@raw.data)
    object@var.unshared.features <- unshared.feats
    for (i in unshared.datasets){
      print(paste0("Selected ", length(unshared.feats[[i]]), " unshared features from ", names(unshared.feats)[i]," Dataset"))
    }
  }
  return(object)
}

#' Scale genes by root-mean-square across cells
#'
#' This function scales normalized gene expression data after variable genes have been selected.
#' Note that the data is not mean-centered before scaling because expression values must remain
#' positive (NMF only accepts positive values). It also removes cells which do not have any
#' expression across the genes selected, by default.
#'
#' @param object \code{liger} object. Should call normalize and selectGenes before calling.
#' @param remove.missing Whether to remove cells from scale.data with no gene expression
#'   (default TRUE).
#' @param chunk size of chunks in hdf5 file. (default 1000)
#' @param verbose Print progress bar/messages (TRUE by default)
#' @return \code{liger} object with scale.data slot set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
scaleNotCenter <- function(object, remove.missing = TRUE, chunk = 1000, verbose = TRUE) {
  if (class(object@raw.data[[1]])[1] == "H5File") {
    hdf5_files = names(object@raw.data)
    vargenes = object@var.genes
    for (i in 1:length(hdf5_files)) {
      if (verbose) {
        message(hdf5_files[i])
      }
      chunk_size = chunk

      if (object@h5file.info[[i]][["format.type"]] == "AnnData"){
        genes = object@raw.data[[i]][["raw.var"]][]$index
      } else {
        genes = object@h5file.info[[i]][["genes"]][]
      }
      num_cells = object@h5file.info[[i]][["barcodes"]]$dims
      num_genes = length(genes)
      num_entries = object@h5file.info[[i]][["data"]]$dims

      prev_end_col = 1
      prev_end_data = 1
      prev_end_ind = 0
      gene_vars = rep(0,num_genes)
      gene_means = object@raw.data[[i]][["gene_means"]][1:num_genes]
      gene_sum_sq = object@raw.data[[i]][["gene_sum_sq"]][1:num_genes]

      gene_inds = which(genes %in% vargenes)
      gene_root_mean_sum_sq = sqrt(gene_sum_sq/(num_cells-1))
      safe_h5_create(object = object, idx = i, dataset_name = "scale.data", dims = c(length(vargenes), num_cells), mode = h5types$double, chunk_size = c(length(vargenes), chunk_size))
      num_chunks = ceiling(num_cells/chunk_size)
      if (verbose) {
        pb = txtProgressBar(0, num_chunks, style = 3)
      }
      ind = 0
      while (prev_end_col < num_cells) {
        ind = ind + 1
        if (num_cells - prev_end_col < chunk_size) {
          chunk_size = num_cells - prev_end_col + 1
        }
        start_inds = object@h5file.info[[i]][["indptr"]][prev_end_col:(prev_end_col+chunk_size)]
        row_inds = object@h5file.info[[i]][["indices"]][(prev_end_ind+1):(tail(start_inds, 1))]
        counts = object@norm.data[[i]][(prev_end_ind+1):(tail(start_inds, 1))]
        scaled = sparseMatrix(i=row_inds[1:length(counts)]+1,p=start_inds[1:(chunk_size+1)]-prev_end_ind,x=counts,dims=c(num_genes,chunk_size))
        scaled = scaled[gene_inds, ]
        scaled = as.matrix(scaled)
        root_mean_sum_sq = gene_root_mean_sum_sq[gene_inds]
        scaled = sweep(scaled, 1, root_mean_sum_sq, "/")
        rownames(scaled) = genes[gene_inds]
        scaled = scaled[vargenes, ]
        scaled[is.na(scaled)] = 0
        scaled[scaled == Inf] = 0
        object@raw.data[[i]][["scale.data"]][1:length(vargenes),prev_end_col:(prev_end_col+chunk_size-1)] = scaled
        num_read = length(counts)
        prev_end_col = prev_end_col + chunk_size
        prev_end_data = prev_end_data + num_read
        prev_end_ind = tail(start_inds, 1)
        if (verbose) {
          setTxtProgressBar(pb, ind)
        }
      }
      object@scale.data[[i]] = object@raw.data[[i]][["scale.data"]]
      if (verbose) {
        setTxtProgressBar(pb, num_chunks)
        cat("\n")
      }
    }
    names(object@scale.data) <- names(object@raw.data)
  } else {
    object@scale.data <- lapply(1:length(object@norm.data), function(i) {
      scaleNotCenterFast(t(object@norm.data[[i]][object@var.genes, , drop = FALSE]))
    })
    # TODO: Preserve sparseness later on (convert inside optimizeALS)
    object@scale.data <- lapply(object@scale.data, function(x) {
      as.matrix(x)
    })

    names(object@scale.data) <- names(object@norm.data)
    for (i in 1:length(object@scale.data)) {
      object@scale.data[[i]][is.na(object@scale.data[[i]])] <- 0
      rownames(object@scale.data[[i]]) <- colnames(object@raw.data[[i]])
      colnames(object@scale.data[[i]]) <- object@var.genes
    }
    # may want to remove such cells before scaling -- should not matter for large datasets?
  }
  #Scale unshared features
  if (length(object@var.unshared.features) != 0){
    for (i in 1:length(object@raw.data)){
      if (!is.null(object@var.unshared.features[[i]])){
        if (class(object@raw.data[[i]])[1] == "dgTMatrix" ||
            class(object@raw.data[[i]])[1] == "dgCMatrix") {
          object@scale.unshared.data[[i]] <- scaleNotCenterFast(t(object@norm.data[[i]][object@var.unshared.features[[i]],]))
          object@scale.unshared.data[[i]] <- as.matrix(object@scale.unshared.data[[i]])
        } else {
          object@scale.unshared.data[[i]] <- scale(t(object@norm.data[[i]][object@var.unshared.features[[i]], ]), center = FALSE, scale = TRUE)
        }
        #names(object@scale.unshared.data) <- names(object@norm.data)
        object@scale.unshared.data[[i]][is.na(object@scale.unshared.data[[i]])] <- 0
        rownames(object@scale.unshared.data[[i]]) <- colnames(object@raw.data[[i]])
        colnames(object@scale.unshared.data[[i]]) <- object@var.unshared.features[[i]]
        #Remove cells that were deemed missing for the shared features
        object@scale.unshared.data[[i]] <- t(object@scale.unshared.data[[i]][rownames(object@scale.data[[i]]),])
      } else{object@scale.unshared.data[i]<- NA}
    }
    names(object@scale.unshared.data) <- names(object@norm.data)
  }
  return(object)
}

#' Remove cells/genes with no expression across any genes/cells
#'
#' Removes cells/genes from chosen slot with no expression in any genes or cells respectively.
#'
#' @param object \code{liger} object (scale.data or norm.data must be set).
#' @param slot.use The data slot to filter (takes "raw.data" and "scale.data") (default "raw.data").
#' @param use.cols Treat each column as a cell (default TRUE).
#' @param verbose Print messages (TRUE by default)
#'
#' @return \code{liger} object with modified raw.data (or chosen slot) (dataset names preserved).
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' if (any(rowSums(ctrl) == 0) || any(rowSums(stim) == 0)) {
#'     # example datasets do not have missing data, thus put in a condition
#'     # Though the function will return unchanged object if no missing found
#'     ligerex <- removeMissingObs(ligerex)
#' }
removeMissingObs <- function(object, slot.use = "raw.data", use.cols = TRUE, verbose = TRUE) {
  filter.data <- slot(object, slot.use)
  removed <- ifelse(((slot.use %in% c("raw.data", "norm.data")) & (use.cols == TRUE)) |
                      ((slot.use == "scale.data") & (use.cols == FALSE)) ,
                    yes = "cells", no = "genes")
  expressed <- ifelse(removed == "cells", yes = " any genes", no = "")
  filter.data <- lapply(seq_along(filter.data), function(x) {
    if (use.cols) {
      missing <- which(colSums(filter.data[[x]]) == 0)
    } else {
      missing <- which(rowSums(filter.data[[x]]) == 0)
    }
    if (length(missing) > 0) {
      if (verbose) {
        message("Removing ",  length(missing), " ", removed, " not expressing", expressed, " in ",
                names(object@raw.data)[x], ".")
      }
      if (use.cols) {
        if (length(missing) < 25) {
          if (verbose) {
            message(writeLines(colnames(filter.data[[x]])[missing]))
          }
        }
        subset <- filter.data[[x]][, -missing]
      } else {
        if (length(missing) < 25) {
          if (verbose) {
            message(writeLines(rownames(filter.data[[x]])[missing]))
          }
        }
        subset <- filter.data[[x]][-missing, ]
      }
    } else {
      subset <- filter.data[[x]]
    }
    subset
  })
  names(filter.data) <- names(object@raw.data)
  slot(object, slot.use) <- filter.data
  return(object)
}

#helper function for readSubset
#Samples cell barcodes from specified datasets
#balance=NULL (default) means that max_cells are sampled from among all cells.
#balance="cluster" samples up to max_cells from each cluster in each dataset
#balance="dataset" samples up to max_cells from each dataset
#datasets.use uses only the specified datasets for sampling. Default is NULL (all datasets)
#rand.seed for reproducibility (default 1).
#verbose for printing messages
#Returns: vector of cell barcodes
downsample <- function(object,balance=NULL,max_cells=1000,datasets.use=NULL,seed=1, verbose = TRUE)
{
  set.seed(seed)
  if(is.null(datasets.use))
  {
    datasets.use = names(object@raw.data)
    if (verbose) {
      message(datasets.use)
    }
  }
  inds = c()
  inds_ds = list()
  if (is.null(balance))
  {
    for (ds in 1:length(datasets.use))
    {
      inds = c(inds,rownames(object@H[[ds]]))
    }
    num_to_samp = min(max_cells,length(inds))
    inds = sample(inds,num_to_samp)
    for (ds in 1:length(datasets.use))
    {
      inds_ds[[ds]] = intersect(inds, rownames(object@H[[ds]]))
    }
  }
  else if (balance == "dataset")
  {
    for (ds in 1:length(datasets.use))
    {
      num_to_samp = min(max_cells,nrow(object@H[[ds]]))
      inds_ds[[ds]] = rownames(object@H[[ds]])[sample(1:nrow(object@H[[ds]]),num_to_samp)]
    }
  }
  else #balance clusters
  {
    if (nrow(object@cell.data)==0)
    {
      dataset <- unlist(lapply(seq_along(object@H), function(i) {
        rep(names(object@H)[i], nrow(object@H[[i]]))
      }), use.names = FALSE)
      object@cell.data <- data.frame(dataset)
      rownames(object@cell.data) <- unlist(lapply(object@H,
                                                  function(x) {
                                                    rownames(x)
                                                  }), use.names = FALSE)
    }
    for (ds in 1:length(datasets.use))
    {
      for (i in levels(object@clusters))
      {
        inds_to_samp = names(object@clusters)[object@clusters==i & object@cell.data[["dataset"]] == ds]
        num_to_samp = min(max_cells,length(inds_to_samp))
        inds_ds[[ds]] = sample(inds_to_samp,num_to_samp)
      }
    }
  }
  return(inds_ds)
}

#' Sample data for plotting
#'
#' This function samples raw/normalized/scaled data from on-disk HDF5 files for plotting.
#' This function assumes that the cell barcodes are unique across all datasets.
#'
#' @param object \code{liger} object. Should call normalize and selectGenes before calling.
#' @param slot.use Type of data for sampling (raw.data, norm.data(default), scale.data).
#' @param balance Type of sampling. NULL means that max_cells are sampled from among all cells;
#'                balance="dataset" samples up to max_cells from each dataset;
#'                balance="cluster" samples up to max_cells from each cluster.
#' @param chunk is the max number of cells at a time to read from disk (default 1000).
#' @param max.cells Total number of cell to sample (default 5000).
#' @param rand.seed  (default 1).
#' @param datasets.use uses only the specified datasets for sampling. Default is NULL (all datasets)
#' @param genes.use samples from only the specified genes. Default is NULL (all genes)
#' @param rand.seed for reproducibility (default 1).
#' @param verbose Print progress bar/messages (TRUE by default)
#' @return \code{liger} object with sample.data slot set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' if (length(ligerex@H) > 0) {
#'     # Downsampling is calculated basing on factorization result
#'     ligerex <- readSubset(ligerex, slot.use = "norm.data", max.cells = 100)
#' }
readSubset <- function(object,
                       slot.use = "norm.data",
                       balance = NULL,
                       max.cells = 1000,
                       chunk = 1000,
                       datasets.use = NULL,
                       genes.use = NULL,
                       rand.seed = 1,
                       verbose = TRUE) {
  if (class(object@raw.data[[1]])[1] == "H5File") {
    if (verbose) {
      message("Start sampling")
    }
    if(is.null(datasets.use))
    {
      datasets.use=names(object@H)
    }
    cell_inds = downsample(object, balance = balance, max_cells = max.cells, datasets.use = datasets.use, seed = rand.seed, verbose = verbose)

    hdf5_files = names(object@raw.data)
    #vargenes = object@var.genes

    # find the intersect of genes from each input datasets
    genes = c()
    if (slot.use != "scale.data"){
      for (i in 1:length(hdf5_files)) {
        if (object@h5file.info[[i]][["format.type"]] == "AnnData"){
          genes_i = object@h5file.info[[i]][["genes"]][]$index
        } else {
          genes_i = object@h5file.info[[i]][["genes"]][]
        }
        if (i == 1) genes = genes_i else genes = intersect(genes, genes_i)
      }
    } else {
      genes = object@var.genes
    }

    if(is.null(genes.use))
    {
      genes.use = genes
    }

    for (i in 1:length(hdf5_files)) {
      if (verbose) {
        message(hdf5_files[i])
      }
      if (slot.use == "scale.data") {
        data.subset = c()
      } else {
        data.subset = Matrix(nrow=length(genes.use),ncol=0,sparse=TRUE)
      }
      chunk_size = chunk
      if (object@h5file.info[[i]][["format.type"]] == "AnnData"){
        barcodes = object@h5file.info[[i]][["barcodes"]][]$cell
        genes = object@h5file.info[[i]][["genes"]][]$index
      } else {
        barcodes = object@h5file.info[[i]][["barcodes"]][]
        genes = object@h5file.info[[i]][["genes"]][]
      }
      num_cells = length(barcodes)
      num_genes = length(genes)

      prev_end_col = 1
      prev_end_data = 1
      prev_end_ind = 0


      #gene_inds = which(genes %in% vargenes)

      num_chunks = ceiling(num_cells/chunk_size)
      if (verbose) {
        pb = txtProgressBar(0, num_chunks, style = 3)
      }
      ind = 0

      while (prev_end_col < num_cells) {
        ind = ind + 1
        if (num_cells - prev_end_col < chunk_size) {
          chunk_size = num_cells - prev_end_col + 1
        }
        if (slot.use != "scale.data"){
          start_inds = object@h5file.info[[i]][["indptr"]][prev_end_col:(prev_end_col+chunk_size)]
          row_inds = object@h5file.info[[i]][["indices"]][(prev_end_ind+1):(tail(start_inds, 1))]
          if (slot.use=="raw.data")
          {
            counts = object@h5file.info[[i]][["data"]][(prev_end_ind+1):(tail(start_inds, 1))]
          }
          if (slot.use=="norm.data")
          {
            counts = object@norm.data[[i]][(prev_end_ind+1):(tail(start_inds, 1))]
          }
          one_chunk = sparseMatrix(i=row_inds[1:length(counts)]+1,p=start_inds[1:(chunk_size+1)]-prev_end_ind,x=counts,dims=c(num_genes,chunk_size))
          rownames(one_chunk) = genes
          colnames(one_chunk) = barcodes[(prev_end_col):(prev_end_col+chunk_size-1)]
          use_these = intersect(colnames(one_chunk),cell_inds[[i]])
          one_chunk = one_chunk[genes.use,use_these]
          data.subset = cbind(data.subset,one_chunk)

          num_read = length(counts)
          prev_end_col = prev_end_col + chunk_size
          prev_end_data = prev_end_data + num_read
          prev_end_ind = tail(start_inds, 1)
          setTxtProgressBar(pb, ind)
        } else {
          one_chunk = object@scale.data[[i]][,prev_end_col:(prev_end_col + chunk_size - 1)]
          rownames(one_chunk) = object@var.genes
          colnames(one_chunk) = barcodes[(prev_end_col):(prev_end_col+chunk_size-1)]
          use_these = intersect(colnames(one_chunk),cell_inds[[i]])
          one_chunk = one_chunk[genes.use,use_these]
          data.subset = cbind(data.subset,one_chunk)

          prev_end_col = prev_end_col + chunk_size
          if (verbose) {
            setTxtProgressBar(pb, ind)
          }
        }
        if (class(object@raw.data[[i]])[1] == "H5File") {
          object@sample.data[[i]] = data.subset
        } else if (class(object@raw.data[[i]])[1] != "H5File" & slot.use == "scale.data") {
          object@sample.data[[i]] = t(data.subset)
        }
        object@h5file.info[[i]][["sample.data.type"]] = slot.use
      }
      if (verbose) {
        setTxtProgressBar(pb, num_chunks)
        cat("\n")
      }
    }
  } else {
    if (verbose) {
      message("Start sampling")
    }
    if(is.null(datasets.use))
    {
      datasets.use = names(object@H)
    }
    cell_inds = downsample(object, balance = balance, max_cells = max.cells, datasets.use = datasets.use, verbose = verbose)

    files = names(object@raw.data)
    # find the intersect of genes from each input datasets
    genes = c()
    for (i in 1:length(files)) {
      genes_i = rownames(object@raw.data[[i]])
      if (i == 1) genes = genes_i else genes = intersect(genes, genes_i)
    }
    if(is.null(genes.use))
    {
      genes.use = genes
    }
    if (verbose) {
      pb = txtProgressBar(0, length(files), style = 3)
    }
    for (i in 1:length(files)){
      if (slot.use=="raw.data")
      {
        data.subset_i = object@raw.data[[i]][genes.use, cell_inds[[i]]]
      }
      if (slot.use=="norm.data")
      {
        data.subset_i = object@norm.data[[i]][genes.use, cell_inds[[i]]]
      }
      if(slot.use=="scale.data")
      {
        data.subset_i = t(object@scale.data[[i]][cell_inds[[i]], genes.use])
      }
      if (verbose) {
        setTxtProgressBar(pb, i)
      }
      object@sample.data[[i]] = data.subset_i
    }
    if (verbose){
      cat("\n")
    }
  }
  names(object@sample.data) = names(object@raw.data)
  return(object)
}

#######################################################################################
#### Factorization

#' Perform online iNMF on scaled datasets
#'
#' @description
#' Perform online integrative non-negative matrix factorization to represent multiple single-cell datasets
#' in terms of H, W, and V matrices. It optimizes the iNMF objective function using online learning (non-negative
#' least squares for H matrix, hierarchical alternating least squares for W and V matrices), where the
#' number of factors is set by k. The function allows online learning in 3 scenarios: (1) fully observed datasets;
#' (2) iterative refinement using continually arriving datasets; and (3) projection of new datasets without updating
#' the existing factorization. All three scenarios require fixed memory independent of the number of cells.
#'
#' For each dataset, this factorization produces an H matrix (cells by k), a V matrix (k by genes),
#' and a shared W matrix (k by genes). The H matrices represent the cell factor loadings.
#' W is identical among all datasets, as it represents the shared components of the metagenes
#' across datasets. The V matrices represent the dataset-specific components of the metagenes.
#'
#' @param object \code{liger} object with data stored in HDF5 files. Should normalize, select genes, and scale before calling.
#' @param X_new List of new datasets for scenario 2 or scenario 3. Each list element should be the name of an HDF5 file.
#' @param projection Perform data integration by shared metagene (W) projection (scenario 3). (default FALSE)
#' @param W.init Optional initialization for W. (default NULL)
#' @param V.init Optional initialization for V (default NULL)
#' @param H.init Optional initialization for H (default NULL)
#' @param A.init Optional initialization for A (default NULL)
#' @param B.init Optional initialization for B (default NULL)
#' @param k Inner dimension of factorization--number of metagenes (default 20). A value in the range 20-50 works well for most analyses.
#' @param lambda Regularization parameter. Larger values penalize dataset-specific effects more
#'   strongly (ie. alignment should increase as lambda increases). We recommend always using the default value except
#'   possibly for analyses with relatively small differences (biological replicates, male/female comparisons, etc.)
#'   in which case a lower value such as 1.0 may improve reconstruction quality. (default 5.0).
#' @param max.epochs Maximum number of epochs (complete passes through the data). (default 5)
#' @param miniBatch_max_iters Maximum number of block coordinate descent (HALS algorithm) iterations to perform for
#' each update of W and V (default 1). Changing this parameter is not recommended.
#' @param miniBatch_size Total number of cells in each minibatch (default 5000). This is a reasonable default, but a smaller value
#' such as 1000 may be necessary for analyzing very small datasets. In general, minibatch size should be no larger than the number
#' of cells in the smallest dataset.
#' @param h5_chunk_size Chunk size of input hdf5 files (default 1000). The chunk size should be no larger than the batch size.
#' @param seed Random seed to allow reproducible results (default 123).
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return \code{liger} object with H, W, V, A and B slots set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' if (length(ligerex@h5file.info) > 0) {
#'     # This function only works for HDF5 based liger object
#'     ligerex <- normalize(ligerex)
#'     ligerex <- selectGenes(ligerex)
#'     ligerex <- scaleNotCenter(ligerex)
#'     # `miniBatch_size` has to be no larger than the number of cells in the smallest dataset
#'     ligerex <- online_iNMF(ligerex, miniBatch_size = 100)
#' }
online_iNMF <- function(object,
                        X_new = NULL,
                        projection = FALSE,
                        W.init = NULL,
                        V.init = NULL,
                        H.init = NULL,
                        A.init = NULL,
                        B.init = NULL,
                        k = 20,
                        lambda = 5,
                        max.epochs = 5,
                        miniBatch_max_iters = 1,
                        miniBatch_size = 5000,
                        h5_chunk_size = 1000,
                        seed = 123,
                        verbose = TRUE){
  if (!is.null(X_new)){ # if there is new dataset
    raw.data_prev = object@raw.data
    norm.data_prev = object@norm.data
    h5file.info_prev = object@h5file.info
    scale.data_prev = object@scale.data
    cell.data_prev = object@cell.data
    names(raw.data_prev) = names(object@raw.data)

    # assuming only one new dataset arrives at a time
    raw.data = c()
    norm.data = c()
    h5file.info = c()
    scale.data = c()
    cell.data = c()
    for (i in 1:length(X_new)){
      raw.data = c(raw.data, X_new[[i]]@raw.data)
      norm.data = c(norm.data, X_new[[i]]@norm.data)
      h5file.info = c(h5file.info, X_new[[i]]@h5file.info)
      scale.data = c(scale.data, X_new[[i]]@scale.data)
      cell.data = rbind(cell.data, X_new[[i]]@cell.data)
    }
    object@raw.data = raw.data
    object@norm.data = norm.data
    object@h5file.info = h5file.info
    object@scale.data = scale.data
    object@cell.data = cell.data

    # check whether X_new needs to be processed
    for (i in 1:length(object@raw.data)){
      if (class(object@raw.data[[i]])[1] == "H5File"){
        processed = object@raw.data[[i]]$exists("scale.data")
      } else {
        processed = !is.null(X_new[[i]]@scale.data)
      }

      if (processed) {
        if (verbose) {
          cat("New dataset", i, "already preprocessed.", "\n")
        }
      } else {
        if (verbose) {
          cat("New dataset", i, "not preprocessed. Preprocessing...", "\n")
        }
        object = normalize(object, chunk = h5_chunk_size)
        object = scaleNotCenter(object, remove.missing = TRUE, chunk = h5_chunk_size)
        if (verbose) {
          cat("New dataset", i, "Processed.", "\n")
        }
      }
    }


    object@raw.data = c(raw.data_prev, object@raw.data)
    object@norm.data = c(norm.data_prev, object@norm.data)
    object@h5file.info = c(h5file.info_prev, object@h5file.info)
    object@scale.data = c(scale.data_prev, object@scale.data)
    object@cell.data = rbind(cell.data_prev, object@cell.data)
    # k x gene -> gene x k & cell x k-> k x cell
    object@W = t(object@W)
    object@V = lapply(object@V, t)
    object@H = lapply(object@H, t)
  }

  for (i in 1:length(object@raw.data)){
    if (class(object@raw.data[[i]])[1] != "H5File") object@scale.data[[i]] = t(object@scale.data[[i]])
  }

  ## extract required information and initialize algorithm
  num_files = length(object@raw.data) # number of total input hdf5 files
  num_prev_files = 0 # number of input hdf5 files processed in last step
  num_new_files = 0 # number of new input hdf5 files since last step
  if (is.null(X_new)) {
    num_prev_files = 0 # start from scratch
    num_new_files = num_files
  } else {
    num_new_files = length(X_new)
    num_prev_files = num_files - num_new_files
    if (verbose) {
      cat(num_new_files, "new datasets detected.", "\n")
    }
  }

  file_idx = 1:num_files # indices for all input files
  file_idx_new = (num_prev_files+1):num_files # indices only for new input files
  file_idx_prev = setdiff(file_idx,file_idx_new)

  vargenes = object@var.genes
  file_names = names(object@raw.data)
  gene_names = vargenes # genes selected for analysis
  num_genes = length(vargenes) # number of the selected genes

  cell_barcodes = list() # cell barcodes for each dataset
  for (i in file_idx){
    cell_barcodes[[i]] = rownames(object@cell.data)[object@cell.data$dataset == file_names[i]]
  }
  num_cells = unlist(lapply(cell_barcodes, length)) # number of cells in each dataset
  num_cells_new = num_cells[(num_prev_files+1):num_files]
  minibatch_sizes = rep(0, num_files)

  for (i in file_idx_new) {
    minibatch_sizes[i] = round((num_cells[i]/sum(num_cells[file_idx_new])) * miniBatch_size)
    if (minibatch_sizes[i] > num_cells[i]){
      stop(paste0("\nNumber of cells to be sampled (n=", minibatch_sizes[i],") is larger than the size of input dataset ", i, " (n=",num_cells[i],").",
                  "\nPlease use a smaller mini-batch size."))
    }
  }
  minibatch_sizes_orig = minibatch_sizes

  if (!projection) {

    if(!is.null(seed)){
      set.seed(seed)
    }

    # W matrix initialization
    if (is.null(X_new)) {
      object@W = matrix(abs(runif(num_genes * k, 0, 2)), num_genes, k)
      for (j in 1:k){
        object@W[, j] = object@W[, j] / sqrt(sum(object@W[, j]^2))
      }
    } else {
      object@W = if(!is.null(W.init)) W.init else object@W
    }
    # V_i matrix initialization
    if (is.null(X_new)) {
      object@V = list()
      for (i in file_idx){
        V_init_idx = sample(1:num_cells_new[i], k) # pick k sample from datasets as initial H matrix
        object@V[[i]] = object@scale.data[[i]][1:num_genes, V_init_idx]
        #object@V[[i]] = matrix(data = abs(x = runif(n = num_genes * k, min = 0, max = 2)),
        #                       nrow = num_genes,
        #                       ncol = k)
      }

      # normalize the columns of H_i, H_s matrices
      for (j in 1:k){
        for (i in file_idx){ # normalize columns of dictionaries
          object@V[[i]][, j] = object@V[[i]][, j] / sqrt(sum(object@V[[i]][, j]^2))
        }
      }
    } else { # if previous Vs are provided
      object@V[file_idx_prev] = if(!is.null(V.init)) V.init else object@V
      V_init_idx = list()
      for (i in file_idx_new){
        V_init_idx = sample(1:num_cells[i], k)
        object@V[[i]] = object@scale.data[[i]][1:num_genes, V_init_idx] # initialize the Vi for new dataset
        for (j in 1:k){
          object@V[[i]][, j] = object@V[[i]][, j] / sqrt(sum(object@V[[i]][, j]^2))
        }
      }
    }
    # H_i matrices initialization
    if (is.null(X_new)) {
      object@H = rep(list(NULL),num_files)
      H_minibatch = list()
    } else { # if previous Hs are provided
      object@H[file_idx_prev] = if(!is.null(H.init)) H.init else object@H
      object@H[file_idx_new] = rep(list(NULL),num_new_files)
      H_minibatch = list()
    }
    # A = HiHi^t, B = XiHit
    A_old = list()
    B_old = list()

    if (is.null(X_new)) {
      object@A = rep(list(matrix(0, k, k)), num_new_files)
      object@B = rep(list(matrix(0, num_genes, k)), num_new_files)
      A_old = rep(list(matrix(0, k, k)), num_new_files) # save information older than 2 epochs
      B_old = rep(list(matrix(0, num_genes, k)), num_new_files) # save information older than 2 epochs

    } else {
      object@A[file_idx_prev] = if(!is.null(A.init)) A.init else object@A
      object@B[file_idx_prev] = if(!is.null(B.init)) B.init else object@B
      A_old[file_idx_prev] = rep(list(NULL), num_prev_files)
      B_old[file_idx_prev] = rep(list(NULL), num_prev_files)
      object@A[(num_prev_files+1):num_files] = rep(list(matrix(0, k, k)), num_new_files)
      object@B[(num_prev_files+1):num_files] = rep(list(matrix(0, num_genes, k)), num_new_files)
      A_old[(num_prev_files+1):num_files] = rep(list(matrix(0, k, k)), num_new_files) # save information older than 2 epochs
      B_old[(num_prev_files+1):num_files] = rep(list(matrix(0, num_genes, k)), num_new_files) # save information older than 2 epochs
    }

    iter = 1
    epoch = rep(0, num_files) # intialize the number of epoch for each dataset
    epoch_prev = rep(0, num_files) # intialize the previous number of epoch for each dataset
    epoch_next = rep(FALSE, num_files)
    sqrt_lambda = sqrt(lambda)
    total_time = 0 # track the total amount of time used for the online learning


    num_chunks = rep(NULL, num_files)
    chunk_idx = rep(list(NULL), num_files)
    all_idx = rep(list(NULL), num_files)

    # chunk permutation
    for (i in file_idx_new){
      num_chunks[i] = ceiling(num_cells[i]/h5_chunk_size)
      chunk_idx[[i]] = sample(1:num_chunks[i],num_chunks[i])
      # idx in the first chunk
      if(chunk_idx[[i]][1]!=num_chunks[i]){
        all_idx[[i]] = (1+h5_chunk_size*(chunk_idx[[i]][1]-1)):(chunk_idx[[i]][1]*h5_chunk_size)
      } else {
        all_idx[[i]] = (1+h5_chunk_size*(chunk_idx[[i]][1]-1)):(num_cells[i])
      }

      for (j in chunk_idx[[i]][-1]){
        if (j != num_chunks[i]){
          all_idx[[i]] = c(all_idx[[i]],(1+h5_chunk_size*(j-1)):(j*h5_chunk_size))
        } else {
          all_idx[[i]] = c(all_idx[[i]],(1+h5_chunk_size*(j-1)):num_cells[i])
        }
      }
    }

    total.iters = floor(sum(num_cells_new) * max.epochs / miniBatch_size)
    if (verbose) {
      cat("Starting Online iNMF...", "\n")
      pb <- txtProgressBar(min = 1, max = total.iters+1, style = 3)
    }

    while(epoch[file_idx_new[1]] < max.epochs) {
      # track epochs
      minibatch_idx = rep(list(NULL), num_files) # indices of samples in each dataest used for this iteration
      if ((max.epochs * num_cells_new[1] - (iter-1) * minibatch_sizes[file_idx_new[1]]) >= minibatch_sizes[file_idx_new[1]]){ # check if the size of the last mini-batch == pre-specified mini-batch size
        for (i in file_idx_new){
          epoch[i] = (iter * minibatch_sizes[i]) %/% num_cells[i] # caculate the current epoch
          if ((epoch_prev[i] != epoch[i]) & ((iter * minibatch_sizes[i]) %% num_cells[i] != 0)){ # if current iter cycles through the data and start a new cycle
            epoch_next[i] = TRUE
            epoch_prev[i] = epoch[i]
            # shuffle dataset before the next epoch
            minibatch_idx[[i]] = all_idx[[i]][c(((((iter - 1) * minibatch_sizes[i]) %% num_cells[i]) + 1):num_cells[i])]
            chunk_idx[[i]] = sample(1:num_chunks[i],num_chunks[i])
            all_idx[[i]] = 0
            for (j in chunk_idx[[i]]){
              if (j != num_chunks[i]){
                all_idx[[i]] = c(all_idx[[i]],(1+h5_chunk_size*(j-1)):(j*h5_chunk_size))
              }else{
                all_idx[[i]] = c(all_idx[[i]],(1+h5_chunk_size*(j-1)):num_cells[i])
              }
            }
            all_idx[[i]] = all_idx[[i]][-1] # remove the first element 0
            minibatch_idx[[i]] = c(minibatch_idx[[i]],all_idx[[i]][1:((iter * minibatch_sizes[i]) %% num_cells[i])])

          } else if ((epoch_prev[i] != epoch[i]) & ((iter * minibatch_sizes[i]) %% num_cells[i] == 0)){ # if current iter finishes this cycle without start a a new cycle
            epoch_next[i] = TRUE
            epoch_prev[i] = epoch[i]

            minibatch_idx[[i]] = all_idx[[i]][((((iter-1) * minibatch_sizes[i]) %% num_cells[i]) + 1):num_cells[i]]
            chunk_idx[[i]] = sample(1:num_chunks[i],num_chunks[i])
            all_idx[[i]] = 0
            for (j in chunk_idx[[i]]){
              if (j != num_chunks[i]){
                all_idx[[i]] = c(all_idx[[i]],(1+h5_chunk_size*(j-1)):(j*h5_chunk_size))
              }else{
                all_idx[[i]] = c(all_idx[[i]],(1+h5_chunk_size*(j-1)):num_cells[i])
              }
            }
            all_idx[[i]] = all_idx[[i]][-1] # remove the first element 0
          } else {                                                                        # if current iter stays within a single cycle
            minibatch_idx[[i]] = all_idx[[i]][(((iter-1) * minibatch_sizes[i]) %% num_cells[i] + 1):((iter * minibatch_sizes[i]) %% num_cells[i])]
          }
        }
      } else {
        for (i in file_idx_new){
          minibatch_sizes[i] = max.epochs * num_cells[i] - (iter-1) * minibatch_sizes[i]
          minibatch_idx[[i]] = (((iter-1) * minibatch_sizes_orig[i] + 1) %% num_cells[i]):num_cells[i]
        }
        epoch[file_idx_new[1]] = max.epochs # last epoch
      }


      if (length(minibatch_idx[[file_idx_new[1]]]) == minibatch_sizes_orig[file_idx_new[1]]){
        X_minibatch = rep(list(NULL), num_files)
        for (i in file_idx_new){
          X_minibatch[[i]] = object@scale.data[[i]][1:num_genes ,minibatch_idx[[i]]]
        }

        # update H_i by ANLS Hi_minibatch[[i]]
        H_minibatch = rep(list(NULL), num_files)
        for (i in file_idx_new){
          H_minibatch[[i]] = solveNNLS(rbind(object@W + object@V[[i]], sqrt_lambda * object@V[[i]]),
                                       rbind(X_minibatch[[i]], matrix(0, num_genes, minibatch_sizes[i])))
        }

        # updata A and B matrices
        if (iter == 1){
          scale_param = c(rep(0, num_prev_files), rep(0, num_new_files))
        } else if(iter == 2){
          scale_param = c(rep(0, num_prev_files), rep(1, num_new_files) / minibatch_sizes[file_idx_new])
        } else {
          scale_param = c(rep(0, num_prev_files), rep((iter - 2) / (iter - 1), num_new_files))
        }


        if (epoch[file_idx_new[1]] > 0 & epoch_next[file_idx_new[1]] == TRUE){ # remove information older than 2 epochs
          for (i in file_idx_new){
            object@A[[i]] = object@A[[i]] - A_old[[i]]
            A_old[[i]] = scale_param[i] * object@A[[i]]
            object@B[[i]] = object@B[[i]] - B_old[[i]]
            B_old[[i]] = scale_param[i] * object@B[[i]]
          }
        } else{ # otherwise scale the old information
          for (i in file_idx_new){
            A_old[[i]] = scale_param[i] * A_old[[i]]
            B_old[[i]] = scale_param[i] * B_old[[i]]
          }
        }

        for (i in file_idx_new){
          object@A[[i]] = scale_param[i] * object@A[[i]] + H_minibatch[[i]] %*% t(H_minibatch[[i]]) / minibatch_sizes[i]   # HiHit
          diag(object@A[[i]])[diag(object@A[[i]])==0] = 1e-15
          object@B[[i]] = scale_param[i] * object@B[[i]] + X_minibatch[[i]] %*% t(H_minibatch[[i]]) / minibatch_sizes[i]   # XiHit
        }


        # update W, V_i by HALS
        iter_miniBatch = 1
        delta_miniBatch = Inf
        max_iters_miniBatch = miniBatch_max_iters

        while(iter_miniBatch <= max_iters_miniBatch){
          # update W
          for (j in 1:k){
            W_update_numerator = rep(0, num_genes)
            W_update_denominator = 0
            for (i in file_idx){
              W_update_numerator = W_update_numerator + object@B[[i]][, j] - (object@W + object@V[[i]]) %*% object@A[[i]][, j]
              W_update_denominator = W_update_denominator +  object@A[[i]][j,j]
            }

            object@W[, j] = nonneg(object@W[, j] + W_update_numerator / W_update_denominator)
          }

          # update V_i
          for (j in 1:k){
            for (i in file_idx_new){
              object@V[[i]][, j] = nonneg(object@V[[i]][, j] + (object@B[[i]][, j] - (object@W + (1 + lambda) * object@V[[i]]) %*% object@A[[i]][, j]) /
                                            ((1 + lambda) * object@A[[i]][j, j]))
            }
          }

          iter_miniBatch = iter_miniBatch + 1
        }
        epoch_next = rep(FALSE, num_files) # reset epoch change indicator
        iter = iter + 1
        if (verbose) {
          setTxtProgressBar(pb = pb, value = iter)
        }
      }
    }
    if (verbose) {
      cat("\nCalculate metagene loadings...", "\n")
    }
    object@H = rep(list(NULL), num_files)
    for (i in file_idx){
      if (num_cells[i] %% miniBatch_size == 0) num_batch = num_cells[i] %/% miniBatch_size else num_batch = num_cells[i] %/% miniBatch_size + 1
      if (num_batch == 1){
        X_i = object@scale.data[[i]][1:num_genes,]
        object@H[[i]] = solveNNLS(rbind(object@W + object@V[[i]],sqrt_lambda * object@V[[i]]), rbind(X_i, matrix(0, num_genes , num_cells[i])))
      } else {
        for (batch_idx in 1:num_batch){
          if (batch_idx != num_batch){
            cell_idx = ((batch_idx - 1) * miniBatch_size + 1):(batch_idx * miniBatch_size)
          } else {
            cell_idx = ((batch_idx - 1) * miniBatch_size + 1):num_cells[i]
          }
          X_i_batch = object@scale.data[[i]][1:num_genes,cell_idx]
          object@H[[i]] = cbind(object@H[[i]], solveNNLS(rbind(object@W + object@V[[i]], sqrt_lambda * object@V[[i]]),
                                                         rbind(X_i_batch, matrix(0, num_genes , length(cell_idx)))))
        }
      }
      colnames(object@H[[i]]) = cell_barcodes[[i]]
    }

    rownames(object@W) = gene_names
    colnames(object@W) = NULL

    for (i in file_idx){
      rownames(object@V[[i]]) = gene_names
      colnames(object@V[[i]]) = NULL
    }

  } else {
    if (verbose) {
      cat("Metagene projection", "\n")
    }
    object@W = if(!is.null(W.init)) W.init else object@W
    object@H[file_idx_new] = rep(list(NULL), num_new_files)
    object@V[file_idx_new] = rep(list(NULL), num_new_files)
    for (i in file_idx_new){
      if (num_cells[i] %% miniBatch_size == 0) num_batch = num_cells[i] %/% miniBatch_size else num_batch = num_cells[i] %/% miniBatch_size + 1
      if (num_cells[i] <= miniBatch_size){
        object@H[[i]] = solveNNLS(object@W, object@scale.data[[i]][1:num_genes,])
      } else {
        for (batch_idx in 1:num_batch){
          if (batch_idx != num_batch){
            cell_idx = ((batch_idx - 1) * miniBatch_size + 1):(batch_idx * miniBatch_size)
          } else {
            cell_idx = ((batch_idx - 1) * miniBatch_size + 1):num_cells[i]
          }
          object@H[[i]] = cbind(object@H[[i]],solveNNLS(object@W, object@scale.data[[i]][1:num_genes,cell_idx]))
        }
      }
      colnames(object@H[[i]]) = cell_barcodes[[i]]
      object@V[[i]] = matrix(0, num_genes, k)
    }
  }

  # gene x k -> k x gene & k x cell -> cell x k
  object@W = t(object@W)
  object@V = lapply(object@V, t)
  object@H = lapply(object@H, t)
  for (i in 1:length(object@raw.data)){
    if (class(object@raw.data[[i]])[1] != "H5File") object@scale.data[[i]] = t(object@scale.data[[i]])
  }

  if (!is.null(X_new)){
    names(object@scale.data) <- names(object@raw.data) <- c(names(raw.data_prev), names(X_new))
  }
  names(object@H) <- names(object@V) <- names(object@raw.data)
  return(object)
}


#' Perform thresholding on dense matrix
#'
#' @description
#' Perform thresholding on the input dense matrix. Remove any values samller than eps by eps.
#' Helper function for online_iNMF
#'
#' @param x Dense matrix.
#' @param eps Threshold. Should be a small positive value. (default 1e-16)
#' @return Dense matrix with smallest values equal to eps.
#' @noRd
nonneg <- function(x, eps = 1e-16) {
  x[x < eps] = eps
  return(x)
}


#' Perform iNMF on scaled datasets
#'
#' @description
#' Perform integrative non-negative matrix factorization to return factorized H, W, and V matrices.
#' It optimizes the iNMF objective function using block coordinate descent (alternating non-negative
#' least squares), where the number of factors is set by k. TODO: include objective function
#' equation here in documentation (using deqn)
#'
#' For each dataset, this factorization produces an H matrix (cells by k), a V matrix (k by genes),
#' and a shared W matrix (k by genes). The H matrices represent the cell factor loadings.
#' W is held consistent among all datasets, as it represents the shared components of the metagenes
#' across datasets. The V matrices represent the dataset-specific components of the metagenes.
#'
#' @param object \code{liger} object. Should normalize, select genes, and scale before calling.
#' @param k Inner dimension of factorization (number of factors). Run suggestK to determine
#'   appropriate value; a general rule of thumb is that a higher k will be needed for datasets with
#'   more sub-structure.
#' @param lambda Regularization parameter. Larger values penalize dataset-specific effects more
#'   strongly (ie. alignment should increase as lambda increases). Run suggestLambda to determine
#'   most appropriate value for balancing dataset alignment and agreement (default 5.0).
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh.
#'   (default 1e-6)
#' @param max.iters Maximum number of block coordinate descent iterations to perform (default 30).
#' @param nrep Number of restarts to perform (iNMF objective function is non-convex, so taking the
#'   best objective from multiple successive initializations is recommended). For easier
#'   reproducibility, this increments the random seed by 1 for each consecutive restart, so future
#'   factorizations of the same dataset can be run with one rep if necessary. (default 1)
#' @param H.init Initial values to use for H matrices. (default NULL)
#' @param W.init Initial values to use for W matrix (default NULL)
#' @param V.init Initial values to use for V matrices (default NULL)
#' @param rand.seed Random seed to allow reproducible results (default 1).
#' @param print.obj Print objective function values after convergence (default FALSE).
#' @param use.unshared Whether to run UANLS method to integrate datasets with previously identified unshared variable genes. Have to run selectGenes with unshared = TRUE and scaleNotCenter it. (default FALSE).
#' @param verbose Print progress bar/messages (TRUE by default)
#' @param ... Arguments passed to other methods
#'
#' @return \code{liger} object with H, W, and V slots set.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Minimum specification for fast example pass
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
optimizeALS <- function(
  object,
  ...
) {
  UseMethod(generic = 'optimizeALS', object = object)
}

#' @rdname optimizeALS
#' @importFrom stats runif
#' @importFrom utils setTxtProgressBar txtProgressBar
#' @export
#' @method optimizeALS list
optimizeALS.list <- function(
  object,
  k,
  lambda = 5.0,
  thresh = 1e-6,
  max.iters = 30,
  nrep = 1,
  H.init = NULL,
  W.init = NULL,
  V.init = NULL,
  use.unshared = FALSE,
  rand.seed = 1,
  print.obj = FALSE,
  verbose = TRUE,
  ...
) {
  if (!all(sapply(X = object, FUN = is.matrix))) {
    stop("All values in 'object' must be a matrix")
  }
  E <- object
  N <- length(x = E)
  ns <- sapply(X = E, FUN = nrow)
  #if (k >= min(ns)) {
  #  stop('Select k lower than the number of cells in smallest dataset: ', min(ns))
  #}
  tmp <- gc()
  g <- ncol(x = E[[1]])
  if (k >= g) {
    stop('Select k lower than the number of variable genes: ', g)
  }
  W_m <- matrix(data = 0, nrow = k, ncol = g)
  V_m <- lapply(
    X = 1:N,
    FUN = function(i) {
      return(matrix(data = 0, nrow = k, ncol = g))
    }
  )
  H_m <- lapply(
    X = ns,
    FUN = function(n) {
      return(matrix(data = 0, nrow = n, ncol = k))
    }
  )
  tmp <- gc()
  best_obj <- Inf
  run_stats <- matrix(data = 0, nrow = nrep, ncol = 2)
  for (i in 1:nrep) {
    set.seed(seed = rand.seed + i - 1)
    start_time <- Sys.time()
    W <- matrix(
      data = abs(x = runif(n = g * k, min = 0, max = 2)),
      nrow = k,
      ncol = g
    )

    V <- lapply(
      X = 1:N,
      FUN = function(i) {
        return(matrix(
          data = abs(x = runif(n = g * k, min = 0, max = 2)),
          nrow = k,
          ncol = g
        ))
      }
    )

    H <- lapply(
      X = ns,
      FUN = function(n) {
        return(matrix(
          data = abs(x = runif(n = n * k, min = 0, max = 2)),
          nrow = n,
          ncol = k
        ))
      }
    )
    tmp <- gc()
    if (!is.null(x = W.init)) {
      W <- W.init
    }
    if (!is.null(x = V.init)) {
      V <- V.init
    }
    if (!is.null(x = H.init)) {
      H <- H.init
    }
    delta <- 1
    iters <- 0
    pb <- txtProgressBar(min = 0, max = max.iters, style = 3)
    sqrt_lambda <- sqrt(x = lambda)
    obj0 <- sum(sapply(
      X = 1:N,
      FUN = function(i) {
        return(norm(x = E[[i]] - H[[i]] %*% (W + V[[i]]), type = "F") ^ 2)
      }
    )) +
      sum(sapply(
        X = 1:N,
        FUN = function(i) {
          return(lambda * norm(x = H[[i]] %*% V[[i]], type = "F") ^ 2)
        }
      ))
    tmp <- gc()
    while (delta > thresh & iters < max.iters) {
      H <- lapply(
        X = 1:N,
        FUN = function(i) {
          return(t(x = solveNNLS(
            C = rbind(t(x = W) + t(x = V[[i]]), sqrt_lambda * t(x = V[[i]])),
            B = rbind(t(x = E[[i]]), matrix(data = 0, nrow = g, ncol = ns[i]))
          )))
        }
      )
      tmp <- gc()
      V <- lapply(
        X = 1:N,
        FUN = function(i) {
          return(solveNNLS(
            C = rbind(H[[i]], sqrt_lambda * H[[i]]),
            B = rbind(E[[i]] - H[[i]] %*% W, matrix(data = 0, nrow = ns[[i]], ncol = g))
          ))
        }
      )
      tmp <- gc()
      W <- solveNNLS(
        C = rbindlist(mat_list = H),
        B = rbindlist(mat_list = lapply(
          X = 1:N,
          FUN = function(i) {
            return(E[[i]] - H[[i]] %*% V[[i]])
          }
        ))
      )
      tmp <- gc()
      obj <- sum(sapply(
        X = 1:N,
        FUN = function(i) {
          return(norm(x = E[[i]] - H[[i]] %*% (W + V[[i]]), type = "F") ^ 2)
        }
      )) +
        sum(sapply(
          X = 1:N,
          FUN = function(i) {
            return(lambda * norm(x = H[[i]] %*% V[[i]], type = "F") ^ 2)
          }
        ))
      tmp <- gc()
      delta <- abs(x = obj0 - obj) / (mean(obj0, obj))
      obj0 <- obj
      iters <- iters + 1
      setTxtProgressBar(pb = pb, value = iters)
    }
    setTxtProgressBar(pb = pb, value = max.iters)
    # if (iters == max.iters) {
    #   print("Warning: failed to converge within the allowed number of iterations.
    #         Re-running with a higher max.iters is recommended.")
    # }
    if (obj < best_obj) {
      W_m <- W
      H_m <- H
      V_m <- V
      best_obj <- obj
      best_seed <- rand.seed + i - 1
    }
    end_time <- difftime(time1 = Sys.time(), time2 = start_time, units = "auto")
    run_stats[i, 1] <- as.double(x = end_time)
    run_stats[i, 2] <- iters
    if (verbose) {
      cat(
        "\nFinished in ",
        run_stats[i, 1],
        " ",
        units(x = end_time),
        ", ",
        iters,
        " iterations.\n",
        "Max iterations set: ",
        max.iters,
        ".\n",
        "Final objective delta: ",
        delta,
        '.\n',
        sep = ""
      )
    }

    if (verbose) {
      if (print.obj) {
        cat("Objective:", obj, "\n")
      }
      cat("Best results with seed ", best_seed, ".\n", sep = "")
    }
  }
  out <- list()
  out$H <- H_m
  for (i in 1:length(x = object)) {
    rownames(x = out$H[[i]]) <- rownames(x = object[[i]])
  }
  out$V <- V_m
  names(x = out$V) <- names(x = out$H) <- names(x = object)
  out$W <- W_m
  return(out)
}

#' @importFrom methods slot<-
#'
#' @rdname optimizeALS
#' @export
#' @method optimizeALS liger
#'
optimizeALS.liger <- function(
  object,
  k,
  lambda = 5.0,
  thresh = 1e-6,
  max.iters = 30,
  nrep = 1,
  H.init = NULL,
  W.init = NULL,
  V.init = NULL,
  use.unshared = FALSE,
  rand.seed = 1,
  print.obj = FALSE,
  verbose = TRUE,
  ...
) {

  if (isFALSE(use.unshared)){
    object <- removeMissingObs(
    object = object,
    slot.use = 'scale.data',
    use.cols = FALSE,
    verbose = TRUE
  )
    out <- optimizeALS(
      object = object@scale.data,
      k = k,
      lambda = lambda,
      thresh = thresh,
      max.iters = max.iters,
      nrep = nrep,
      H.init = H.init,
      W.init = W.init,
      V.init = V.init,
      use.unshared = FALSE,
      rand.seed = rand.seed,
      print.obj = print.obj,
      verbose = verbose
    )
    names(x = out$H) <- names(x = out$V) <- names(x = object@raw.data)
    for (i in 1:length(x = object@scale.data)) {
      rownames(x = out$H[[i]]) <- rownames(x = object@scale.data[[i]])
    }
    colnames(x = out$W) <- object@var.genes
    for (i in names(x = out)) {
      slot(object = object, name = i) <- out[[i]]
    }
    object@parameters$lambda <- lambda
    return(object)
  }
  if(isTRUE(use.unshared)){
    object <- optimize_UANLS(object = object,
                             k = k,
                             lambda = lambda,
                             thresh = thresh,
                             max.iters = max.iters,
                             nrep = nrep,
                             rand.seed = rand.seed,
                             print.obj = print.obj)
  }
}

#' Perform factorization for new value of k
#'
#' This uses an efficient strategy for updating that takes advantage of the information in the
#' existing factorization. It is most recommended for values of k smaller than current value,
#' where it is more likely to speed up the factorization.
#'
#' @param object \code{liger} object. Should call optimizeALS before calling.
#' @param k.new Inner dimension of factorization (number of factors)
#' @param lambda Regularization parameter. By default, this will use the lambda last used with
#'   optimizeALS.
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh
#'   (default 1e-4).
#' @param max.iters Maximum number of block coordinate descent iterations to perform (default 100).
#' @param rand.seed Random seed to set. Only relevant if k.new > k. (default 1)
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return \code{liger} object with H, W, and V slots reset.
#'
#' @importFrom plyr rbind.fill.matrix
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' k <- 5
#' # Minimum specification for fast example pass
#' ligerex <- optimizeALS(ligerex, k = k, max.iters = 1)
#' if (k != 5) {
#'     ligerex <- optimizeNewK(ligerex, k.new = k, max.iters = 1)
#' }
optimizeNewK <- function(object, k.new, lambda = NULL, thresh = 1e-4, max.iters = 100,
                         rand.seed = 1, verbose = TRUE) {
  if (is.null(lambda)) {
    lambda <- object@parameters$lambda
  }
  k <- ncol(object@H[[1]])
  if (k.new == k) {
    return(object)
  }
  H <- object@H
  W <- object@W
  V <- object@V

  if (k.new > k) {
    set.seed(rand.seed)
    sqrt_lambda <- sqrt(lambda)
    g <- ncol(W)
    N <- length(H)
    ns <- sapply(H, nrow)
    W_new <- matrix(abs(runif(g * k, 0, 2)), k.new - k, g)
    V_new <- lapply(1:N, function(i) {
      matrix(abs(runif(g * (k.new - k), 0, 2)), k.new - k, g)
    })
    H_new <- lapply(ns, function(n) {
      matrix(abs(runif(n * (k.new - k), 0, 2)), n, k.new - k)
    })
    H_new <- lapply(1:N, function(i) {
      t(solveNNLS(
        rbind(t(W_new) + t(V_new[[i]]), sqrt_lambda * t(V_new[[i]])),
        rbind(
          t(object@scale.data[[i]] - H[[i]] %*% (W + V[[i]])),
          matrix(0, nrow = g, ncol = ns[i])
        )
      ))
    })
    V_new <- lapply(1:N, function(i) {
      solveNNLS(
        rbind(H_new[[i]], sqrt_lambda * H_new[[i]]),
        rbind(
          object@scale.data[[i]] - H[[i]] %*% (W + V[[i]]) - H_new[[i]] %*% W_new,
          matrix(0, nrow = ns[[i]], ncol = g)
        )
      )
    })
    W_new <- solveNNLS(
      rbind.fill.matrix(H_new),
      rbind.fill.matrix(lapply(1:N, function(i) {
        object@scale.data[[i]] - H[[i]] %*% (W + V[[i]]) - H_new[[i]] %*% V_new[[i]]
      }))
    )
    H <- lapply(1:N, function(i) {
      cbind(H[[i]], H_new[[i]])
    })
    V <- lapply(1:N, function(i) {
      rbind(V[[i]], V_new[[i]])
    })
    W <- rbind(W, W_new)
  }
  else {
    deltas <- rep(0, k)
    for (i in 1:length(object@H))
    {
      deltas <- deltas + sapply(1:k, function(x) {
        norm(H[[i]][, k] %*% t(W[k, ] + V[[i]][k, ]), "F")
      })
    }
    k.use <- order(deltas, decreasing = TRUE)[1:k.new]
    W <- W[k.use, ]
    H <- lapply(H, function(x) {
      x[, k.use]
    })
    V <- lapply(V, function(x) {
      x[k.use, ]
    })
  }
  object <- optimizeALS(object, k.new,
                        lambda = lambda, thresh = thresh, max.iters = max.iters, H.init = H,
                        W.init = W, V.init = V, rand.seed = rand.seed, verbose = verbose)
  return(object)
}

#' Perform factorization for new data
#'
#' Uses an efficient strategy for updating that takes advantage of the information in the existing
#' factorization. Assumes that selected genes (var.genes) are represented in the new datasets.
#'
#' @param object \code{liger} object. Should call optimizeALS before calling.
#' @param new.data List of raw data matrices (one or more). Each list entry should be named.
#' @param which.datasets List of datasets to append new.data to if add.to.existing is true.
#'   Otherwise, the most similar existing datasets for each entry in new.data.
#' @param add.to.existing Add the new data to existing datasets or treat as totally new datasets
#'   (calculate new Vs?) (default TRUE)
#' @param lambda Regularization parameter. By default, this will use the lambda last used with
#'   optimizeALS.
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh
#'   (default 1e-4).
#' @param max.iters Maximum number of block coordinate descent iterations to perform (default 100).
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return \code{liger} object with H, W, and V slots reset. Raw.data, norm.data, and scale.data will
#'   also be updated to include the new data.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' \donttest{
#' # Assume we are performing the factorization
#' # Specification for minimal example test time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' # Suppose we have new data, namingly Y_new and Z_new from the same cell type.
#' # Add it to existing datasets.
#' new_data <- list(Y_set = ctrl, Z_set = stim)
#' # 2 iters do not lead to converge, it's for minimal test time
#' ligerex2 <- optimizeNewData(ligerex, new.data = new_data,
#'                             which.datasets = list('ctrl', 'stim'),
#'                             max.iters = 1)
#' # acquire new data from different cell type (X), we'll just add another dataset
#' # it's probably most similar to ctrl
#' X <- ctrl
#' # 2 iters do not lead to converge, it's for minimal test time
#' ligerex3 <- optimizeNewData(ligerex, new.data = list(x_set = X),
#'                             which.datasets = list('ctrl'),
#'                             add.to.existing = FALSE,
#'                             max.iters = 1)
#' }
optimizeNewData <- function(object, new.data, which.datasets, add.to.existing = TRUE, lambda = NULL,
                            thresh = 1e-4, max.iters = 100, verbose = TRUE) {
  if (is.null(lambda)) {
    lambda <- object@parameters$lambda
  }
  if (add.to.existing) {
    for (i in 1:length(new.data)) {
      if (verbose) {
        message(dim(object@raw.data[[which.datasets[[i]]]]))
      }
      object@raw.data[[which.datasets[[i]]]] <- cbind(
        object@raw.data[[which.datasets[[i]]]],
        new.data[[i]]
      )
      if (verbose) {
        message(dim(object@raw.data[[which.datasets[[i]]]]))
      }
    }
    object <- normalize(object)
    object <- scaleNotCenter(object)
    sqrt_lambda <- sqrt(lambda)
    g <- ncol(object@W)
    H_new <- lapply(1:length(new.data), function(i) {
      t(solveNNLS(
        rbind(
          t(object@W) + t(object@V[[which.datasets[[i]]]]),
          sqrt_lambda * t(object@V[[which.datasets[[i]]]])
        ),
        rbind(
          t(object@scale.data[[which.datasets[[i]]]][colnames(new.data[[i]]), ]),
          matrix(0, nrow = g, ncol = ncol(new.data[[i]]))
        )
      ))
    })
    for (i in 1:length(new.data)) {
      object@H[[which.datasets[[i]]]] <- rbind(object@H[[which.datasets[[i]]]], H_new[[i]])
    }
  } else {
    old.names <- names(object@raw.data)
    new.names <- names(new.data)
    combined.names <- c(old.names, new.names)
    for (i in 1:length(which.datasets)) {
      object@V[[names(new.data)[i]]] <- object@V[[which.datasets[[i]]]]
    }
    object@raw.data <- c(object@raw.data, new.data)
    names(object@raw.data) <- names(object@V) <- combined.names
    object <- normalize(object)
    object <- scaleNotCenter(object)
    ns <- lapply(object@raw.data, ncol)
    N <- length(ns)
    g <- ncol(object@W)
    sqrt_lambda <- sqrt(lambda)
    for (i in 1:N) {
      print(ns[[i]])
      print(dim(object@raw.data[[i]]))
      print(dim(object@norm.data[[i]]))
      print(dim(object@scale.data[[i]]))
      print(dim(object@V[[i]]))
    }
    H_new <- lapply(1:length(new.data), function(i) {
      t(solveNNLS(rbind(t(object@W) + t(object@V[[new.names[i]]]),
                        sqrt_lambda * t(object@V[[new.names[i]]])),
                  rbind(t(object@scale.data[[new.names[i]]]),
                        matrix(0, nrow = g, ncol = ncol(new.data[[i]])))
      )
      )
    })
    object@H <- c(object@H, H_new)
    names(object@H) <- combined.names
  }
  k <- ncol(object@H[[1]])
  object <- optimizeALS(object, k, lambda, thresh, max.iters,
                        H.init = object@H, W.init = object@W,
                        V.init = object@V, verbose = verbose)
  return(object)
}

#' Perform factorization for subset of data
#'
#' Uses an efficient strategy for updating that takes advantage of the information in the existing
#' factorization. Can use either cell names or cluster names to subset. For more basic subsetting
#' functionality (without automatic optimization), see subsetLiger.
#'
#' @param object \code{liger} object. Should call optimizeALS before calling.
#' @param cell.subset List of cell names to retain from each dataset (same length as number of
#'   datasets).
#' @param cluster.subset Clusters for which to keep cells (ie. c(1, 5, 6)). Should pass in either
#'   cell.subset or cluster.subset but not both.
#' @param lambda Regularization parameter. By default, uses last used lambda.
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh
#'   (default 1e-4).
#' @param max.iters Maximum number of block coordinate descent iterations to perform (default 100).
#' @param datasets.scale Names of datasets to rescale after subsetting (default NULL).
#'
#' @return \code{liger} object with H, W, and V slots reset. Scale.data
#'   (if desired) will also be updated to reflect the subset.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' \donttest{
#' # Assume we are performing the factorization
#' # Specification for minimal example run time, not converging.
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' # Preparing subset with random sampling.
#' # Subset can also be obtained with prior knowledge from metadata.
#' cell_names_1 <- sample(rownames(ligerex@H[[1]]), 20)
#' cell_names_2 <- sample(rownames(ligerex@H[[2]]), 20)
#'
#' ligerex2 <- optimizeSubset(ligerex, cell.subset = list(cell_names_1, cell_names_2),
#'                            max.iters = 1)
#' }
optimizeSubset <- function(object, cell.subset = NULL, cluster.subset = NULL, lambda = NULL,
                           thresh = 1e-4, max.iters = 100, datasets.scale = NULL) {
  if (is.null(lambda)) {
    lambda <- object@parameters$lambda
  }
  if (is.null(cell.subset) & is.null(cluster.subset)) {
    stop("Please specify a cell subset or cluster subset.")
  }
  else if (is.null(cell.subset) & !is.null(cluster.subset)) {
    cell.subset <- lapply(1:length(object@scale.data), function(i) {
      which(object@clusters[rownames(object@scale.data[[i]])] %in% cluster.subset)
    })
  }
  old_names <- names(object@raw.data)
  H <- object@H
  H <- lapply(1:length(object@H), function(i) {
    object@H[[i]][cell.subset[[i]], ]
  })
  object@raw.data <- lapply(1:length(object@raw.data), function(i) {
    object@raw.data[[i]][, cell.subset[[i]]]
  })
  all.cell.subset <- Reduce(c, cell.subset)
  object@cell.data <- droplevels(object@cell.data[all.cell.subset, ])
  for (i in 1:length(object@norm.data)) {
    object@norm.data[[i]] <- object@norm.data[[i]][, cell.subset[[i]]]
    if (names(object@norm.data)[i] %in% datasets.scale) {
      object@scale.data[[i]] <- scale(t(object@norm.data[[i]][object@var.genes, ]),
                                      scale = TRUE, center = FALSE)
      object@scale.data[[i]][is.na(object@scale.data[[i]])] <- 0
    } else {
      object@scale.data[[i]] <- as.matrix(t(object@norm.data[[i]][object@var.genes, ]))
    }
  }

  names(object@raw.data) <- names(object@norm.data) <- names(object@H) <- old_names
  k <- ncol(H[[1]])
  object <- optimizeALS(object, k = k, lambda = lambda, thresh = thresh, max.iters = max.iters,
                        H.init = H, W.init = object@W, V.init = object@V)
  return(object)
}

#' Perform factorization for new lambda value
#'
#' Uses an efficient strategy for updating that takes advantage of the information in the existing
#' factorization; uses previous k. Recommended mainly when re-optimizing for higher lambda and when
#' new lambda value is significantly different; otherwise may not return optimal results.
#'
#' @param object \code{liger} object. Should call optimizeALS before calling.
#' @param new.lambda Regularization parameter. Larger values penalize dataset-specific effects more
#' strongly.
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh
#' @param max.iters Maximum number of block coordinate descent iterations to perform (default 100).
#' @param rand.seed Random seed for reproducibility (default 1).
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return \code{liger} object with optimized factorization values
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' \donttest{
#' # Assume we are performing the factorization
#' # Specification for minimal example run time, not converging.
#' ligerex <- optimizeALS(ligerex, k = 5, lambda = 5, max.iters = 1)
#' # decide to run with lambda = 15 instead (keeping k the same)
#' ligerex <- optimizeNewLambda(ligerex, new.lambda = 15, max.iters = 1)
#' }
optimizeNewLambda <- function(object, new.lambda, thresh = 1e-4, max.iters = 100, rand.seed = 1, verbose = TRUE) {
  k <- ncol(object@H[[1]])
  H <- object@H
  W <- object@W
  if (new.lambda < object@parameters$lambda && verbose) {
    message("New lambda less than current lambda; new factorization may not be optimal. ",
            "Re-optimization with optimizeAlS recommended instead.")
  }
  object <- optimizeALS(object, k, lambda = new.lambda, thresh = thresh, max.iters = max.iters,
                        H.init = H, W.init = W, rand.seed = rand.seed, verbose = verbose)
  return(object)
}

#' Visually suggest appropriate lambda value
#'
#' Can be used to select appropriate value of lambda for factorization of particular dataset. Plot
#' alignment and agreement for various test values of lambda. Most appropriate lambda
#' is likely around the "elbow" of the alignment plot (when alignment stops increasing). This will
#' likely also correspond to slower decrease in agreement. Depending on number of cores used,
#' this process can take 10-20 minutes.
#'
#' @param object \code{liger} object. Should normalize, select genes, and scale before calling.
#' @param k Number of factors to use in test factorizations. See optimizeALS documentation.
#' @param lambda.test Vector of lambda values to test. If not given, use default set spanning
#'   0.25 to 60
#' @param rand.seed Random seed for reproducibility (default 1).
#' @param num.cores Number of cores to use for optimizing factorizations in parallel (default 1).
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh
#' @param max.iters Maximum number of block coordinate descent iterations to perform
#' @param knn_k Number of nearest neighbors for within-dataset knn in quantileAlignSNF (default 20).
#' @param k2 Horizon parameter for quantileAlignSNF (default 500).
#' @param ref_dataset Reference dataset for quantileAlignSNF (defaults to larger dataset).
#' @param resolution Resolution for quantileAlignSNF (default 1).
#' @param gen.new Do not use optimizeNewLambda in factorizations. Recommended to set TRUE
#'   when looking at only a small range of lambdas (ie. 1:7) (default FALSE)
#' @param nrep Number restarts to perform at each lambda value tested (increase to produce
#'   smoother curve if results unclear) (default 1).
#' @param return.data Whether to return list of data matrices (raw) or dataframe (processed)
#'   instead of ggplot object (default FALSE).
#' @param return.raw If return.results TRUE, whether to return raw data (in format described below),
#'   or dataframe used to produce ggplot object. Raw data is matrix of alignment values for each
#'   lambda value tested (each column represents a different rep for nrep).(default FALSE)
#' @param verbose Print progress bar/messages (TRUE by default)
#' @return Matrix of results if indicated or ggplot object. Plots alignment vs. lambda to console.
#' @import doParallel
#' @import parallel
#' @importFrom foreach foreach
#' @importFrom foreach "%dopar%"
#' @importFrom ggplot2 ggplot aes geom_point geom_line guides guide_legend labs theme theme_classic
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' suggestLambda(ligerex, k = 20, lambda.test = c(5, 10), max.iters = 1)
#' }
suggestLambda <- function(object, k, lambda.test = NULL, rand.seed = 1, num.cores = 1, thresh = 1e-4,
                          max.iters = 100, knn_k = 20, k2 = 500, ref_dataset = NULL, resolution = 1,
                          gen.new = FALSE, nrep = 1, return.data = FALSE, return.raw = FALSE, verbose = TRUE) {
  if (is.null(lambda.test)) {
    lambda.test <- c(seq(0.25, 1, 0.25), seq(2, 10, 1), seq(15, 60, 5))
  }
  time_start <- Sys.time()
  # optimize smallest lambda value first to take advantage of efficient updating
  if (verbose) {
    message("This operation may take several minutes depending on number of values being tested")
  }
  rep_data <- list()
  for (r in 1:nrep) {
    if (verbose) {
      message("Preprocessing for rep ", r,
              ": optimizing initial factorization with smallest test lambda=",
              lambda.test[1])
    }
    object <- optimizeALS(object, k = k, thresh = thresh, lambda = lambda.test[1],
                          max.iters = max.iters, nrep = 1, rand.seed = (rand.seed + r - 1), verbose = verbose)
    if (verbose) {
      message('Testing different choices of lambda values')
    }
    #cl <- makeCluster(num.cores)
    cl <- parallel::makeCluster(num.cores)
    #registerDoSNOW(cl)
    doParallel::registerDoParallel(cl)
    #pb <- txtProgressBar(min = 0, max = length(lambda.test), style = 3, initial = 1, file = "")
    # define progress bar function
    #progress <- function(n) setTxtProgressBar(pb, n)
    #opts <- list(progress = progress)
    i <- 0
    data_matrix <- foreach(i = 1:length(lambda.test), .combine = "rbind",
                           .packages = 'rliger') %dopar% {
                             if (i != 1) {
                               if (gen.new) {
                                 ob.test <- optimizeALS(object,
                                                        k = k, lambda = lambda.test[i], thresh = thresh,
                                                        max.iters = max.iters, rand.seed = (rand.seed + r - 1)
                                 )
                               } else {
                                 ob.test <- optimizeNewLambda(object,
                                                              new.lambda = lambda.test[i], thresh = thresh,
                                                              max.iters = max.iters, rand.seed = (rand.seed + r - 1)
                                 )
                               }
                             } else {
                               ob.test <- object
                             }
                             ob.test <- quantileAlignSNF(ob.test, knn_k = knn_k,
                                                         k2 = k2, resolution = resolution,
                                                         ref_dataset = ref_dataset,
                                                         id.number = i
                             )
                             calcAlignment(ob.test)
                           }
    #close(pb)
    parallel::stopCluster(cl)
    rep_data[[r]] <- data_matrix
  }

  aligns <- Reduce(cbind, rep_data)
  if (is.null(dim(aligns))) {
    aligns <- matrix(aligns, ncol = 1)
  }
  mean_aligns <- apply(aligns, 1, mean)

  time_elapsed <- difftime(Sys.time(), time_start, units = "auto")
  if (verbose) {
    cat(paste("\nCompleted in:", as.double(time_elapsed), units(time_elapsed)))
  }
  # make dataframe
  df_al <- data.frame(align = mean_aligns, lambda = lambda.test)

  p1 <- ggplot(df_al, aes_string(x = 'lambda', y = 'mean_aligns')) + geom_line(size=1) +
    geom_point() +
    theme_classic() + labs(y = 'Alignment', x = 'Lambda') +
    guides(col = guide_legend(title = "", override.aes = list(size = 2))) +
    theme(legend.position = 'top')

  if (return.data) {
    print(p1)
    if (return.raw) {
      rownames(aligns) <- lambda.test
      return(aligns)
    }
    return(df_al)
  }
  return(p1)
}

#' Visually suggest appropiate k value
#'
#' @description
#' This can be used to select appropriate value of k for factorization of particular dataset.
#' Plots median (across cells in all datasets) K-L divergence from uniform for cell factor loadings
#' as a function of k. This should increase as k increases but is expected to level off above
#' sufficiently high number of factors (k). This is because cells should have factor loadings which
#' are not uniformly distributed when an appropriate number of factors is reached.
#'
#' Depending on number of cores used, this process can take 10-20 minutes.
#'
#' @param object \code{liger} object. Should normalize, select genes, and scale before calling.
#' @param k.test Set of factor numbers to test (default seq(5, 50, 5)).
#' @param lambda Lambda to use for all foctorizations (default 5).
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh
#' @param max.iters Maximum number of block coordinate descent iterations to perform
#' @param num.cores Number of cores to use for optimizing factorizations in parallel (default 1)
#' @param rand.seed Random seed for reproducibility (default 1).
#' @param gen.new Do not use optimizeNewK in factorizations. Results in slower factorizations.
#'   (default FALSE).
#' @param nrep Number restarts to perform at each k value tested (increase to produce
#'   smoother curve if results unclear) (default 1).
#' @param plot.log2 Plot log2 curve for reference on K-L plot (log2 is upper bound and con
#'   sometimes help in identifying "elbow" of plot). (default TRUE)
#' @param return.data Whether to return list of data matrices (raw) or dataframe (processed)
#'   instead of ggplot object (default FALSE).
#' @param return.raw If return.results TRUE, whether to return raw data (in format described below),
#'   or dataframe used to produce ggplot object. Raw data is list of matrices of K-L divergences
#'   (length(k.test) by n_cells). Length of list corresponds to nrep. (default FALSE)
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return Matrix of results if indicated or ggplot object. Plots K-L divergence vs. k to console.
#'
#' @import doParallel
#' @import parallel
#' @importFrom foreach foreach
#' @importFrom foreach "%dopar%"
#' @importFrom ggplot2 ggplot aes geom_point geom_line guides guide_legend labs theme theme_classic
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' suggestK(ligerex, k.test = c(5,6), max.iters = 1)
#' }
suggestK <- function(object, k.test = seq(5, 50, 5), lambda = 5, thresh = 1e-4, max.iters = 100,
                     num.cores = 1, rand.seed = 1, gen.new = FALSE, nrep = 1, plot.log2 = TRUE,
                     return.data = FALSE, return.raw = FALSE, verbose = TRUE) {
  if (length(object@scale.data) == 0) {
    stop("scaleNotCenter should be run on the object before running suggestK.")
  }
  time_start <- Sys.time()
  # optimize largest k value first to take advantage of efficient updating
  if (verbose) {
    message("This operation may take several minutes depending on number of values being tested")
  }
  rep_data <- list()
  for (r in 1:nrep) {
    if (verbose) {
      message("Preprocessing for rep ", r,
              ": optimizing initial factorization with largest test k=",
              k.test[length(k.test)])
    }
    object <- optimizeALS(object, k = k.test[length(k.test)], lambda = lambda, thresh = thresh,
                          max.iters = max.iters, nrep = 1, rand.seed = (rand.seed + r - 1))
    if (verbose) {
      message('Testing different choices of k')
    }
    cl <- parallel::makeCluster(num.cores)
    doParallel::registerDoParallel(cl)
    #pb <- txtProgressBar(min = 0, max = length(k.test), style = 3, initial = 1, file = "")
    # define progress bar function
    #progress <- function(n) setTxtProgressBar(pb, n)
    #opts <- list(progress = progress)
    i <- 0
    data_matrix <- foreach(i = length(k.test):1, .combine = "rbind",
                           .packages = 'rliger') %dopar% {
                             if (i != length(k.test)) {
                               if (gen.new) {
                                 ob.test <- optimizeALS(object,
                                                        k = k.test[i], lambda = lambda, thresh = thresh,
                                                        max.iters = max.iters, rand.seed = (rand.seed + r - 1)
                                 )
                               } else {
                                 ob.test <- optimizeNewK(object,
                                                         k.new = k.test[i], lambda = lambda, thresh = thresh,
                                                         max.iters = max.iters, rand.seed = (rand.seed + r - 1)
                                 )
                               }
                             } else {
                               ob.test <- object
                             }
                             dataset_split <- kl_divergence_uniform(ob.test)
                             unlist(dataset_split)
                           }
    #close(pb)
    parallel::stopCluster(cl)
    data_matrix <- data_matrix[nrow(data_matrix):1, ]
    rep_data[[r]] <- data_matrix
  }

  medians <- Reduce(cbind, lapply(rep_data, function(x) {apply(x, 1, median)}))
  if (is.null(dim(medians))) {
    medians <- matrix(medians, ncol = 1)
  }
  mean_kls <- apply(medians, 1, mean)

  time_elapsed <- difftime(Sys.time(), time_start, units = "auto")
  if (verbose) {
    cat(paste("\nCompleted in:", as.double(time_elapsed), units(time_elapsed)))
  }
  # make dataframe
  df_kl <- data.frame(median_kl = c(mean_kls, log2(k.test)), k = c(k.test, k.test),
                      calc = c(rep('KL_div', length(k.test)), rep('log2(k)', length(k.test))))
  if (!plot.log2) {
    df_kl <- df_kl[df_kl$calc == 'KL_div', ]
  }

  p1 <- ggplot(df_kl, aes_string(x = 'k', y = 'median_kl', col = 'calc')) + geom_line(size=1) +
    geom_point() +
    theme_classic() + labs(y='Median KL divergence (across all cells)', x = 'K') +
    guides(col=guide_legend(title="", override.aes = list(size = 2))) +
    theme(legend.position = 'top')

  if (return.data) {
    print(p1)
    if (return.raw) {
      rep_data <- lapply(rep_data, function(x) {
        rownames(x) <- k.test
        return(x)
      })
      return(rep_data)
    }
    return(df_kl)
  }
  return(p1)
}


#######################################################################################
#### Quantile Alignment/Normalization
#' Quantile align (normalize) factor loadings
#'
#' This process builds a shared factor neighborhood graph to jointly cluster cells, then quantile
#' normalizes corresponding clusters.
#'
#' The first step, building the shared factor neighborhood graph, is performed in SNF(), and
#' produces a graph representation where edge weights between cells (across all datasets)
#' correspond to their similarity in the shared factor neighborhood space. An important parameter
#' here is knn_k, the number of neighbors used to build the shared factor space.
#'
#' Next we perform quantile alignment for each dataset, factor, and cluster (by
#' stretching/compressing datasets' quantiles to better match those of the reference dataset). These
#' aligned factor loadings are combined into a single matrix and returned as H.norm.
#'
#' @param object \code{liger} object. Should run optimizeALS before calling.
#' @param knn_k Number of nearest neighbors for within-dataset knn graph (default 20).
#' @param ref_dataset Name of dataset to use as a "reference" for normalization. By default,
#'   the dataset with the largest number of cells is used.
#' @param min_cells Minimum number of cells to consider a cluster shared across datasets (default 20)
#' @param quantiles Number of quantiles to use for quantile normalization (default 50).
#' @param eps  The error bound of the nearest neighbor search. (default 0.9) Lower values give more
#' accurate nearest neighbor graphs but take much longer to computer.
#' @param dims.use Indices of factors to use for shared nearest factor determination (default
#'   1:ncol(H[[1]])).
#' @param do.center Centers the data when scaling factors (useful for less sparse modalities like
#'   methylation data). (default FALSE)
#' @param max_sample Maximum number of cells used for quantile normalization of each cluster
#' and factor. (default 1000)
#' @param refine.knn whether to increase robustness of cluster assignments using KNN graph.(default TRUE)
#' @param rand.seed Random seed to allow reproducible results (default 1)
#' @param ... Arguments passed to other methods
#'
#' @return \code{liger} object with 'H.norm' and 'clusters' slot set.
#' @importFrom stats approxfun
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
quantile_norm <- function(
  object,
  ...
) {
  UseMethod(generic = 'quantile_norm', object = object)
}

#' @rdname quantile_norm
#' @export
#' @method quantile_norm list
#'
quantile_norm.list <- function(
  object,
  quantiles = 50,
  ref_dataset = NULL,
  min_cells = 20,
  knn_k = 20,
  dims.use = NULL,
  do.center = FALSE,
  max_sample = 1000,
  eps = 0.9,
  refine.knn = TRUE,
  rand.seed = 1,
  ...
) {
  set.seed(rand.seed)
  if (!all(sapply(X = object, FUN = is.matrix))) {
    stop("All values in 'object' must be a matrix")
  }
  if (is.null(x = names(x = object))) {
    stop("'object' must be a named list of matrices")
  }
  if (is.character(x = ref_dataset) && !ref_dataset %in% names(x = object)) {
    stop("Cannot find reference dataset")
  } else if (!inherits(x = ref_dataset, what = c('character', 'numeric'))) {
    stop("'ref_dataset' must be a character or integer specifying which dataset is the reference")
  }
  labels <- list()
  if (is.null(dims.use)) {
    use_these_factors <- 1:ncol(object[[1]])
  } else {
    use_these_factors <- dims.use
  }
  # fast max factor assignment with Rcpp code
  labels <- lapply(object, max_factor, dims_use = use_these_factors, center_cols = do.center)
  clusters <- as.factor(unlist(lapply(labels, as.character)))
  names(clusters) <- unlist(lapply(object, rownames))

  # increase robustness of cluster assignments using knn graph
  if (refine.knn) {
    clusters <- refine_clusts_knn(object, clusters, k = knn_k, eps = eps)
  }
  cluster_assignments <- clusters
  clusters <- lapply(object, function(x) {
    clusters[rownames(x)]
  })
  names(clusters) <- names(object)
  dims <- ncol(object[[ref_dataset]])

  dataset <- unlist(lapply(1:length(object), function(i) {
    rep(names(object)[i], nrow(object[[i]]))
  }))
  Hs <- object
  num_clusters <- dims
  for (k in 1:length(object)) {
    for (j in 1:num_clusters) {
      cells2 <- which(clusters[[k]] == j)
      cells1 <- which(clusters[[ref_dataset]] == j)
      for (i in 1:dims) {
        num_cells2 <- length(cells2)
        num_cells1 <- length(cells1)
        if (num_cells1 < min_cells | num_cells2 < min_cells) {
          next
        }
        if (num_cells2 == 1) {
          Hs[[k]][cells2, i] <- mean(Hs[[ref_dataset]][cells1, i])
          next
        }
        q2 <- quantile(sample(Hs[[k]][cells2, i], min(num_cells2, max_sample)), seq(0, 1, by = 1 / quantiles))
        q1 <- quantile(sample(Hs[[ref_dataset]][cells1, i], min(num_cells1, max_sample)), seq(0, 1, by = 1 / quantiles))
        if (sum(q1) == 0 | sum(q2) == 0 | length(unique(q1)) <
            2 | length(unique(q2)) < 2) {
          new_vals <- rep(0, num_cells2)
        }
        else {
          warp_func <- withCallingHandlers(stats::approxfun(q2, q1, rule = 2), warning=function(w){invokeRestart("muffleWarning")})
          new_vals <- warp_func(Hs[[k]][cells2, i])
        }
        Hs[[k]][cells2, i] <- new_vals
      }
    }
  }
  out <- list(
    'H.norm' = Reduce(rbind, Hs),
    'clusters' = cluster_assignments
  )
  return(out)
}

#' @rdname quantile_norm
#' @export
#' @method quantile_norm liger
quantile_norm.liger <- function(
  object,
  quantiles = 50,
  ref_dataset = NULL,
  min_cells = 20,
  knn_k = 20,
  dims.use = NULL,
  do.center = FALSE,
  max_sample = 1000,
  eps = 0.9,
  refine.knn = TRUE,
  rand.seed = 1,
  ...
) {
  if (is.null(x = ref_dataset)) {
    ns <- sapply(X = object@H, FUN = nrow)
    ref_dataset <- names(x = object@H)[which.max(x = ns)]
  }
  out <- quantile_norm(
    object = object@H,
    quantiles = quantiles,
    ref_dataset = ref_dataset,
    min_cells = min_cells,
    knn_k = knn_k,
    dims.use = dims.use,
    do.center = do.center,
    max_sample = max_sample,
    eps = eps,
    refine.knn = refine.knn,
    rand.seed = rand.seed
  )
  for (i in names(x = out)) {
    slot(object = object, name = i) <- out[[i]]
  }
  return(object)
}

#' Louvain algorithm for community detection
#'
#' @description
#' After quantile normalization, users can additionally run the Louvain algorithm
#' for community detection, which is widely used in single-cell analysis and excels at merging
#' small clusters into broad cell classes.
#'
#' @param object \code{liger} object. Should run quantile_norm before calling.
#' @param k The maximum number of nearest neighbours to compute. (default 20)
#' @param resolution Value of the resolution parameter, use a value above (below) 1.0 if you want
#' to obtain a larger (smaller) number of communities. (default 1.0)
#' @param prune Sets the cutoff for acceptable Jaccard index when
#' computing the neighborhood overlap for the SNN construction. Any edges with
#' values less than or equal to this will be set to 0 and removed from the SNN
#' graph. Essentially sets the strigency of pruning (0 --- no pruning, 1 ---
#' prune everything). (default 1/15)
#' @param eps The error bound of the nearest neighbor search. (default 0.1)
#' @param nRandomStarts Number of random starts. (default 10)
#' @param nIterations Maximal number of iterations per random start. (default 100)
#' @param random.seed Seed of the random number generator. (default 1)
#' @param verbose Print messages (TRUE by default)
#' @param dims.use Indices of factors to use for Louvain clustering (default 1:ncol(H[[1]])).
#' @return \code{liger} object with refined 'clusters' slot set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- louvainCluster(ligerex, resolution = 0.3)
louvainCluster <- function(object, resolution = 1.0, k = 20, prune = 1 / 15, eps = 0.1, nRandomStarts = 10,
                           nIterations = 100, random.seed = 1, verbose = TRUE, dims.use = NULL) {
  tmpdir <- tempdir()
  output_path <- paste0('edge_', sub('\\s', '_', Sys.time()), '.txt')
  output_path = sub(":","_",output_path)
  output_path = sub(":","_",output_path)
  output_path <- file.path(tmpdir, output_path)

  if (is.null(dims.use)) {
    use_these_factors <- 1:ncol(object@H[[1]])
  } else {
    use_these_factors <- dims.use
  }

  if (dim(object@H.norm)[1] == 0){
    if (verbose) {
      message("Louvain Clustering on unnormalized cell factor loadings.")
    }
    knn <- RANN::nn2(Reduce(rbind, object@H)[,use_these_factors], k = k, eps = eps)
  } else {
    if (verbose) {
      message("Louvain Clustering on quantile normalized cell factor loadings.")
    }
    knn <- RANN::nn2(object@H.norm[,use_these_factors], k = k, eps = eps)
  }
  snn <- ComputeSNN(knn$nn.idx, prune = prune)
  WriteEdgeFile(snn, output_path, display_progress = FALSE)
  clusts <- RunModularityClusteringCpp(snn,
                                       modularityFunction = 1, resolution = resolution, nRandomStarts = nRandomStarts,
                                       nIterations = nIterations, algorithm = 1, randomSeed = random.seed, printOutput = FALSE,
                                       edgefilename = output_path
  )
  names(clusts) = rownames(object@cell.data)
  rownames(snn) = rownames(object@cell.data)
  colnames(snn) = rownames(object@cell.data)
  clusts <- GroupSingletons(ids = clusts, SNN = snn, verbose = FALSE)
  object@clusters = as.factor(clusts)
  unlink(output_path)
  return(object)
}

# Group single cells that make up their own cluster in with the cluster they are
# most connected to. (Adopted from Seurat v3)
#
# @param ids Named vector of cluster ids
# @param SNN SNN graph used in clustering
# @param group.singletons Group singletons into nearest cluster (TRUE by default). If FALSE, assign all singletons to
# @param verbose Print message
# a "singleton" group
#
# @return Returns updated cluster assignment with all singletons merged with most connected cluster
#
GroupSingletons <- function(ids, SNN, group.singletons = TRUE, verbose = FALSE) {
  # identify singletons
  singletons <- c()
  singletons <- names(x = which(x = table(ids) == 1))
  singletons <- intersect(x = unique(x = ids), singletons)
  if (!group.singletons) {
    ids[which(ids %in% singletons)] <- "singleton"
    return(ids)
  }
  # calculate connectivity of singletons to other clusters, add singleton
  # to cluster it is most connected to
  cluster_names <- as.character(x = unique(x = ids))
  cluster_names <- setdiff(x = cluster_names, y = singletons)
  connectivity <- vector(mode = "numeric", length = length(x = cluster_names))
  names(x = connectivity) <- cluster_names
  new.ids <- ids
  for (i in singletons) {
    i.cells <- names(which(ids == i))
    for (j in cluster_names) {
      j.cells <- names(which(ids == j))
      subSNN <- SNN[i.cells, j.cells]
      set.seed(1) # to match previous behavior, random seed being set in WhichCells
      if (is.object(x = subSNN)) {
        connectivity[j] <- sum(subSNN) / (nrow(x = subSNN) * ncol(x = subSNN))
      } else {
        connectivity[j] <- mean(x = subSNN)
      }
    }
    m <- max(connectivity, na.rm = TRUE)
    mi <- which(x = connectivity == m, arr.ind = TRUE)
    closest_cluster <- sample(x = names(x = connectivity[mi]), 1)
    ids[i.cells] <- closest_cluster
  }
  if (length(x = singletons) > 0 && verbose) {
    message(paste(
      length(x = singletons),
      "singletons identified.",
      length(x = unique(x = ids)),
      "final clusters."
    ))
  }
  return(ids)
}


#' Impute the query cell expression matrix
#'
#' Impute query features from a reference dataset using KNN.
#'
#' @param object \code{liger} object.
#' @param knn_k The maximum number of nearest neighbors to search. (default 20)
#' @param reference Dataset containing values to impute into query dataset(s).
#' @param queries Dataset to be augmented by imputation. If not specified, will pass in all datasets.
#' @param weight Whether to use KNN distances as weight matrix (default FALSE).
#' @param norm Whether normalize the imputed data with default parameters (default TRUE).
#' @param scale Whether scale but not center the imputed data with default parameters (default TRUE).
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return \code{liger} object with raw data in raw.data slot replaced by imputed data (genes by cells)
#'
#' @importFrom FNN get.knnx
#'
#' @export
#' @examples
#' \dontrun{
#' # Only runable for ATAC dataset. See tutorial on GitHub.
#' # ligerex (liger object), factorization complete
#' # impute every dataset other than the reference dataset
#' ligerex <- imputeKNN(ligerex, reference = "y_set", weight = FALSE)
#' # impute only z_set dataset
#' ligerex <- imputeKNN(ligerex, reference = "y_set", queries = list("z_set"), knn_k = 50)
#' }
imputeKNN <- function(object, reference, queries, knn_k = 20, weight = TRUE, norm = TRUE, scale = FALSE, verbose = TRUE) {
  if (verbose) {
    cat("NOTE: This function will discard the raw data previously stored in the liger object and",
        "replace the raw.data slot with the imputed data.\n\n")
  }

  if (length(reference) > 1) {
    stop("Can only have ONE reference dataset")
  }
  if (missing(queries)) { # all datasets
    queries <- names(object@raw.data)
    queries <- as.list(queries[!queries %in% reference])
    if (verbose) {
      cat(
        "Imputing ALL the datasets except the reference dataset\n",
        "Reference dataset:\n",
        paste("  ", reference, "\n"),
        "Query datasets:\n",
        paste("  ", as.character(queries), "\n")
      )
    }
  }
  else { # only given query datasets
    queries <- as.list(queries)
    if (reference %in% queries) {
      stop("Reference dataset CANNOT be inclued in the query datasets")
    }
    else {
      if (verbose) {
        cat(
          "Imputing given query datasets\n",
          "Reference dataset:\n",
          paste("  ", reference, "\n"),
          "Query datasets:\n",
          paste("  ", as.character(queries), "\n")
        )
      }
    }
  }

  reference_cells <- colnames(object@raw.data[[reference]]) # cells by genes
  for (query in queries) {
    query_cells <- colnames(object@raw.data[[query]])

    # creating a (reference cell numbers X query cell numbers) weights matrix for knn weights and unit weights
    nn.k <- get.knnx(object@H.norm[reference_cells, ], object@H.norm[query_cells, ], k = knn_k, algorithm = "CR")
    weights <- Matrix(0, nrow = ncol(object@raw.data[[reference]]), ncol = nrow(nn.k$nn.index), sparse = TRUE)
    if (isTRUE(weight)){ # for weighted situation
      # find nearest neighbors for query cell in normed ref datasets
      for (n in 1:nrow(nn.k$nn.index)) { # record ref-query cell-cell distances
        weights[nn.k$nn.index[n, ], n] <- exp(-nn.k$nn.dist[n, ]) / sum(exp(-nn.k$nn.dist[n, ]))
      }
    }
    else{ # for unweighted situation
      for (n in 1:nrow(nn.k$nn.index)) {
        weights[nn.k$nn.index[n, ], n] <- 1/knn_k # simply count the mean
      }
    }

    # (genes by ref cell num) multiply by the weight matrix (ref cell num by query cell num)
    imputed_vals <- object@raw.data[[reference]] %*% weights
    # assigning dimnames
    colnames(imputed_vals) <- query_cells
    rownames(imputed_vals) <- rownames(object@raw.data[[reference]])

    # formatiing the matrix
    if (class(object@raw.data[[reference]])[1] == "dgTMatrix" |
        class(object@raw.data[[reference]])[1] == "dgCMatrix") {
      imputed_vals <- as(imputed_vals, "dgCMatrix")
    } else {
      imputed_vals <- as.matrix(imputed_vals)
    }

    object@raw.data[[query]] <- imputed_vals
  }

  if (norm) {
    if (verbose) {
      cat('\nNormalizing data...\n')
    }
    object <- rliger::normalize(object)
  }
  if (scale) {
    if (verbose) {
      cat('Scaling (but not centering) data...')
    }
    object <- rliger::scaleNotCenter(object)
  }

  return(object)
}

#' Perform Wilcoxon rank-sum test
#'
#' Perform Wilcoxon rank-sum tests on specified dataset using given method.
#'
#' @param object \code{liger} object.
#' @param data.use This selects which dataset(s) to use. (default 'all')
#' @param compare.method This indicates the metric of the test. Either 'clusters' or 'datasets'.
#'
#' @return A 10-columns data.frame with test results.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- louvainCluster(ligerex, resolution = 0.3)
#' wilcox.results <- runWilcoxon(ligerex, compare.method = "clusters")
#' wilcox.results <- runWilcoxon(ligerex, compare.method = "datasets", data.use = c(1, 2))
#' if (length(ligerex@h5file.info) > 0) {
#'     # For HDF5 based object
#'     # Need to sample cells and read into memory before running Wilcoxon test
#'     ligerex <- readSubset(ligerex, slot.use = "norm.data", max.cells = 1000)
#'     wilcox.results <- runWilcoxon(ligerex, compare.method = "clusters")
#' }
runWilcoxon <- function(object, data.use = "all", compare.method = c("clusters", "datasets")) {
  # check parameter inputs
  compare.method <- match.arg(compare.method)
  if (compare.method == "datasets") {
    if (length(names(object@norm.data)) < 2) {
      stop("Should have at least TWO inputs to compare between datasets")
    }
    if (!missing(data.use) & length(data.use) < 2) {
      stop("Should have at least TWO inputs to compare between datasets")
    }
  }

  if (class(object@raw.data[[1]])[1] == "H5File"){
    if (is.null(object@h5file.info[[1]][["sample.data.type"]])){
      message("Need to sample data before Wilcoxon test for HDF5 input.")
    } else {
      message("Running Wilcoxon test on ", object@h5file.info[[1]][["sample.data.type"]])
    }
  }

  ### create feature x sample matrix
  if (data.use[1] == "all" | length(data.use) > 1) { # at least two datasets
    if (data.use[1] == "all") {
      message("Performing Wilcoxon test on ALL datasets: ", toString(names(object@norm.data)))
      sample.list <- attributes(object@norm.data)$names
    }
    else {
      message("Performing Wilcoxon test on GIVEN datasets: ", toString(data.use))
      sample.list <- data.use
    }
    # get all shared genes of every datasets
    genes <- Reduce(intersect, lapply(sample.list, function(sample) {
      if (class(object@norm.data[[sample]])[[1]] == "dgCMatrix")
      {
        return(object@norm.data[[sample]]@Dimnames[[1]])
      }
      else
      {
        return(rownames(object@sample.data[[sample]]))
      }
    }))
    if (class(object@norm.data[[sample.list[1]]])[[1]] == "dgCMatrix") {
      feature_matrix <- Reduce(cbind, lapply(sample.list, function(sample) {
        object@norm.data[[sample]][genes, ]
      })) # get feature matrix, shared genes as rows and all barcodes as columns
    } else {
      feature_matrix <- Reduce(cbind, object@sample.data[sample.list])
    }
    # get labels of clusters and datasets
    cell_source <- object@cell.data[["dataset"]] # from which dataset
    names(cell_source) <- names(object@clusters)
    cell_source <- cell_source[colnames(feature_matrix), drop = TRUE]
    clusters <- object@clusters[colnames(feature_matrix), drop = TRUE] # from which cluster
  } else { # for one dataset only
    message("Performing Wilcoxon test on GIVEN dataset: ", data.use)
    if (class(object@norm.data[[data.use]])[[1]] == "dgCMatrix") {
      feature_matrix <- object@norm.data[[data.use]]
      clusters <- object@clusters[object@norm.data[[data.use]]@Dimnames[[2]], drop = TRUE] # from which cluster
    } else {
      feature_matrix <- object@sample.data[[data.use]]
      clusters <- object@clusters[colnames(object@sample.data[[data.use]]), drop = TRUE] # from which cluster
    }
  }

  ### perform wilcoxon test
  if (compare.method == "clusters") { # compare between clusters across datasets
    len <- nrow(feature_matrix)
    if (len > 100000) {
      message("Calculating Large-scale Input...")
      results <- Reduce(rbind, lapply(suppressWarnings(split(seq(len), seq(len / 100000))), function(index) {
        wilcoxauc(log1p(1e6*feature_matrix[index, ]), clusters)
      }))
    } else {
      results <- wilcoxauc(log1p(1e6*feature_matrix), clusters)
    }
  }

  if (compare.method == "datasets") { # compare between datasets within each cluster
    results <- Reduce(rbind, lapply(levels(clusters), function(cluster) {
      sub_barcodes <- names(clusters[clusters == cluster]) # every barcode within this cluster
      sub_label <- paste0(cluster, "-", cell_source[sub_barcodes]) # data source for each cell
      sub_matrix <- feature_matrix[, sub_barcodes]
      if (length(unique(cell_source[sub_barcodes])) == 1) { # if cluster has only 1 data source
        message("Note: Skip Cluster ", cluster, " since it has only ONE data source.")
        return()
      }
      return(wilcoxauc(log1p(1e6*sub_matrix), sub_label))
    }))
  }
  return(results)
}

#' Linking genes to putative regulatory elements
#'
#' Evaluate the relationships between pairs of genes and peaks based on specified distance metric.
#'
#' @param gene_counts A gene expression matrix (genes by cells) of normalized counts.
#' This matrix has to share the same column names (cell barcodes) as the matrix passed to peak_counts
#' @param peak_counts A peak-level matrix (peaks by cells) of normalized accessibility values, such as the one resulting from imputeKNN.
#' This matrix must share the same column names (cell barcodes) as the matrix passed to gene_counts.
#' @param genes.list A list of the genes symbols to be tested. If not specified,
#' this function will use all the gene symbols from the matrix passed to gmat by default.
#' @param alpha Significance threshold for correlation p-value. Peak-gene correlations with p-values below
#' this threshold are considered significant. The default is 0.05.
#' @param dist This indicates the type of correlation to calculate -- one of “spearman” (default), "pearson", or "kendall".
#' @param path_to_coords Path to the gene coordinates file.
#' @param verbose Print messages (TRUE by default)
#'
#' @return a sparse matrix with peak names as rows and gene symbols as columns, with each element indicating the
#' correlation between peak i and gene j (or 0 if the gene and peak are not significantly linked).
#'
#' @importFrom utils read.csv2
#' @export
#' @examples
#' \dontrun{
#' # Only runable for ATAC datasets, see tutorial on GitHub
#' # some gene counts matrix: gmat.small
#' # some peak counts matrix: pmat.small
#' regnet <- linkGenesAndPeaks(gmat.small, pmat.small, dist = "spearman",
#' alpha = 0.05, path_to_coords = 'some_path')
#' }
linkGenesAndPeaks <- function(gene_counts, peak_counts, genes.list = NULL, dist = "spearman",
                              alpha = 0.05, path_to_coords, verbose = TRUE) {
  ## check dependency
  if (!requireNamespace("GenomicRanges", quietly = TRUE)) {
    stop("Package \"GenomicRanges\" needed for this function to work. Please install it by command:\n",
         "BiocManager::install('GenomicRanges')",
         call. = FALSE
    )
  }

  if (!requireNamespace("IRanges", quietly = TRUE)) {
    stop("Package \"IRanges\" needed for this function to work. Please install it by command:\n",
         "BiocManager::install('IRanges')",
         call. = FALSE
    )
  }

  ### make Granges object for peaks
  peak.names <- strsplit(rownames(peak_counts), "[:-]")
  chrs <- Reduce(append, lapply(peak.names, function(peak) {
    peak[1]
  }))
  chrs.start <- Reduce(append, lapply(peak.names, function(peak) {
    peak[2]
  }))
  chrs.end <- Reduce(append, lapply(peak.names, function(peak) {
    peak[3]
  }))
  peaks.pos <- GenomicRanges::GRanges(
    seqnames = chrs,
    ranges = IRanges::IRanges(as.numeric(chrs.start), end = as.numeric(chrs.end))
  )

  ### make Granges object for genes
  gene.names <- read.csv2(path_to_coords, sep = "\t", header = FALSE, stringsAsFactors = FALSE)
  gene.names <- gene.names[complete.cases(gene.names), ]
  genes.coords <- GenomicRanges::GRanges(
    seqnames = gene.names$V1,
    ranges = IRanges::IRanges(as.numeric(gene.names$V2), end = as.numeric(gene.names$V3))
  )
  names(genes.coords) <- gene.names$V4

  ### construct regnet
  gene_counts <- t(gene_counts) # cell x genes
  peak_counts <- t(peak_counts) # cell x genes

  # find overlap peaks for each gene
  if (missing(genes.list)) {
    genes.list <- colnames(gene_counts)
  }
  missing_genes <- !genes.list %in% names(genes.coords)
  if (sum(missing_genes)!=0 && verbose){
    message("Removing ", sum(missing_genes), " genes not found in given gene coordinates...")
  }
  genes.list <- genes.list[!missing_genes]
  genes.coords <- genes.coords[genes.list]

  if (verbose) {
    message("Calculating correlation for gene-peak pairs...")
  }
  each.len <- 0
  # assign('each.len', 0, envir = globalenv())

  elements <- lapply(seq(length(genes.list)), function(pos) {
    gene.use <- genes.list[pos]
    # re-scale the window for each gene
    gene.loci <- GenomicRanges::trim(suppressWarnings(GenomicRanges::promoters(GenomicRanges::resize(
      genes.coords[gene.use],
      width = 1, fix = "start"
    ),
    upstream = 500000, downstream = 500000
    )))
    peaks.use <- S4Vectors::queryHits(GenomicRanges::findOverlaps(peaks.pos, gene.loci))
    if ((x <- length(peaks.use)) == 0L) { # if no peaks in window, skip this iteration
      return(list(NULL, as.numeric(each.len), NULL))
    }
    ### compute correlation and p-adj for genes and peaks ###
    res <- suppressWarnings(psych::corr.test(
      x = gene_counts[, gene.use], y = as.matrix(peak_counts[, peaks.use]),
      method = dist, adjust = "holm", ci = FALSE, use = "complete"
    ))
    pick <- res[["p"]] < alpha # filter by p-value
    pick[is.na(pick)] <- FALSE

    if (sum(pick) == 0) { # if no peaks are important, skip this iteration
      return(list(NULL, as.numeric(each.len), NULL))
    }
    else {
      res.corr <- as.numeric(res[["r"]][pick])
      peaks.use <- peaks.use[pick]
    }
    # each.len <<- each.len + length(peaks.use)
    assign('each.len', each.len + length(peaks.use), envir = parent.frame(2))
    return(list(as.numeric(peaks.use), as.numeric(each.len), res.corr))
  })

  i_index <- Reduce(append, lapply(elements, function(ele) {
    ele[[1]]
  }))
  p_index <- c(0, Reduce(append, lapply(elements, function(ele) {
    ele[[2]]
  })))
  value_list <- Reduce(append, lapply(elements, function(ele) {
    ele[[3]]
  }))

  # make final sparse matrix
  regnet <- sparseMatrix(
    i = i_index, p = p_index, x = value_list,
    dims = c(ncol(peak_counts), length(genes.list)),
    dimnames = list(colnames(peak_counts), genes.list)
  )

  return(regnet)
}


#' Export predicted gene-pair interaction
#'
#' Export the predicted gene-pair interactions calculated by upstream function 'linkGenesAndPeaks'
#'  into an Interact Track file which is compatible with UCSC Genome Browser.
#'
#' @param corr.mat A sparse matrix with peak names as rows and gene symbols as columns.
#' @param genes.list A list of the genes symbols to be tested. If not specified,
#' this function will use all the gene symbols from the matrix passed to gmat by default.
#' @param output_path Path in which the output file will be stored.
#' @param path_to_coords Path to the gene coordinates file.
#'
#' @return An Interact Track file stored in the specified path.
#'
#' @importFrom stats complete.cases
#' @importFrom utils write.table
#'
#' @export
#' @examples
#' \dontrun{
#' # Only runable for ATAC datasets, see tutorial on GitHub
#' # some gene-peak correlation matrix: regent
#' makeInteractTrack(regnet, path_to_coords = 'some_path_to_gene_coordinates/hg19_genes.bed')
#' }
makeInteractTrack <- function(corr.mat, genes.list, output_path, path_to_coords) {
  # get genomic coordinates
  if (missing(path_to_coords)) {
    stop("Parameter 'path_to_coords' cannot be empty.")
  }

  ### make Granges object for genes
  genes.coords <- read.csv2(path_to_coords,
                            sep = "\t", header = FALSE, colClasses =
                              c("character", "integer", "integer", "character", "NULL", "NULL")
  )
  genes.coords <- genes.coords[complete.cases(genes.coords$V4), ]
  rownames(genes.coords) <- genes.coords[, 4]

  # split peak names into chrom and coordinates
  peak.names <- strsplit(rownames(corr.mat), "[:-]")
  chrs <- Reduce(append, lapply(peak.names, function(peak) {
    peak[1]
  }))
  chrs.start <- as.numeric(Reduce(append, lapply(peak.names, function(peak) {
    peak[2]
  })))
  chrs.end <- as.numeric(Reduce(append, lapply(peak.names, function(peak) {
    peak[3]
  })))

  # check genes.list
  if (missing(genes.list)) {
    genes.list <- colnames(corr.mat)
  }

  # check output_path
  if (missing(output_path)) {
    output_path <- getwd()
  }

  output_path <- paste0(output_path, "/Interact_Track.bed")
  track.doc <- paste0('track type=interact name="Interaction Track" description="Gene-Peaks Links"',
                      ' interactDirectional=true maxHeightPixels=200:100:50 visibility=full')
  write(track.doc, file = output_path)

  genes_not_existed <- 0
  filtered_genes <- 0

  for (gene in genes.list) {
    if (!gene %in% colnames(corr.mat)) { # if gene not in the corr.mat
      genes_not_existed <- genes_not_existed + 1
      next
    }
    peaks.sel <- which(corr.mat[, gene] != 0)
    if (sum(peaks.sel) == 0) {
      filtered_genes <- filtered_genes + 1
      next
    }

    track <- data.frame(
      chrom = chrs[peaks.sel],
      chromStart = chrs.start[peaks.sel],
      chromEnd = chrs.end[peaks.sel],
      name = paste0(gene, "/", rownames(corr.mat)[peaks.sel]),
      score = 0,
      value = as.numeric(corr.mat[peaks.sel, gene]),
      exp = ".",
      color = 5,
      sourceChrom = chrs[peaks.sel],
      sourceStart = chrs.start[peaks.sel],
      sourceEnd = chrs.start[peaks.sel] + 1,
      sourceName = ".",
      sourceStrand = ".",
      targetChrom = genes.coords[gene, 1],
      targetStart = genes.coords[gene, 2],
      targetEnd = genes.coords[gene, 2] + 1,
      targetName = gene,
      targetStrand = "."
    )
    write.table(track,
                file = output_path, append = TRUE,
                quote = FALSE, sep = "\t", eol = "\n", na = "NA", dec = ".",
                row.names = FALSE, col.names = FALSE, qmethod = c("escape", "double"),
                fileEncoding = ""
    )
  }

  message("A total of ", genes_not_existed, " genes do not exist in input matrix.")
  message("A total of ", filtered_genes, " genes do not have significant correlated peaks.")
  message("The Interaction Track is stored in Path: ", output_path)
}


#' Analyze biological interpretations of metagene
#'
#' Identify the biological pathways (gene sets from Reactome) that each metagene (factor) might belongs to.
#'
#' @param object \code{liger} object.
#' @param gene_sets A list of the Reactome gene sets names to be tested. If not specified,
#' this function will use all the gene sets from the Reactome by default
#' @param mat_w This indicates whether to use the shared factor loadings 'W' (default TRUE)
#' @param mat_v This indicates which V matrix to be added to the analysis. It can be a numeric number or a list
#' of the numerics.
#' @param custom_gene_sets A named list of character vectors of entrez gene ids. If not specified,
#' this function will use all the gene symbols from the input matrix by default
#'
#' @return A list of matrices with GSEA analysis for each factor
#'
#' @importFrom methods .hasSlot
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' result <- runGSEA(ligerex)
#' }
runGSEA <- function(object, gene_sets = c(), mat_w = TRUE, mat_v = 0, custom_gene_sets = c()) {
  if (!requireNamespace("org.Hs.eg.db", quietly = TRUE)) {
    stop("Package \"org.Hs.eg.db\" needed for this function to work. Please install it by command:\n",
         "BiocManager::install('org.Hs.eg.db')",
         call. = FALSE
    )
  }

  if (!requireNamespace("reactome.db", quietly = TRUE)) {
    stop("Package \"reactome.db\" needed for this function to work. Please install it by command:\n",
         "BiocManager::install('reactome.db')",
         call. = FALSE
    )
  }

  if (!requireNamespace("fgsea", quietly = TRUE)) {
    stop("Package \"fgsea\" needed for this function to work. Please install it by command:\n",
         "BiocManager::install('fgsea')",
         call. = FALSE
    )
  }

  if (length(mat_v) > length(object@V)) {
    stop("The gene loading input is invalid.", call. = FALSE)
  }

  if (!.hasSlot(object, "W") | !.hasSlot(object, "V")) {
    stop("There is no W or V matrix. Please do iNMF first.", call. = FALSE)
  }

  if (mat_w) {
    gene_loadings <- object@W
    if (mat_v) {
      gene_loadings <- gene_loadings + Reduce("+", lapply(mat_v, function(v) {
        object@V[[v]]
      }))
    }
  } else {
    gene_loadings <- Reduce("+", lapply(mat_v, function(v) {
      object@V[[v]]
    }))
  }

  gene_ranks <- t(apply(gene_loadings, MARGIN = 1, function(x) {
    rank(x)
  }))

  colnames(gene_ranks) <- sapply(colnames(gene_ranks), toupper)
  gene_id <- as.character(AnnotationDbi::mapIds(org.Hs.eg.db::org.Hs.eg.db, colnames(gene_ranks), "ENTREZID", "SYMBOL"))
  colnames(gene_ranks) <- gene_id
  gene_ranks <- gene_ranks[, !is.na(colnames(gene_ranks))]
  if (inherits((custom_gene_sets)[1], "tbl_df")) {
    pathways <- split(custom_gene_sets, x = custom_gene_sets$entrez_gene, f = custom_gene_sets$gs_name)
    pathways <- lapply(pathways, function(x) {
      as.character(x)
    })
  } else if (length(custom_gene_sets)) {
    pathways <- custom_gene_sets
  } else {
    pathways <- fgsea::reactomePathways(colnames(gene_ranks))
    if (length(gene_sets)) {
      pathways <- pathways[intersect(gene_sets, names(pathways))]
    }
  }
  # gsea <- list()
  gsea <- apply(gene_ranks, MARGIN = 1, function(x) {
    fgsea::fgsea(pathways, x, minSize = 15, maxSize = 500, nperm = 10000)
  })
  gsea <- lapply(gsea, function(x) {
    as.matrix(x[order(x$pval), ])
  })
  return(gsea)
}


#######################################################################################
#### Dimensionality Reduction

#' Perform t-SNE dimensionality reduction
#'
#' Runs t-SNE on the normalized cell factors (or raw cell factors) to generate a 2D embedding for
#' visualization. Has option to run on subset of factors. Note that running multiple times will
#' reset tsne.coords values.
#'
#' In order to run fftRtsne (recommended for large datasets), you must first install FIt-SNE as
#' detailed \href{https://github.com/KlugerLab/FIt-SNE}{here}. Include the path to the cloned
#' FIt-SNE directory as the fitsne.path parameter, though this is only necessary for the first call
#' to runTSNE. For more detailed FIt-SNE installation instructions, see the liger repo README.
#'
#' @param object \code{liger} object. Should run quantile_norm before calling with defaults.
#' @param use.raw Whether to use un-aligned cell factor loadings (H matrices) (default FALSE).
#' @param dims.use Factors to use for computing tSNE embedding (default 1:ncol(H.norm)).
#' @param use.pca Whether to perform initial PCA step for Rtsne (default FALSE).
#' @param perplexity Parameter to pass to Rtsne (expected number of neighbors) (default 30).
#' @param theta Speed/accuracy trade-off (increase for less accuracy), set to 0.0 for exact TSNE
#'   (default 0.5).
#' @param method Supports two methods for estimating tSNE values: Rtsne (Barnes-Hut implementation
#'   of t-SNE) and fftRtsne (FFT-accelerated Interpolation-based t-SNE) (using Kluger Lab
#'   implementation). (default Rtsne)
#' @param fitsne.path Path to the cloned FIt-SNE directory (ie. '/path/to/dir/FIt-SNE') (required
#'   for using fftRtsne -- only first time runTSNE is called) (default NULL).
#' @param rand.seed Random seed for reproducibility (default 42).
#'
#' @return \code{liger} object with tsne.coords slot set.
#'
#' @importFrom Rtsne Rtsne
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
runTSNE <- function(object, use.raw = FALSE, dims.use = 1:ncol(object@H.norm), use.pca = FALSE,
                    perplexity = 30, theta = 0.5, method = "Rtsne", fitsne.path = NULL,
                    rand.seed = 42) {
  if (use.raw) {
    data.use <- do.call(rbind, object@H)
    if (identical(dims.use, 1:0)) {
      dims.use <- 1:ncol(data.use)
    }
  } else {
    data.use <- object@H.norm
  }
  if (method == "Rtsne") {
    set.seed(rand.seed)
    object@tsne.coords <- Rtsne(data.use[, dims.use],
                                pca = use.pca, check_duplicates = FALSE,
                                theta = theta, perplexity = perplexity
    )$Y
  } else if (method == "fftRtsne") {
    # if (!exists('fftRtsne')) {
    #   if (is.null(fitsne.path)) {
    #     stop('Please pass in path to FIt-SNE directory as fitsne.path.')
    #   }
    # source(paste0(fitsne.path, '/fast_tsne.R'), chdir = TRUE)
    # }
    # object@tsne.coords <- fftRtsne(data.use[, dims.use], rand_seed = rand.seed,
    #                                theta = theta, perplexity = perplexity)
    object@tsne.coords <- fftRtsne(
      X = data.use[, dims.use], rand_seed = rand.seed,
      fast_tsne_path = fitsne.path, theta = theta,
      perplexity = perplexity
    )
  } else {
    stop("Invalid method: Please choose Rtsne or fftRtsne")
  }
  rownames(object@tsne.coords) <- rownames(data.use)
  return(object)
}

#' Perform UMAP dimensionality reduction
#'
#' @description
#' Run UMAP on the normalized cell factors (or raw cell factors) to generate a 2D embedding for
#' visualization (or general dimensionality reduction). Has option to run on subset of factors.
#' Note that running multiple times will overwrite tsne.coords values. It is generally
#' recommended to use this method for dimensionality reduction with extremely large datasets.
#'
#' Note that this method requires that the package uwot is installed. It does not depend
#' on reticulate or python umap-learn.
#'
#' @param object \code{liger} object. Should run quantile_norm before calling with defaults.
#' @param use.raw Whether to use un-aligned cell factor loadings (H matrices) (default FALSE).
#' @param dims.use Factors to use for computing tSNE embedding (default 1:ncol(H.norm)).
#' @param k Number of dimensions to reduce to (default 2).
#' @param distance Mtric used to measure distance in the input space. A wide variety of metrics are
#'   already coded, and a user defined function can be passed as long as it has been JITd by numba.
#'   (default "euclidean", alternatives: "cosine", "manhattan", "hamming")
#' @param n_neighbors Number of neighboring points used in local approximations of manifold
#'   structure. Larger values will result in more global structure being preserved at the loss of
#'   detailed local structure. In general this parameter should often be in the range 5 to 50, with
#'   a choice of 10 to 15 being a sensible default. (default 10)
#' @param min_dist Controls how tightly the embedding is allowed compress points together. Larger
#'   values ensure embedded points are more evenly distributed, while smaller values allow the
#'   algorithm to optimise more accurately with regard to local structure. Sensible values are in
#'   the range 0.001 to 0.5, with 0.1 being a reasonable default. (default 0.1)
#' @param rand.seed Random seed for reproducibility (default 42).
#'
#' @return \code{liger} object with tsne.coords slot set.
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' if (packageVersion("Matrix") <= package_version("1.6.1.1")) {
#'   ligerex <- runUMAP(ligerex)
#' }
#' }
runUMAP <- function(object, use.raw = FALSE, dims.use = 1:ncol(object@H.norm), k = 2,
                    distance = "euclidean", n_neighbors = 10, min_dist = 0.1, rand.seed = 42) {
  set.seed(rand.seed)
  if (use.raw) {
    raw.data <- do.call(rbind, object@H)
    # if H.norm not set yet
    if (identical(dims.use, 1:0)) {
      dims.use <- 1:ncol(raw.data)
    }
    object@tsne.coords <- uwot::umap(raw.data[, dims.use],
                                     n_components = as.integer(k), metric = distance,
                                     n_neighbors = as.integer(n_neighbors), min_dist = min_dist
    )
    rownames(object@tsne.coords) <- rownames(raw.data)
  } else {
    object@tsne.coords <- uwot::umap(object@H.norm[, dims.use],
                                     n_components = as.integer(k), metric = distance,
                                     n_neighbors = as.integer(n_neighbors), min_dist = min_dist
    )
    rownames(object@tsne.coords) <- rownames(object@H.norm)
  }
  return(object)
}


#######################################################################################
#### Metrics

#' Calculate a dataset-specificity score for each factor
#'
#' This score represents the relative magnitude of the dataset-specific components of each factor's
#' gene loadings compared to the shared components for two datasets. First, for each dataset we
#' calculate the norm of the sum of each factor's shared loadings (W) and dataset-specific loadings
#' (V). We then determine the ratio of these two values and subtract from 1... TODO: finish
#' description.
#'
#' @param object \code{liger} object. Should run optimizeALS before calling.
#' @param dataset1 Name of first dataset (by default takes first two datasets for dataset1 and 2)
#' @param dataset2 Name of second dataset
#' @param do.plot Display barplot of dataset specificity scores (by factor) (default TRUE).
#'
#' @return List containing three elements. First two elements are the norm of each metagene factor
#' for each dataset. Last element is the vector of dataset specificity scores.
#'
#' @importFrom graphics barplot
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' calcDatasetSpecificity(ligerex)
calcDatasetSpecificity <- function(object, dataset1 = NULL, dataset2 = NULL, do.plot = TRUE) {
  if (is.null(dataset1) | is.null(dataset2)) {
    dataset1 <- names(object@H)[1]
    dataset2 <- names(object@H)[2]
  }
  k <- ncol(object@H[[1]])
  pct1 <- rep(0, k)
  pct2 <- rep(0, k)
  for (i in 1:k) {
    pct1[i] <- norm(as.matrix(object@V[[dataset1]][i, ] + object@W[i, ]), "F")
    # norm(object@H[[1]][,i] %*% t(object@W[i,] + object@V[[1]][i,]),"F")
    pct2[i] <- norm(as.matrix(object@V[[dataset2]][i, ] + object@W[i, ]), "F")
    # norm(object@H[[2]][,i] %*% t(object@W[i,] + object@V[[2]][i,]),"F")
  }
  # pct1 = pct1/sum(pct1)
  # pct2 = pct2/sum(pct2)
  if (do.plot) {
    graphics::barplot(100 * (1 - (pct1 / pct2)),
                      xlab = "Factor",
                      ylab = "Percent Specificity", main = "Dataset Specificity of Factors",
                      names.arg = 1:k, cex.names = 0.75, mgp = c(2, 0.5, 0)
    ) # or possibly abs(pct1-pct2)
  }
  return(list(pct1, pct2, 100 * (1 - (pct1 / pct2))))
}

#' Calculate agreement metric
#'
#' @description
#' This metric quantifies how much the factorization and alignment distorts the geometry of the
#' original datasets. The greater the agreement, the less distortion of geometry there is. This is
#' calculated by performing dimensionality reduction on the original and quantile aligned (or just
#' factorized) datasets, and measuring similarity between the k nearest neighbors for each cell in
#' original and aligned datasets. The Jaccard index is used to quantify similarity, and is the final
#' metric averages across all cells.
#'
#' Note that for most datasets, the greater the chosen k, the greater the agreement in general.
#' There are several options for dimensionality reduction, with the default being 'NMF' as it is
#' expected to be most similar to iNMF. Although agreement can theoretically approach 1, in practice
#' it is usually no higher than 0.2-0.3 (particularly for non-deterministic approaches like NMF).
#'
#' @param object \code{liger} object. Should call quantile_norm before calling.
#' @param dr.method Dimensionality reduction method to use for assessing pre-alignment geometry
#'   (either "PCA", "NMF", or "ICA"). (default "NMF")
#' @param ndims Number of dimensions to use in dimensionality reduction (recommended to use the
#'   same as number of factors) (default 40).
#' @param k Number of nearest neighbors to use in calculating Jaccard index (default 15).
#' @param use.aligned Whether to use quantile aligned or unaligned cell factor loadings (default
#'   TRUE).
#' @param rand.seed Random seed for reproducibility (default 42).
#' @param by.dataset Return agreement calculated for each dataset (default FALSE).
#'
#' @return Agreement metric (or vector of agreement per dataset).
#'
#' @importFrom FNN get.knn
#' @importFrom ica icafast
#' @importFrom irlba prcomp_irlba
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' agreement <- calcAgreement(ligerex)
calcAgreement <- function(object, dr.method = "NMF", ndims = 40, k = 15, use.aligned = TRUE,
                          rand.seed = 42, by.dataset = FALSE) {
  # if (!requireNamespace("NNLM", quietly = TRUE) & dr.method == "NMF") {
  #   stop("Package \"NNLM\" needed for this function to perform NMF. Please install it.",
  #        call. = FALSE
  #   )
  # }
  if (class(object@raw.data[[1]])[1] == "H5File") {
    if (object@h5file.info[[1]][["sample.data.type"]] != "scale.data"){
      stop("HDF5-based Liger object requires sampled scale.data for calculating agreement.")
    }
  }

  message("Reducing dimensionality using ", dr.method)
  set.seed(rand.seed)
  dr <- list()
  if (dr.method == "NMF") {
    if (class(object@raw.data[[1]])[1] == "H5File") {
      dr <- lapply(object@sample.data, function(x) {
        nmf_hals(t(x), k = ndims)[[1]]
      })
    } else {
      dr <- lapply(object@scale.data, function(x) {
        nmf_hals(x, k = ndims)[[1]]
      })
    }
  }
  else if (dr.method == "ICA") {
    if (class(object@raw.data[[1]])[1] == "H5File") {
      dr <- list()
      for (i in 1:length(object@H)){
        dr[[i]] = icafast(t(object@sample.data[[i]]), nc = ndims)$S
      }

    } else {
      dr <- lapply(object@scale.data, function(x) {
        icafast(x, nc = ndims)$S
      })
    }
  } else {
    if (class(object@raw.data[[1]])[1] == "H5File") {
      dr <- list()
      for (i in 1:length(object@H)){
        dr[[i]] = suppressWarnings(prcomp_irlba(object@sample.data[[i]],
                                                n = ndims,
                                                scale. = (colSums(t(object@sample.data[[i]])) > 0), center = FALSE
        )$rotation)
        rownames(dr[[i]]) = colnames(object@sample.data[[i]])
      }

    } else {
      dr <- lapply(object@scale.data, function(x) {
        suppressWarnings(prcomp_irlba(t(x),
                                      n = ndims,
                                      scale. = (colSums(x) > 0), center = FALSE
        )$rotation)
      })
      for (i in 1:length(dr)) {
        rownames(dr[[i]]) <- rownames(object@scale.data[[i]])
      }
    }
  }
  ns <- sapply(dr, nrow)
  n <- sum(ns)
  jaccard_inds <- c()
  distorts <- c()

  for (i in 1:length(dr)) {
    jaccard_inds_i <- c()
    if (use.aligned) {
      original <- object@H.norm[rownames(dr[[i]]), ]
    } else {
      original <- object@H[[i]]
    }
    fnn.1 <- get.knn(dr[[i]], k = k)
    fnn.2 <- get.knn(original, k = k)
    jaccard_inds_i <- c(jaccard_inds_i, sapply(1:ns[i], function(i) {
      intersect <- intersect(fnn.1$nn.index[i, ], fnn.2$nn.index[i, ])
      union <- union(fnn.1$nn.index[i, ], fnn.2$nn.index[i, ])
      length(intersect) / length(union)
    }))
    jaccard_inds_i <- jaccard_inds_i[is.finite(jaccard_inds_i)]
    jaccard_inds <- c(jaccard_inds, jaccard_inds_i)

    distorts <- c(distorts, mean(jaccard_inds_i))
  }
  if (by.dataset) {
    return(distorts)
  }
  return(mean(jaccard_inds))
}

#' Calculate alignment metric
#'
#' This metric quantifies how well-aligned two or more datasets are. Alignment is defined as in the
#' documentation for Seurat. We randomly downsample all datasets to have as many cells as the
#' smallest one. We construct a nearest-neighbor graph and calculate for each cell how many of its
#' neighbors are from the same dataset. We average across all cells and compare to the expected
#' value for perfectly mixed datasets, and scale the value from 0 to 1. Note that in practice,
#' alignment can be greater than 1 occasionally.
#'
#' @param object \code{liger} object. Should call quantile_norm before calling.
#' @param k Number of nearest neighbors to use in calculating alignment. By default, this will be
#'   floor(0.01 * total number of cells), with a lower bound of 10 in all cases except where the
#'   total number of sampled cells is less than 10.
#' @param rand.seed Random seed for reproducibility (default 1).
#' @param cells.use Vector of cells across all datasets to use in calculating alignment
#' @param cells.comp Vector of cells across all datasets to compare to cells.use when calculating
#'   alignment (instead of dataset designations). These can be from the same dataset as cells.use.
#'   (default NULL)
#' @param clusters.use Names of clusters to use in calculating alignment (default NULL).
#' @param by.cell Return alignment calculated individually for each cell (default FALSE).
#' @param by.dataset Return alignment calculated for each dataset (default FALSE).
#'
#' @return Alignment metric.
#'
#' @importFrom FNN get.knn
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' agreement <- calcAlignment(ligerex)
calcAlignment <- function(object, k = NULL, rand.seed = 1, cells.use = NULL, cells.comp = NULL,
                          clusters.use = NULL, by.cell = FALSE, by.dataset = FALSE) {
  if (is.null(cells.use)) {
    cells.use <- rownames(object@H.norm)
  }
  if (!is.null(clusters.use)) {
    cells.use <- names(object@clusters)[which(object@clusters %in% clusters.use)]
  }
  if (!is.null(cells.comp)) {
    nmf_factors <- object@H.norm[c(cells.use, cells.comp), ]
    num_cells <- length(c(cells.use, cells.comp))
    func_H <- list(cells1 = nmf_factors[cells.use, ],
                   cells2 = nmf_factors[cells.comp, ])
    message('Using designated sets cells.use and cells.comp as subsets to compare')
  } else {
    nmf_factors <- object@H.norm[cells.use, ]
    num_cells <- length(cells.use)
    func_H <- lapply(seq_along(object@H), function(x) {
      cells.overlap <- intersect(cells.use, rownames(object@H[[x]]))
      if (length(cells.overlap) > 0) {
        object@H[[x]][cells.overlap, ]
      } else {
        warning(paste0("Selected subset eliminates dataset ", names(object@H)[x]),
                immediate. = TRUE
        )
        return(NULL)
      }
    })
    func_H <- func_H[!sapply(func_H, is.null)]
  }
  num_factors <- ncol(object@H.norm)
  N <- length(func_H)
  if (N == 1) {
    warning("Alignment null for single dataset", immediate. = TRUE)
  }
  set.seed(rand.seed)
  min_cells <- min(sapply(func_H, function(x) {
    nrow(x)
  }))
  sampled_cells <- unlist(lapply(1:N, function(x) {
    sample(rownames(func_H[[x]]), min_cells)
  }))
  max_k <- length(sampled_cells) - 1
  if (is.null(k)) {
    k <- min(max(floor(0.01 * num_cells), 10), max_k)
  } else if (k > max_k) {
    stop(paste0("Please select k <=", max_k))
  }
  knn_graph <- get.knn(nmf_factors[sampled_cells, 1:num_factors], k)
  # Generate new "datasets" for desired cell groups
  if (!is.null(cells.comp)) {
    dataset <- unlist(sapply(1:N, function(x) {
      rep(paste0('group', x), nrow(func_H[[x]]))
    }))
  } else {
    dataset <- unlist(sapply(1:N, function(x) {
      rep(names(object@H)[x], nrow(func_H[[x]]))
    }))
  }
  names(dataset) <- rownames(nmf_factors)
  dataset <- dataset[sampled_cells]

  num_sampled <- N * min_cells
  num_same_dataset <- rep(k, num_sampled)

  alignment_per_cell <- c()
  for (i in 1:num_sampled) {
    inds <- knn_graph$nn.index[i, ]
    num_same_dataset[i] <- sum(dataset[inds] == dataset[i])
    alignment_per_cell[i] <- 1 - (num_same_dataset[i] - (k / N)) / (k - k / N)
  }
  if (by.dataset) {
    alignments <- c()
    for (i in 1:N) {
      start <- 1 + (i - 1) * min_cells
      end <- i * min_cells
      alignment <- mean(alignment_per_cell[start:end])
      alignments <- c(alignments, alignment)
    }
    return(alignments)
  } else if (by.cell) {
    names(alignment_per_cell) <- sampled_cells
    return(alignment_per_cell)
  }
  return(mean(alignment_per_cell))
}

#' Calculate alignment for each cluster
#'
#' Returns alignment for each cluster in analysiss (see documentation for calcAlignment).
#'
#' @param object \code{liger} object. Should call quantile_norm before calling.
#' @param rand.seed Random seed for reproducibility (default 1).
#' @param k Number of nearest neighbors in calculating alignment (see calcAlignment for default).
#'   Can pass in single value or vector with same length as number of clusters.
#' @param by.dataset Return alignment calculated for each dataset in cluster (default FALSE).
#'
#' @return Vector of alignment statistics (with names of clusters).
#'
#' @importFrom FNN get.knn
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' agreement <- calcAlignmentPerCluster(ligerex)
calcAlignmentPerCluster <- function(object, rand.seed = 1, k = NULL, by.dataset = FALSE) {
  clusters <- levels(object@clusters)
  if (typeof(k) == "double") {
    if (length(k) == 1) {
      k <- rep(k, length(clusters))
    } else if (length(k) != length(clusters)) {
      stop("Length of k does not match length of clusters")
    }
  }
  align_metrics <- sapply(seq_along(clusters), function(x) {
    calcAlignment(object,
                  k = k[x], rand.seed = rand.seed,
                  clusters.use = clusters[x],
                  by.dataset = by.dataset
    )
  })
  if (by.dataset) {
    colnames(align_metrics) <- levels(object@clusters)
    rownames(align_metrics) <- names(object@H)
  } else {
    names(align_metrics) <- levels(object@clusters)
  }
  return(align_metrics)
}

#' Calculate adjusted Rand index
#'
#' Computes adjusted Rand index for \code{liger} clustering and external clustering.
#' The Rand index ranges from 0 to 1, with 0 indicating no agreement between clusterings and 1
#' indicating perfect agreement.
#'
#' @param object \code{liger} object. Should run quantileAlignSNF before calling.
#' @param clusters.compare Clustering with which to compare (named vector).
#' @param verbose Print messages (TRUE by default)
#'
#' @return Adjusted Rand index value.
#'
#' @importFrom mclust adjustedRandIndex
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' agreement <- calcARI(ligerex, ligerex@clusters)
calcARI <- function(object, clusters.compare, verbose = TRUE) {
  if (length(clusters.compare) < length(object@clusters) && verbose) {
    message("Calculating ARI for subset of all cells")
  }
  return(adjustedRandIndex(object@clusters[names(clusters.compare)],
                           clusters.compare))
}

#' Calculate purity
#'
#' Calculates purity for \code{liger} clustering and external clustering (true clusters/classes).
#' Purity can sometimes be a more useful metric when the clustering to be tested contains more
#' subgroups or clusters than the true clusters (or classes). Purity also ranges from 0 to 1,
#' with a score of 1 representing a pure, or accurate, clustering.
#'
#' @param object \code{liger} object. Should run quantileAlignSNF before calling.
#' @param classes.compare Clustering with which to compare (named vector).
#' @param verbose Print messages (TRUE by default)
#'
#' @return Purity value.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Specification for minimal example run time, not converging
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' agreement <- calcARI(ligerex, ligerex@clusters)
calcPurity <- function(object, classes.compare, verbose = TRUE) {
  if (length(classes.compare) < length(object@clusters) && verbose) {
    print("Calculating purity for subset of full cells")
  }
  clusters <- object@clusters[names(classes.compare)]
  purity <- sum(apply(table(classes.compare, clusters), 2, max)) / length(clusters)

  return(purity)
}

#' Calculate proportion mitochondrial contribution
#'
#' Calculates proportion of mitochondrial contribution based on raw or normalized data.
#'
#' @param object \code{liger} object.
#' @param use.norm Whether to use cell normalized data in calculating contribution (default FALSE).
#' @param mito.pattern Regex pattern for identifying mitochondrial genes. Default "^mt-" typically goes for mouse.
#' May use "^MT-" for human.
#' @return Named vector containing proportion of mitochondrial contribution for each cell.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' # Expect a warning because the test data does not contain mito genes
#' ligerex@cell.data$mito <- getProportionMito(ligerex, mito.pattern = "^MT-")
getProportionMito <- function(object, use.norm = FALSE, mito.pattern = "^mt-") {
  all.genes <- Reduce(union, lapply(object@raw.data, rownames))
  mito.genes <- grep(pattern = mito.pattern, x = all.genes, value = TRUE)
  if (length(mito.genes) == 0) {
    warning("No mito genes identified with pattern \"", mito.pattern, "\". ")
  }
  data.use <- object@raw.data
  if (use.norm) {
    data.use <- object@norm.data
  }
  percent_mito <- unlist(lapply(unname(data.use), function(x) {
    colSums(x[mito.genes, ]) / colSums(x)
  }), use.names = TRUE)

  return(percent_mito)
}

#######################################################################################
#### Visualization

#' Plot t-SNE coordinates of cells across datasets
#'
#' Generates two plots of all cells across datasets, one colored by dataset and one colored by
#' cluster. These are useful for visually examining the alignment and cluster distributions,
#' respectively. If clusters have not been set yet (quantileAlignSNF not called), will plot by
#' single color for second plot. It is also possible to pass in another clustering (as long as
#' names match those of cells).
#'
#' @param object \code{liger} object. Should call runTSNE or runUMAP before calling.
#' @param clusters Another clustering to use for coloring second plot (must have same names as
#'   clusters slot) (default NULL).
#' @param title Plot titles (list or vector of length 2) (default NULL).
#' @param pt.size Controls size of points representing cells (default 0.3).
#' @param text.size Controls size of plot text (cluster center labels) (default 3).
#' @param do.shuffle Randomly shuffle points so that points from same dataset are not plotted
#'   one after the other (default TRUE).
#' @param rand.seed Random seed for reproducibility of point shuffling (default 1).
#' @param axis.labels Vector of two strings to use as x and y labels respectively.
#' @param do.legend Display legend on plots (default TRUE).
#' @param legend.size Size of legend on plots (default 5).
#' @param reorder.idents logical whether to reorder the datasets from default order before plotting (default FALSE).
#' @param new.order new dataset factor order for plotting.  must set reorder.idents = TRUE.
#' @param return.plots Return ggplot plot objects instead of printing directly (default FALSE).
#' @param legend.fonts.size Controls the font size of the legend.
#' @param raster Rasterization of points (default NULL). Automatically convert to raster format if
#'   there are over 100,000 cells to plot.
#'
#' @return List of ggplot plot objects (only if return.plots TRUE, otherwise prints plots to
#'   console).
#'
#' @importFrom ggplot2 ggplot geom_point geom_text ggtitle guides guide_legend aes theme xlab ylab
#' @importFrom dplyr %>% group_by summarize
#' @importFrom scattermore geom_scattermore
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
#' ligerex <- louvainCluster(ligerex)
#' plotByDatasetAndCluster(ligerex, pt.size = 1)
plotByDatasetAndCluster <- function(object, clusters = NULL, title = NULL, pt.size = 0.3,
                                    text.size = 3, do.shuffle = TRUE, rand.seed = 1,
                                    axis.labels = NULL, do.legend = TRUE, legend.size = 5,
                                    reorder.idents = FALSE, new.order = NULL,
                                    return.plots = FALSE, legend.fonts.size = 12, raster = NULL) {
  # check raster and set by number of cells total if NULL
  if (is.null(x = raster)) {
    if (nrow(x = object@cell.data) > 1e5) {
      raster <- TRUE
      message("NOTE: Points are rasterized as number of cells/nuclei plotted exceeds 100,000.
              \n To plot in vector form set `raster = FALSE`.")
    } else {
      raster <- FALSE
    }
  }

  tsne_df <- data.frame(object@tsne.coords)
  colnames(tsne_df) <- c("Dim1", "Dim2")
  tsne_df[['Dataset']] <- unlist(lapply(1:length(object@H), function(x) {
    rep(names(object@H)[x], nrow(object@H[[x]]))
  }))
  if (isTRUE(reorder.idents)){
    tsne_df$Dataset <- factor(tsne_df$Dataset, levels = new.order)
  }
  c_names <- names(object@clusters)
  if (is.null(clusters)) {
    # if clusters have not been set yet
    if (length(object@clusters) == 0) {
      clusters <- rep(1, nrow(object@tsne.coords))
      names(clusters) <- c_names <- rownames(object@tsne.coords)
    } else {
      clusters <- object@clusters
      c_names <- names(object@clusters)
    }
  }
  tsne_df[['Cluster']] <- clusters[c_names]
  if (do.shuffle) {
    set.seed(rand.seed)
    idx <- sample(1:nrow(tsne_df))
    tsne_df <- tsne_df[idx, ]
  }


  if (isTRUE(x = raster)) {
    p1 <- ggplot(tsne_df, aes_string(x = 'Dim1', y = 'Dim2', color = 'Dataset')) + theme_bw() +
      theme_cowplot(legend.fonts.size) + geom_scattermore(pointsize = pt.size) +
      guides(color = guide_legend(override.aes = list(size = legend.size)))

    centers <- tsne_df %>% group_by(.data[['Cluster']]) %>% summarize(
      Dim1 = median(x = .data[['Dim1']]),
      Dim2 = median(x = .data[['Dim2']])
    )

    p2 <- ggplot(tsne_df, aes_string(x = 'Dim1', y = 'Dim2', color = 'Cluster')) +
      theme_cowplot(legend.fonts.size) + geom_scattermore(pointsize = pt.size) +
      geom_text(data = centers, mapping = aes_string(label = 'Cluster'), colour = "black", size = text.size) +
      guides(color = guide_legend(override.aes = list(size = legend.size)))
  } else {
    p1 <- ggplot(tsne_df, aes_string(x = 'Dim1', y = 'Dim2', color = 'Dataset')) + theme_bw() +
      theme_cowplot(legend.fonts.size) + geom_point(size = pt.size, stroke = 0.2) +
      guides(color = guide_legend(override.aes = list(size = legend.size)))

    centers <- tsne_df %>% group_by(.data[['Cluster']]) %>% summarize(
      Dim1 = median(x = .data[['Dim1']]),
      Dim2 = median(x = .data[['Dim2']])
    )

    p2 <- ggplot(tsne_df, aes_string(x = 'Dim1', y = 'Dim2', color = 'Cluster')) +
      theme_cowplot(legend.fonts.size) + geom_point(size = pt.size, stroke = 0.2) +
      geom_text(data = centers, mapping = aes_string(label = 'Cluster'), colour = "black", size = text.size) +
      guides(color = guide_legend(override.aes = list(size = legend.size)))
  }


  if (!is.null(title)) {
    p1 <- p1 + ggtitle(title[1])
    p2 <- p2 + ggtitle(title[2])
  }
  if (!is.null(axis.labels)) {
    p1 <- p1 + xlab(axis.labels[1]) + ylab(axis.labels[2])
    p2 <- p2 + xlab(axis.labels[1]) + ylab(axis.labels[2])
  }
  p1 <- p1 + theme_cowplot(12)
  p2 <- p2 + theme_cowplot(12)
  if (!do.legend) {
    p1 <- p1 + theme(legend.position = "none")
    p2 <- p2 + theme(legend.position = "none")
  }
  if (return.plots) {
    return(list(p1, p2))
  } else {
    print(p1)
    print(p2)
  }
}

#' Plot specific feature on t-SNE coordinates
#'
#' Generates one plot for each dataset, colored by chosen feature (column) from cell.data slot.
#' Feature can be categorical (factor) or continuous.
#' Can also plot all datasets combined with by.dataset = FALSE.
#'
#' @param object \code{liger} object. Should call runTSNE or runUMAP before calling.
#' @param feature Feature to plot (should be column from cell.data slot).
#' @param by.dataset Whether to generate separate plot for each dataset (default TRUE).
#' @param discrete Whether to treat feature as discrete; if left NULL will infer from column class
#'   in cell.data (if factor, treated like discrete) (default NULL).
#' @param title Plot title (default NULL).
#' @param pt.size Controls size of points representing cells (default 0.3).
#' @param text.size Controls size of plot text (cluster center labels) (default 3).
#' @param do.shuffle Randomly shuffle points so that points from same dataset are not plotted
#'   one after the other (default TRUE).
#' @param rand.seed Random seed for reproducibility of point shuffling (default 1).
#' @param do.labels Print centroid labels for categorical features (default FALSE).
#' @param axis.labels Vector of two strings to use as x and y labels respectively.
#' @param do.legend Display legend on plots (default TRUE).
#' @param legend.size Size of legend spots for discrete data (default 5).
#' @param option Colormap option to use for ggplot2's scale_color_viridis (default 'plasma').
#' @param cols.use Vector of colors to form gradient over instead of viridis colormap (low to high).
#'   Only applies to continuous features (default NULL).
#' @param zero.color Color to use for zero values (no expression) (default '#F5F5F5').
#' @param return.plots Return ggplot plot objects instead of printing directly (default FALSE).
#'
#' @return List of ggplot plot objects (only if return.plots TRUE, otherwise prints plots to
#'   console).
#'
#' @importFrom ggplot2 ggplot geom_point geom_text ggtitle aes guides guide_legend labs
#' scale_color_viridis_c scale_color_gradientn theme xlab ylab
#' @importFrom dplyr %>% group_by summarize
#' @importFrom stats median
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
#' plotFeature(ligerex, "nUMI", pt.size = 1)
plotFeature <- function(object, feature, by.dataset = TRUE, discrete = NULL, title = NULL,
                        pt.size = 0.3, text.size = 3, do.shuffle = TRUE, rand.seed = 1, do.labels = FALSE,
                        axis.labels = NULL, do.legend = TRUE, legend.size = 5, option = 'plasma',
                        cols.use = NULL, zero.color = '#F5F5F5', return.plots = FALSE) {
  dr_df <- data.frame(object@tsne.coords)
  colnames(dr_df) <- c("dr1", "dr2")
  if (!(feature %in% colnames(object@cell.data))) {
    stop('Please select existing feature in cell.data, or add it before calling.')
  }
  dr_df$feature <- object@cell.data[, feature]
  if (is.null(discrete)) {
    if (!is.factor(dr_df$feature)) {
      discrete <- FALSE
    } else {
      discrete <- TRUE
    }
  }
  if (!discrete){
    dr_df$feature[dr_df$feature == 0] <- NA
  }
  if (by.dataset) {
    dr_df$dataset <- object@cell.data$dataset
  } else {
    dr_df$dataset <- factor("single")
  }
  if (do.shuffle) {
    set.seed(rand.seed)
    idx <- sample(1:nrow(dr_df))
    dr_df <- dr_df[idx, ]
  }
  p_list <- list()
  for (sub_df in split(dr_df, f = dr_df$dataset)) {
    ggp <- ggplot(sub_df, aes_string(x = 'dr1', y = 'dr2', color = 'feature')) + geom_point(size = pt.size)

    # if data is discrete
    if (discrete) {
      ggp <- ggp + guides(color = guide_legend(override.aes = list(size = legend.size))) +
        labs(col = feature)
      if (do.labels) {
        centers <- sub_df %>% group_by(feature) %>% summarize(
          dr1 = median(x = sub_df[['dr1']]),
          dr2 = median(x = sub_df[['dr2']])
        )
        ggp <- ggp + geom_text(data = centers, mapping = aes(label = feature),
                               colour = "black", size = text.size)
      }
    } else {
      if (is.null(cols.use)) {
        ggp <- ggp + scale_color_viridis_c(option = option,
                                           direction = -1,
                                           na.value = zero.color) + labs(col = feature)
      } else {
        ggp <- ggp + scale_color_gradientn(colors = cols.use,
                                           na.value = zero.color) + labs(col = feature)
      }

    }
    if (by.dataset) {
      base <- as.character(sub_df$dataset[1])
    } else {
      base <- ""
    }
    if (!is.null(title)) {
      base <- paste(title, base)
    }
    ggp <- ggp + ggtitle(base)
    if (!is.null(axis.labels)) {
      ggp <- ggp + xlab(axis.labels[1]) + ylab(axis.labels[2])
    }
    if (!do.legend) {
      ggp <- ggp + theme(legend.position = "none")
    }
    p_list[[as.character(sub_df$dataset[1])]] <- ggp
  }
  if (by.dataset) {
    p_list <- p_list[names(object@raw.data)]
  }

  if (return.plots){
    if (length(p_list) == 1) {
      return(p_list[[1]])
    } else {
      return(p_list)
    }
  } else {
    for (plot in p_list) {
      print(plot)
    }
  }
}

#' Plot scatter plots of unaligned and aligned factor loadings
#'
#' @description
#' Generates scatter plots of factor loadings vs cells for both unaligned and aligned
#' (normalized) factor loadings. This allows for easier visualization of the changes made to the
#' factor loadings during the alignment step. Lists a subset of highly loading genes for each factor.
#' Also provides an option to plot t-SNE coordinates of the cells colored by aligned factor loadings.
#'
#' It is recommended to call this function into a PDF due to the large number of
#' plots produced.
#'
#' @param object \code{liger} object. Should call quantileAlignSNF before calling.
#' @param num.genes Number of genes to display for each factor (default 10).
#' @param cells.highlight Names of specific cells to highlight in plot (black) (default NULL).
#' @param plot.tsne Plot t-SNE coordinates for each factor (default FALSE).
#' @param verbose Print messages (TRUE by default)
#'
#' @return Plots to console (1-2 pages per factor)
#'
#' @importFrom graphics legend par plot
#' @importFrom grDevices rainbow
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' plotFactors(ligerex)
#' ligerex <- runTSNE(ligerex)
#' plotFactors(ligerex, plot.tsne = TRUE)
#' }
plotFactors <- function(object, num.genes = 10, cells.highlight = NULL, plot.tsne = FALSE, verbose = TRUE) {
  k <- ncol(object@H.norm)
  if (verbose) {
    pb <- txtProgressBar(min = 0, max = k, style = 3)
  }
  W <- t(object@W)
  rownames(W) <- colnames(object@H[[1]])
  Hs_norm <- object@H.norm
  # restore default settings when the current function exits
  init_par <- graphics::par(no.readonly = TRUE)
  on.exit(graphics::par(init_par))
  for (i in 1:k) {
    graphics::par(mfrow = c(2, 1))
    top_genes.W <- rownames(W)[order(W[, i], decreasing = TRUE)[1:num.genes]]
    top_genes.W.string <- paste0(top_genes.W, collapse = ", ")
    factor_textstring <- paste0("Factor", i)

    plot_title1 <- paste(factor_textstring, "\n", top_genes.W.string, "\n")
    cols <- rep("gray", times = nrow(Hs_norm))
    names(cols) <- rownames(Hs_norm)
    cols.use <- grDevices::rainbow(length(object@H))

    for (cl in 1:length(object@H)) {
      cols[rownames(object@H[[cl]])] <- rep(cols.use[cl], times = nrow(object@H[[cl]]))
    }
    if (!is.null(cells.highlight)) {
      cols[cells.highlight] <- rep("black", times = length(cells.highlight))
    }
    graphics::plot(1:nrow(Hs_norm), do.call(rbind, object@H)[, i],
                   cex = 0.2, pch = 20,
                   col = cols, main = plot_title1, xlab = "Cell", ylab = "Raw H Score"
    )
    graphics::legend("top", names(object@H), pch = 20, col = cols.use, horiz = TRUE, cex = 0.75)
    graphics::plot(1:nrow(Hs_norm), object@H.norm[, i],
                   pch = 20, cex = 0.2,
                   col = cols, xlab = "Cell", ylab = "H_norm Score"
    )
    if (plot.tsne) {
      graphics::par(mfrow = c(1, 1))
      fplot(object@tsne.coords, object@H.norm[, i], title = paste0("Factor ", i))
    }
    if (verbose) {
      setTxtProgressBar(pb, i)
    }
  }
}

#' Generate word clouds and t-SNE plots
#'
#' @description
#' Plots t-SNE coordinates of all cells by their loadings on each factor. Underneath it displays the
#' most highly loading shared and dataset-specific genes, with the size of the marker indicating
#' the magnitude of the loading.
#'
#' It is recommended to call this function into a PDF due to the large number of
#' plots produced.
#'
#' @param object \code{liger} object. Should call runTSNE before calling.
#' @param dataset1 Name of first dataset (by default takes first two datasets for dataset1 and 2)
#' @param dataset2 Name of second dataset
#' @param num.genes Number of genes to show in word clouds (default 30).
#' @param min.size Size of smallest gene symbol in word cloud (default 1).
#' @param max.size Size of largest gene symbol in word cloud (default 4).
#' @param factor.share.thresh Use only factors with a dataset specificity less than or equalt to
#'   threshold (default 10).
#' @param log.fc.thresh Lower log-fold change threshold for differential expression in markers
#'   (default 1).
#' @param pval.thresh Upper p-value threshold for Wilcoxon rank test for gene expression
#'   (default 0.05).
#' @param do.spec.plot Include dataset specificity plot in printout (default TRUE).
#' @param return.plots Return ggplot objects instead of printing directly (default FALSE).
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return List of ggplot plot objects (only if return.plots TRUE, otherwise prints plots to
#'   console).
#'
#' @importFrom ggrepel geom_text_repel
#' @importFrom ggplot2 ggplot aes aes_string geom_point ggtitle scale_color_gradient scale_size
#' scale_x_continuous scale_y_continuous coord_fixed labs
#' @importFrom grid roundrectGrob
#' @importFrom grid gpar
#' @importFrom cowplot draw_grob
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
#' plotWordClouds(ligerex, do.spec.plot = FALSE)
#' }
plotWordClouds <- function(object, dataset1 = NULL, dataset2 = NULL, num.genes = 30, min.size = 1,
                           max.size = 4, factor.share.thresh = 10, log.fc.thresh = 1, pval.thresh = 0.05,
                           do.spec.plot = TRUE, return.plots = FALSE, verbose = TRUE) {
  if (is.null(dataset1) | is.null(dataset2)) {
    dataset1 <- names(object@H)[1]
    dataset2 <- names(object@H)[2]
  }

  if(class(object@raw.data[[1]])[1] == "H5File"){
    sample.idx = unlist(lapply(object@sample.data, colnames))
    H_aligned = object@H.norm[sample.idx, ]
    tsne_coords <- object@tsne.coords[sample.idx, ]
  } else {
    H_aligned <- object@H.norm
    tsne_coords <- object@tsne.coords
  }

  W <- t(object@W)
  V1 <- t(object@V[[dataset1]])
  V2 <- t(object@V[[dataset2]])
  W <- pmin(W + V1, W + V2)

  dataset.specificity <- calcDatasetSpecificity(object, dataset1 = dataset1,
                                                dataset2 = dataset2, do.plot = do.spec.plot)
  factors.use <- which(abs(dataset.specificity[[3]]) <= factor.share.thresh)

  markers <- getFactorMarkers(object, dataset1 = dataset1, dataset2 = dataset2,
                              factor.share.thresh = factor.share.thresh,
                              num.genes = num.genes, log.fc.thresh = log.fc.thresh,
                              pval.thresh = pval.thresh,
                              dataset.specificity = dataset.specificity,
                              verbose = verbose
  )

  rownames(W) <- rownames(V1) <- rownames(V2) <- object@var.genes
  loadings_list <- list(V1, W, V2)
  names_list <- list(dataset1, "Shared", dataset2)
  if (verbose) {
    pb <- txtProgressBar(min = 0, max = length(factors.use), style = 3)
  }
  return_plots <- list()
  for (i in factors.use) {
    tsne_df <- data.frame(H_aligned[, i], tsne_coords)
    factorlab <- paste("Factor", i, sep = "")
    colnames(tsne_df) <- c(factorlab, "Dim1", "Dim2")
    factor_ds <- paste("Factor", i, "Dataset Specificity:", dataset.specificity[[3]][i])
    p1 <- ggplot(tsne_df, aes_string(x = "Dim1", y = "Dim2", color = factorlab)) + geom_point() +
      scale_color_gradient(low = "yellow", high = "red") + ggtitle(label = factor_ds)

    top_genes_V1 <- markers[[1]]$gene[markers[[1]]$factor_num == i]
    top_genes_W <- markers[[2]]$gene[markers[[2]]$factor_num == i]
    top_genes_V2 <- markers[[3]]$gene[markers[[3]]$factor_num == i]

    top_genes_list <- list(top_genes_V1, top_genes_W, top_genes_V2)
    plot_list <- lapply(seq_along(top_genes_list), function(x) {
      top_genes <- top_genes_list[[x]]
      gene_df <- data.frame(
        genes = top_genes,
        loadings = loadings_list[[x]][top_genes, i]
      )
      if (length(top_genes) == 0) {
        gene_df <- data.frame(genes = c("no genes"), loadings = c(1))
      }
      out_plot <- ggplot(gene_df, aes(x = 1, y = 1, size = loadings, label = .data[['genes']])) +
        geom_text_repel(force = 100, segment.color = NA) +
        scale_size(range = c(min.size, max.size), guide = FALSE) +
        scale_y_continuous(breaks = NULL) +
        scale_x_continuous(breaks = NULL) +
        labs(x = "", y = "") + ggtitle(label = names_list[[x]]) + coord_fixed() + ggplot2::theme_void()
      return(out_plot)
    })

    p2 <- (plot_grid(plotlist = plot_list, align = "hv", nrow = 1)
           + draw_grob(roundrectGrob(
             x = 0.33, y = 0.5, width = 0.67, height = 0.70,
             gp = gpar(fill = "khaki1", col = "Black", alpha = 0.5, lwd = 2)
           ))
           + draw_grob(roundrectGrob(
             x = 0.67, y = 0.5, width = 0.67, height = 0.70,
             gp = gpar(fill = "indianred1", col = "Black", alpha = 0.5, lwd = 2)
           )))
    return_plots[[i]] <- plot_grid(p1, p2, nrow = 2, align = "h")
    if (!return.plots) {
      print(return_plots[[i]])
    }
    if (verbose) {
      setTxtProgressBar(pb, i)
    }
  }
  if (return.plots) {
    return(return_plots)
  }
}

#' Generate t-SNE plots and gene loading plots
#'
#' @description
#' Plots t-SNE coordinates of all cells by their loadings on each factor. Underneath it displays the
#' most highly loading shared and dataset-specific genes, along with the overall gene loadings
#' for each dataset.
#'
#' It is recommended to call this function into a PDF due to the large number of
#' plots produced.
#'
#' @param object \code{liger} object. Should call runTSNE before calling.
#' @param dataset1 Name of first dataset (by default takes first two datasets for dataset1 and 2)
#' @param dataset2 Name of second dataset
#' @param num.genes Number of genes to show in word clouds (default 30).
#' @param num.genes.show Number of genes displayed as y-axis labels in the gene loading plots at
#' the bottom (default 12)
#' @param mark.top.genes Plot points corresponding to top loading genes in different color (default
#'   TRUE).
#' @param factor.share.thresh Use only factors with a dataset specificity less than or equal to
#'   threshold (default 10).
#' @param log.fc.thresh Lower log-fold change threshold for differential expression in markers
#'   (default 1).
#' @param umi.thresh Lower UMI threshold for markers (default 30).
#' @param frac.thresh Lower threshold for fraction of cells expressing marker (default 0).
#' @param pval.thresh Upper p-value threshold for Wilcoxon rank test for gene expression
#'   (default 0.05).
#' @param do.spec.plot Include dataset specificity plot in printout (default TRUE).
#' @param max.val Value between 0 and 1 at which color gradient should saturate to max color. Set to
#'   NULL to revert to default gradient scaling. (default 0.1)
#' @param pt.size Point size for plots (default 0.4).
#' @inheritParams plotGene
#' @param return.plots Return ggplot objects instead of printing directly (default FALSE).
#' @param axis.labels Vector of two strings to use as x and y labels respectively (default NULL).
#' @param do.title Include top title with cluster and Dataset Specificity (default FALSE).
#' @param verbose Print progress bar/messages (TRUE by default)
#' @param raster Rasterization of points (default NULL). Automatically convert to raster format if
#'   there are over 100,000 cells to plot.
#'
#' @return List of ggplot plot objects (only if return.plots TRUE, otherwise prints plots to
#'   console).
#'
#' @importFrom ggplot2 aes aes_string annotate coord_cartesian element_blank ggplot geom_point
#' ggtitle scale_color_viridis_c theme
#' theme_bw
#' @importFrom grid gpar unit
#' @import patchwork
#' @importFrom stats loadings
#' @importFrom cowplot theme_cowplot
#' @importFrom scattermore geom_scattermore
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
#' plotGeneLoadings(ligerex, "stim", "ctrl", do.spec.plot = FALSE)
#' }
plotGeneLoadings <- function(object, dataset1 = NULL, dataset2 = NULL, num.genes.show = 12,
                             num.genes = 30, mark.top.genes = TRUE, factor.share.thresh = 10,
                             log.fc.thresh = 1, umi.thresh = 30, frac.thresh = 0,
                             pval.thresh = 0.05, do.spec.plot = TRUE, max.val = 0.1, pt.size = 0.4,
                             option = "plasma", zero.color = "#F5F5F5", return.plots = FALSE,
                             axis.labels = NULL, do.title = FALSE, verbose = TRUE, raster = NULL) {
  # check raster and set by number of cells total if NULL
  if (is.null(x = raster)) {
    if (nrow(x = object@cell.data) > 1e5) {
      raster <- TRUE
      message("NOTE: Points are rasterized as number of cells/nuclei plotted exceeds 100,000.
              \n To plot in vector form set `raster = FALSE`.")
    } else {
      raster <- FALSE
    }
  }

  if (is.null(dataset1) | is.null(dataset2)) {
    dataset1 <- names(object@H)[1]
    dataset2 <- names(object@H)[2]
  }

  if(class(object@raw.data[[1]])[1] == "H5File"){
    sample.idx = unlist(lapply(object@sample.data, colnames))
    H_aligned = object@H.norm[sample.idx, ]
    tsne_coords <- object@tsne.coords[sample.idx, ]
  } else {
    H_aligned <- object@H.norm
    tsne_coords <- object@tsne.coords
  }

  W_orig <- t(object@W)
  V1 <- t(object@V[[dataset1]])
  V2 <- t(object@V[[dataset2]])
  W <- pmin(W_orig + V1, W_orig + V2)

  dataset.specificity <- calcDatasetSpecificity(object,
                                                dataset1 = dataset1,
                                                dataset2 = dataset2, do.plot = do.spec.plot
  )

  factors.use <- which(abs(dataset.specificity[[3]]) <= factor.share.thresh)


  markers <- getFactorMarkers(object,
                              dataset1 = dataset1, dataset2 = dataset2,
                              factor.share.thresh = factor.share.thresh,
                              num.genes = num.genes, log.fc.thresh = log.fc.thresh,
                              pval.thresh = pval.thresh,
                              dataset.specificity = dataset.specificity,
                              verbose = verbose
  )

  rownames(W) <- rownames(V1) <- rownames(V2) <- rownames(W_orig) <- object@var.genes
  loadings_list <- list(V1, W, V2)
  names_list <- list(dataset1, "Shared", dataset2)
  if (verbose) {
    pb <- txtProgressBar(min = 0, max = length(factors.use), style = 3)
  }
  return_plots <- list()
  for (i in factors.use) {
    tsne_df <- data.frame(H_aligned[, i], tsne_coords)
    factorlab <- paste("Factor", i, sep = "")
    colnames(tsne_df) <- c(factorlab, "Dim1", "Dim2")
    tsne_df[[factorlab]][tsne_df[[factorlab]] == 0] <- NA
    factor_ds <- paste("Factor", i, "Dataset Specificity:", dataset.specificity[[3]][i])
    data.max <- max(object@H.norm[, i])
    # plot TSNE
    if (!is.null(max.val)) {
      values <- c(0, max.val, 1)
    } else {
      values <- NULL
    }

    if (isTRUE(x = raster)) {
      p1 <- ggplot(tsne_df, aes_string(x = "Dim1", y = "Dim2", color = factorlab)) +
        geom_scattermore(pointsize = pt.size) +
        scale_color_viridis_c(
          option = option,
          direction = -1,
          na.value = zero.color, values = values
        ) +
        theme_cowplot(12)
    } else {
      p1 <- ggplot(tsne_df, aes_string(x = "Dim1", y = "Dim2", color = factorlab)) +
        geom_point(size = pt.size) +
        scale_color_viridis_c(
          option = option,
          direction = -1,
          na.value = zero.color, values = values
        ) +
        theme_cowplot(12)
    }


    if (!is.null(axis.labels)) {
      p1 <- p1 + xlab(axis.labels[1]) + ylab(axis.labels[2])
    }
    if (do.title) {
      p1 <- p1 + ggtitle(label = factor_ds)
    }

    # subset to specific factor and sort by p-value
    top_genes_V1 <- markers[[1]][markers[[1]]$factor_num == i, ]
    top_genes_V1 <- top_genes_V1[order(top_genes_V1$p_value), ]$gene
    # don't sort for W
    top_genes_W <- markers[[2]][markers[[2]]$factor_num == i, ]$gene
    top_genes_V2 <- markers[[3]][markers[[3]]$factor_num == i, ]
    top_genes_V2 <- top_genes_V2[order(top_genes_V2$p_value), ]$gene

    top_genes_list <- list(top_genes_V1, top_genes_W, top_genes_V2)
    # subset down to those which will be shown if sorting by p-val

    top_genes_list <- lapply(top_genes_list, function(x) {
      if (length(x) > num.genes.show) {
        # to avoid subset warning
        x <- x[1:num.genes.show]
      }
      x
    })

    plot_list <- lapply(seq_along(top_genes_list), function(x) {
      top_genes <- top_genes_list[[x]]
      # make dataframe for cum gene loadings plot
      sorted <- sort(loadings_list[[x]][, i])
      # sort by loadings instead - still only showing num.genes.show
      # look through top num.genes in loadings
      top_loaded <- names(rev(sorted[(length(sorted) - num.genes + 1):length(sorted)]))
      top_genes <- top_loaded[which(top_loaded %in% top_genes)]
      if (length(top_genes) == 0) {
        top_genes <- c("no genes")
      }

      gene_df <- data.frame(
        loadings = sorted,
        xpos = seq(0, 1, length.out = length(sorted)),
        top_k = names(sorted) %in% top_genes
      )
      y_lim_text <- max(gene_df$loadings)
      # plot and annotate with top genes

      out_plot <- ggplot(gene_df, aes_string(x = 'xpos', y = 'loadings')) +
        geom_point(size = pt.size) +
        theme_bw() +
        theme(
          axis.ticks.x = element_blank(),
          axis.line.x = element_blank(),
          axis.title = element_blank(),
          axis.text.x = element_blank(),
          panel.grid.major.x = element_blank(),
          panel.grid.minor.x = element_blank()
        ) +
        ggtitle(label = names_list[[x]]) +
        annotate("text",
                 x = 1.1,
                 y = seq(y_lim_text, 0, length.out = num.genes.show)[1:length(top_genes)],
                 label = top_genes, hjust = 0, col = "#8227A0"
        ) +
        coord_cartesian(
          xlim = c(0, 1), # This focuses the x-axis on the range of interest
          clip = "off"
        ) +
        theme(plot.margin = unit(c(1, 4, 1, 1), "lines"))

      if (mark.top.genes) {
        out_plot <- out_plot + geom_point(
          data = subset(gene_df, gene_df[['top_k']] == TRUE),
          aes_string('xpos', 'loadings'),
          col = "#8227A0", size = 0.5
        )
      }
      return(out_plot)
    })

    # p2 <- plot_grid(plotlist = plot_list, nrow = 1)

    return_plots[[i]] <- p1 / (plot_list[[1]] | plot_list[[2]] | plot_list[[3]])
    # if can figure out how to make cowplot work, might bring this back
    # return_plots[[i]] <- plot_grid(p1, p2, nrow = 2, align = "h")
    if (!return.plots) {
      print(return_plots[[i]])
    }
    if (verbose) {
      setTxtProgressBar(pb, i)
    }
  }
  if (return.plots) {
    return(return_plots)
  }
}

#' Plot violin plots for gene expression
#'
#' Generates violin plots of expression of specified gene for each dataset.
#'
#' @param object \code{liger} object.
#' @param gene Gene for which to plot relative expression.
#' @param methylation.indices Indices of datasets in object with methylation data (this data is not
#'   magnified and put on log scale).
#' @param by.dataset Plots gene expression for each dataset separately (default TRUE).
#' @param return.plots Return ggplot objects instead of printing directly to console (default
#'   FALSE).
#'
#' @return List of ggplot plot objects (only if return.plots TRUE, otherwise prints plots to
#'   console).
#'
#' @importFrom cowplot plot_grid
#' @importFrom ggplot2 aes_string ggplot geom_point geom_boxplot geom_violin ggtitle labs
#' scale_color_gradient2 theme
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 2)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- louvainCluster(ligerex)
#' plotGeneViolin(ligerex, "CD74", by.dataset = FALSE)
#' plotGeneViolin(ligerex, "CD74")
plotGeneViolin <- function(object, gene, methylation.indices = NULL,
                           by.dataset = TRUE, return.plots = FALSE) {
  if (class(object@raw.data[[1]])[1] == "H5File"){
    if (object@h5file.info[[1]][["sample.data.type"]] != "norm.data"){
      stop("norm.data should be sampled for making violin plots.")
    }
  }

  gene_vals <- c()
  gene_df <- data.frame(Clusters = object@clusters)

  for (i in 1:length(object@raw.data)) {
    if (class(object@raw.data[[i]])[1] == "H5File"){
      if (i %in% methylation.indices) {
        gene_vals <- c(gene_vals, object@sample.data[[i]][gene, ])
      } else {
        if (gene %in% rownames(object@sample.data[[i]])) {
          gene_vals_int <- log1p(10000 * object@sample.data[[i]][gene, ])
        }
        else {
          gene_vals_int <- rep(list(0), ncol(object@sample.data[[i]]))
          names(gene_vals_int) <- colnames(object@sample.data[[i]])
        }
        gene_vals <- c(gene_vals, gene_vals_int)
      }
    } else {
      if (i %in% methylation.indices) {
        gene_vals <- c(gene_vals, object@norm.data[[i]][gene, ])
      } else {
        if (gene %in% rownames(object@norm.data[[i]])) {
          gene_vals_int <- log1p(10000 * object@norm.data[[i]][gene, ])
        }
        else {
          gene_vals_int <- rep(list(0), ncol(object@norm.data[[i]]))
          names(gene_vals_int) <- colnames(object@norm.data[[i]])
        }
        gene_vals <- c(gene_vals, gene_vals_int)
      }
    }
  }

  gene_df$Gene <- as.numeric(gene_vals[rownames(gene_df)])
  gene_plots <- list()
  for (i in 1:length(object@scale.data)) {
    if (by.dataset) {
      gene_df.sub <- gene_df[colnames(object@norm.data[[i]]), ]
      gene_df.sub$Cluster <- object@clusters[colnames(object@norm.data[[i]])]
      title <- names(object@norm.data)[i]
    } else {
      gene_df.sub <- gene_df
      gene_df.sub$Cluster <- object@clusters
      title <- "All Datasets"
    }
    max_v <- max(gene_df.sub["Gene"], na.rm = TRUE)
    min_v <- min(gene_df.sub["Gene"], na.rm = TRUE)
    midpoint <- (max_v - min_v) / 2
    plot_i <- ggplot(gene_df.sub, aes_string(x = "Cluster", y = "Gene", fill = "Cluster")) +
      geom_boxplot(position = "dodge", width = 0.4, outlier.shape = NA, alpha = 0.7) +
      geom_violin(position = "dodge", alpha = 0.7) +
      ggtitle(title)
    gene_plots[[i]] <- plot_i + theme(legend.position = "none") + labs(y = gene)
    if (i == 1 & !by.dataset) {
      break
    }
  }
  if (return.plots) {
    return(gene_plots)
  } else {
    for (i in 1:length(gene_plots)) {
      print(gene_plots[[i]])
    }
  }
}

#' Plot gene expression on dimensional reduction (t-SNE) coordinates
#'
#' Generates plot of dimensional reduction coordinates (default t-SNE) colored by expression of
#' specified gene. Data can be scaled by dataset or selected feature column from cell.data (or across
#' all cells). Data plots can be split by feature.
#'
#' @param object \code{liger} object. Should call runTSNE before calling.
#' @param gene Gene for which to plot expression.
#' @param use.raw Plot raw UMI values instead of normalized, log-transformed data (default FALSE).
#' @param use.scaled Plot values scaled across specified groups of cells (with log transformation)
#'   (default FALSE).
#' @param scale.by Grouping of cells by which to scale gene (can be any factor column in cell.data
#'   or 'none' for scaling across all cells) (default 'dataset').
#' @param log2scale Whether to show log2 transformed values or original normalized, raw, or scaled
#'   values (as stored in object). Default value is FALSE if use.raw = TRUE, otherwise TRUE.
#' @param methylation.indices Indices of datasets in object with methylation data (this data is not
#'   log transformed and must use normalized values). (default NULL)
#' @param plot.by How to group cells for plotting (can be any factor column in cell.data or 'none'
#'   for plotting all cells in a single plot). Note that this can result in large number of plots.
#'   Users are encouraged to use same value as for scale.by (default 'dataset').
#' @param set.dr.lims Whether to keep dimensional reduction coordinates consistent when multiple
#'   plots created (default FALSE).
#' @param pt.size Point size for plots (default 0.1).
#' @param min.clip Minimum value for expression values plotted. Can pass in quantile (0-1) or
#'   absolute cutoff (set clip.absolute = TRUE). Can also pass in vector if expecting multiple plots;
#'   users are encouraged to pass in named vector (from levels of desired feature) to avoid
#'   mismatches in order (default NULL).
#' @param max.clip Maximum value for expression values plotted. Can pass in quantile (0-1) or
#'   absolute cutoff (set clip.absolute = TRUE). Can also pass in vector if expecting multiple plots;
#'   users are encouraged to pass in named vector (from levels of desired feature) to avoid
#'   mismatches in order (default NULL).
#' @param clip.absolute Whether to treat clip values as absolute cutoffs instead of quantiles
#'   (default FALSE).
#' @param points.only Remove axes, background, and legend when plotting coordinates (default FALSE).
#' @param option Colormap option to use for ggplot2's scale_color_viridis (default 'plasma').
#' @param cols.use Vector of colors to form gradient over instead of viridis colormap (low to high).
#'   (default NULL).
#' @param zero.color Color to use for zero values (no expression) (default '#F5F5F5').
#' @param axis.labels Vector of two strings to use as x and y labels respectively. (default NULL)
#' @param do.legend Display legend on plots (default TRUE).
#' @param return.plots Return ggplot objects instead of printing directly (default FALSE).
#' @param keep.scale Maintain min/max color scale across all plots when using plot.by (default FALSE)
#' @param raster Rasterization of points (default NULL). Automatically convert to raster format if
#'   there are over 100,000 cells to plot.
#'
#' @return If returning single plot, returns ggplot object; if returning multiple plots; returns
#'   list of ggplot objects.
#'
#' @importFrom dplyr %>% group_by mutate_at vars group_cols
#' @importFrom ggplot2 ggplot geom_point aes_string element_blank ggtitle labs xlim ylim
#' scale_color_viridis_c scale_color_gradientn theme
#' @importFrom stats quantile
#' @importFrom scattermore geom_scattermore
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
#' plotGene(ligerex, "CD74", pt.size = 1)
plotGene <- function(object, gene, use.raw = FALSE, use.scaled = FALSE, scale.by = 'dataset',
                     log2scale = NULL, methylation.indices = NULL, plot.by = 'dataset',
                     set.dr.lims = FALSE, pt.size = 0.1, min.clip = NULL, max.clip = NULL,
                     clip.absolute = FALSE, points.only = FALSE, option = 'plasma', cols.use = NULL,
                     zero.color = '#F5F5F5', axis.labels = NULL, do.legend = TRUE, return.plots = FALSE,
                     keep.scale = FALSE, raster = NULL) {
  if ((plot.by != scale.by) & (use.scaled)) {
    warning("Provided values for plot.by and scale.by do not match; results may not be very
            interpretable.")
  }

  # check raster and set by number of cells total if NULL
  if (is.null(x = raster)) {
    if (nrow(x = object@cell.data) > 1e5) {
      raster <- TRUE
      message("NOTE: Points are rasterized as number of cells/nuclei plotted exceeds 100,000.
              \n To plot in vector form set `raster = FALSE`.")
    } else {
      raster <- FALSE
    }
  }


  if (use.raw) {
    if (is.null(log2scale)) {
      log2scale <- FALSE
    }
    # drop only outer level names
    if (class(object@raw.data[[1]])[1] == "H5File") {
      if (object@h5file.info[[1]][["sample.data.type"]] != "raw.data"){
        stop("raw.data should be sampled for this plot.")
      }
      gene_vals <- getGeneValues(object@sample.data, gene, log2scale = log2scale)
    } else {
      gene_vals <- getGeneValues(object@raw.data, gene, log2scale = log2scale)
    }
  } else {
    if (is.null(log2scale)) {
      log2scale <- TRUE
    }
    # rescale in case requested gene not highly variable
    if (use.scaled) {
      # check for feature
      if (!(scale.by %in% colnames(object@cell.data)) & scale.by != 'none') {
        stop("Please select existing feature in cell.data to scale.by, or add it before calling.")
      }
      if (class(object@raw.data[[1]])[1] == "H5File") {
        if (object@h5file.info[[1]][["sample.data.type"]] != "norm.data"){
          stop("norm.data should be sampled for this plot.")
        }
        gene_vals <- getGeneValues(object@sample.data, gene)
        cells <- unlist(lapply(object@sample.data, colnames))
      } else {
        gene_vals <- getGeneValues(object@norm.data, gene)
        cells <- unlist(lapply(object@norm.data, colnames))
      }
      cellnames <- names(gene_vals)
      # set up dataframe with groups
      gene_df <- data.frame(gene = gene_vals)
      if (scale.by == 'none') {
        gene_df[['scaleby']] = 'none'
      } else {
        gene_df[['scaleby']] = factor(object@cell.data[cells,][[scale.by]])
      }
      gene_df1 <- gene_df %>%
        group_by(.data[['scaleby']]) %>%
        # scale by selected feature
        mutate_at(vars(-group_cols()), function(x) { scale(x, center = FALSE)})
      gene_vals <- gene_df1$gene
      if (log2scale) {
        gene_vals <- log2(10000 * gene_vals + 1)
      }
      names(gene_vals) <- cellnames
    } else {
      # using normalized data
      # indicate methylation indices here
      if (class(object@raw.data[[1]])[1] == "H5File") {
        if (object@h5file.info[[1]][["sample.data.type"]] != "norm.data"){
          stop("norm.data should be sampled for this plot.")
        }
        gene_vals <- getGeneValues(object@sample.data, gene, methylation.indices = methylation.indices,
                                   log2scale = log2scale)
      } else {
        gene_vals <- getGeneValues(object@norm.data, gene, methylation.indices = methylation.indices,
                                   log2scale = log2scale)
      }
    }
  }
  gene_vals[gene_vals == 0] <- NA
  # Extract min and max expression values for plot scaling if keep.scale = TRUE
  if (keep.scale){
    max_exp_val <- max(gene_vals, na.rm = TRUE)
    min_exp_val <- min(gene_vals, na.rm = TRUE)
  }

  if (class(object@raw.data[[1]])[1] == "H5File") {
    cells <- unlist(lapply(object@sample.data, colnames))
    dr_df <- data.frame(object@tsne.coords[cells,])
  } else {
    dr_df <- data.frame(object@tsne.coords)
    rownames(dr_df) <- rownames(object@cell.data)
  }
  dr_df$gene <- as.numeric(gene_vals[rownames(dr_df)])
  colnames(dr_df) <- c("dr1", "dr2", "gene")
  # get dr limits for later
  lim1 <- c(min(dr_df$dr1), max(dr_df$dr1))
  lim2 <- c(min(dr_df$dr2), max(dr_df$dr2))

  if (plot.by != 'none') {
    if (!(plot.by %in% colnames(object@cell.data))) {
      stop("Please select existing feature in cell.data to plot.by, or add it before calling.")
    }
    dr_df$plotby <- factor(object@cell.data[rownames(dr_df),][[plot.by]])
  } else {
    dr_df$plotby <- factor("none")
  }
  # expand clip values if only single provided
  num_levels <- length(levels(dr_df$plotby))
  if (length(min.clip) == 1) {
    min.clip <- rep(min.clip, num_levels)
    names(min.clip) <- levels(dr_df$plotby)
  }
  if (length(max.clip) == 1) {
    max.clip <- rep(max.clip, num_levels)
    names(max.clip) <- levels(dr_df$plotby)
  }
  if (!is.null(min.clip) & is.null(names(min.clip))) {
    if (num_levels > 1) {
      message("Adding names to min.clip according to levels in plot.by group; order may not be
              preserved as intended if multiple clip values passed in. Pass in named vector to
              prevent this.")
    }
    names(min.clip) <- levels(dr_df$plotby)
  }
  if (!is.null(max.clip) & is.null(names(max.clip))) {
    if (num_levels > 1) {
      message("Adding names to max.clip according to levels in plot.by group; order may not be
              preserved as intended if multiple clip values passed in. Pass in named vector to
              prevent this.")
    }
    names(max.clip) <- levels(dr_df$plotby)
  }
  p_list <- list()
  for (sub_df in split(dr_df, f = dr_df$plotby)) {
    # maybe do quantile cutoff here
    group_name <- as.character(sub_df$plotby[1])
    if (!clip.absolute) {
      max_v <- quantile(sub_df$gene, probs = max.clip[group_name], na.rm = TRUE)
      min_v <- quantile(sub_df$gene, probs = min.clip[group_name], na.rm = TRUE)
    } else {
      max_v <- max.clip[group_name]
      min_v <- min.clip[group_name]
    }
    sub_df$gene[sub_df$gene < min_v & !is.na(sub_df$gene)] <- min_v
    sub_df$gene[sub_df$gene > max_v & !is.na(sub_df$gene)] <- max_v

    if (isTRUE(x = raster)) {
      ggp <- ggplot(sub_df, aes_string(x = 'dr1', y = 'dr2', color = 'gene')) + geom_scattermore(pointsize = pt.size) +
        labs(col = gene)
    } else {
      ggp <- ggplot(sub_df, aes_string(x = 'dr1', y = 'dr2', color = 'gene')) + geom_point(size = pt.size) +
        labs(col = gene)
    }

    if (!is.null(cols.use)) {
      if (keep.scale) {
        ggp <- ggp + scale_color_gradientn(colors = cols.use,
                                           na.value = zero.color,
                                           limits = c(min_exp_val, max_exp_val))
      } else {
        ggp <- ggp + scale_color_gradientn(colors = cols.use,
                                           na.value = zero.color)
      }
    } else {
      if (keep.scale) {
        ggp <- ggp + scale_color_viridis_c(option = option,
                                           direction = -1,
                                           na.value = zero.color,
                                           limits = c(min_exp_val, max_exp_val))
      } else {
        ggp <- ggp + scale_color_viridis_c(option = option,
                                           direction = -1,
                                           na.value = zero.color)
      }
    }
    if (set.dr.lims) {
      ggp <- ggp + xlim(lim1) + ylim(lim2)
    }

    if (plot.by != 'none') {
      base <- as.character(sub_df$plotby[1])
    } else {
      base <- ""
    }
    ggp <- ggp + ggtitle(base)

    if (!is.null(axis.labels)) {
      ggp <- ggp + xlab(axis.labels[1]) + ylab(axis.labels[2])
    }
    if (!do.legend) {
      ggp <- ggp + theme(legend.position = "none")
    }
    if (points.only) {
      ggp <- ggp + theme(
        axis.line = element_blank(), axis.text.x = element_blank(),
        axis.text.y = element_blank(), axis.ticks = element_blank(),
        axis.title.x = element_blank(),
        axis.title.y = element_blank(), legend.position = "none",
        panel.background = element_blank(), panel.border = element_blank(),
        panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
        plot.background = element_blank(), plot.title = element_blank()
      )
    }
    p_list[[as.character(sub_df$plotby[1])]] <- ggp + theme_cowplot(12)
  }
  if (plot.by == 'dataset') {
    p_list <- p_list[names(object@raw.data)]
  }

  if (return.plots){
    if (length(p_list) == 1) {
      return(p_list[[1]])
    } else {
      return(p_list)
    }
  } else {
    for (plot in p_list) {
      print(plot)
    }
  }
}

#' Plot expression of multiple genes
#'
#' Uses plotGene to plot each gene (and dataset) on a separate page. It is recommended to call this
#' function into a PDF due to the large number of plots produced.
#'
#' @param object \code{liger} object. Should call runTSNE before calling.
#' @param genes Vector of gene names.
#' @param ... arguments passed from \code{\link[rliger]{plotGene}}
#'
#' @return If returning single plot, returns ggplot object; if returning multiple plots; returns
#'   list of ggplot objects.
#'
#' @importFrom ggplot2 ggplot geom_point aes_string scale_color_gradient2 ggtitle
#'
#' @export
#' @examples
#' \donttest{
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 1)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- runTSNE(ligerex)
#' plotGenes(ligerex, c("CD74", "NKG7"), pt.size = 1)
#' }
plotGenes <- function(object, genes, ...) {
  for (i in 1:length(genes)) {
    print(genes[i])
    plotGene(object, genes[i], ...)
  }
}

#' Generate a river (Sankey) plot
#'
#' Creates a riverplot to show how separate cluster assignments from two datasets map onto a
#' joint clustering. The joint clustering is by default the object clustering, but an external one
#' can also be passed in. Uses the riverplot package to construct riverplot object and then plot.
#'
#' @param object \code{liger} object. Should run quantileAlignSNF before calling.
#' @param cluster1 Cluster assignments for dataset 1. Note that cluster names should be distinct
#'   across datasets.
#' @param cluster2 Cluster assignments for dataset 2. Note that cluster names should be distinct
#'   across datasets.
#' @param cluster_consensus Optional external consensus clustering (to use instead of object
#'   clusters)
#' @param min.frac Minimum fraction of cluster for edge to be shown (default 0.05).
#' @param min.cells Minumum number of cells for edge to be shown (default 10).
#' @param river.yscale y-scale to pass to riverplot -- scales the edge with values by this factor,
#'   can be used to squeeze vertically (default 1).
#' @param river.lty Line style to pass to riverplot (default 0).
#' @param river.node_margin Node_margin to pass to riverplot -- how much vertical space to keep
#'   between the nodes (default 0.1).
#' @param label.cex Size of text labels (default 1).
#' @param label.col Color of text labels (defualt "black").
#' @param lab.srt Angle of text labels (default 0).
#' @param river.usr Coordinates at which to draw the plot in form (x0, x1, y0, y1).
#' @param node.order Order of clusters in each set (list with three vectors of ordinal numbers).
#'   By default will try to automatically order them appropriately.
#'
#' @return NULL for now. Could be back if CRAN dependency riverplot is back.
#'
#' @importFrom plyr mapvalues
#' @importFrom grDevices hcl
#' @importFrom utils capture.output
#'
#' @export
#' @examples
#' \dontrun{
#' # Riverplot currently archived, cannot run this example
#' # ligerex (liger object), factorization complete input
#' # toy clusters
#' cluster1 <- sample(c('type1', 'type2', 'type3'), ncol(ligerex@raw.data[[1]]), replace = TRUE)
#' names(cluster1) <- colnames(ligerex@raw.data[[1]])
#' cluster2 <- sample(c('type4', 'type5', 'type6'), ncol(ligerex@raw.data[[2]]), replace = TRUE)
#' names(cluster2) <- colnames(ligerex@raw.data[[2]])
#' # create riverplot
#' makeRiverplot(ligerex, cluster1, cluster2)
#' }
makeRiverplot <- function(object, cluster1, cluster2, cluster_consensus = NULL, min.frac = 0.05,
                          min.cells = 10, river.yscale = 1, river.lty = 0, river.node_margin = 0.1,
                          label.cex = 1, label.col = "black", lab.srt = 0, river.usr = NULL,
                          node.order = "auto") {
  .Deprecated(NULL, msg = "Cran package riverplot is archived, we have to disable this function for now.")
  return(NULL)
  # cluster1 <- droplevels(cluster1)
  # cluster2 <- droplevels(cluster2)
  # if (is.null(cluster_consensus)) {
  #   cluster_consensus <- droplevels(object@clusters)
  # }
  # # Make cluster names unique if necessary
  # if (length(intersect(levels(cluster1), levels(cluster2))) > 0 |
  #     length(intersect(levels(cluster1), levels(cluster_consensus))) > 0 |
  #     length(intersect(levels(cluster2), levels(cluster_consensus))) > 0) {
  #   message("Duplicate cluster names detected. Adding 1- and 2- to make unique names.")
  #   cluster1 <- mapvalues(cluster1, from = levels(cluster1),
  #                         to = paste("1", levels(cluster1), sep = "-"))
  #   cluster2 <- mapvalues(cluster2, from = levels(cluster2),
  #                         to = paste("2", levels(cluster2), sep = "-"))
  # }
  # cluster1 <- cluster1[intersect(names(cluster1), names(cluster_consensus))]
  # cluster2 <- cluster2[intersect(names(cluster2), names(cluster_consensus))]
  #
  # # set node order
  # if (identical(node.order, "auto")) {
  #   tab.1 <- table(cluster1, cluster_consensus[names(cluster1)])
  #   tab.1 <- sweep(tab.1, 1, rowSums(tab.1), "/")
  #   tab.2 <- table(cluster2, cluster_consensus[names(cluster2)])
  #   tab.2 <- sweep(tab.2, 1, rowSums(tab.2), "/")
  #   whichmax.1 <- apply(tab.1, 1, which.max)
  #   whichmax.2 <- apply(tab.2, 1, which.max)
  #   ord.1 <- order(whichmax.1)
  #   ord.2 <- order(whichmax.2)
  #   cluster1 <- factor(cluster1, levels = levels(cluster1)[ord.1])
  #   cluster2 <- factor(cluster2, levels = levels(cluster2)[ord.2])
  # } else {
  #   if (is.list(node.order)) {
  #     cluster1 <- factor(cluster1, levels = levels(cluster1)[node.order[[1]]])
  #     cluster_consensus <- factor(cluster_consensus,
  #                                 levels = levels(cluster_consensus)[node.order[[2]]])
  #     cluster2 <- factor(cluster2, levels = levels(cluster2)[node.order[[3]]])
  #   }
  # }
  # cluster1 <- cluster1[!is.na(cluster1)]
  # cluster2 <- cluster2[!is.na(cluster2)]
  # nodes1 <- levels(cluster1)[table(cluster1) > 0]
  # nodes2 <- levels(cluster2)[table(cluster2) > 0]
  # nodes_middle <- levels(cluster_consensus)[table(cluster_consensus) > 0]
  # node_Xs <- c(
  #   rep(1, length(nodes1)), rep(2, length(nodes_middle)),
  #   rep(3, length(nodes2))
  # )
  #
  # # first set of edges
  # edge_list <- list()
  # for (i in 1:length(nodes1)) {
  #   temp <- list()
  #   i_cells <- names(cluster1)[cluster1 == nodes1[i]]
  #   for (j in 1:length(nodes_middle)) {
  #     if (length(which(cluster_consensus[i_cells] == nodes_middle[j])) / length(i_cells) > min.frac &
  #         length(which(cluster_consensus[i_cells] == nodes_middle[j])) > min.cells) {
  #       temp[[nodes_middle[j]]] <- sum(cluster_consensus[i_cells] ==
  #                                        nodes_middle[j]) / length(cluster1)
  #     }
  #   }
  #   edge_list[[nodes1[i]]] <- temp
  # }
  # # second set of edges
  # cluster3 <- cluster_consensus[names(cluster2)]
  # for (i in 1:length(nodes_middle)) {
  #   temp <- list()
  #   i_cells <- names(cluster3)[cluster3 == nodes_middle[i]]
  #   for (j in 1:length(nodes2)) {
  #     j_cells <- names(cluster2)[cluster2 == nodes2[j]]
  #     if (length(which(cluster_consensus[j_cells] == nodes_middle[i])) / length(j_cells) > min.frac &
  #         length(which(cluster_consensus[j_cells] == nodes_middle[i])) > min.cells) {
  #       if (!is.na(sum(cluster2[i_cells] == nodes2[j]))) {
  #         temp[[nodes2[j]]] <- sum(cluster2[i_cells] ==
  #                                    nodes2[j]) / length(cluster2)
  #       }
  #     }
  #   }
  #   edge_list[[nodes_middle[i]]] <- temp
  # }
  # # set cluster colors
  # node_cols <- list()
  # ggplotColors <- function(g) {
  #   d <- 360 / g
  #   h <- cumsum(c(15, rep(d, g - 1)))
  #   grDevices::hcl(h = h, c = 100, l = 65)
  # }
  # pal <- ggplotColors(length(nodes1))
  # for (i in 1:length(nodes1)) {
  #   node_cols[[nodes1[i]]] <- list(col = pal[i], textcex = label.cex,
  #                                  textcol = label.col, srt = lab.srt)
  # }
  # pal <- ggplotColors(length(nodes_middle))
  # for (i in 1:length(nodes_middle)) {
  #   node_cols[[nodes_middle[i]]] <- list(col = pal[i], textcex = label.cex,
  #                                        textcol = label.col, srt = lab.srt)
  # }
  # pal <- ggplotColors(length(nodes2))
  # for (i in 1:length(nodes2)) {
  #   node_cols[[nodes2[i]]] <- list(col = pal[i], textcex = label.cex,
  #                                  textcol = label.col, srt = lab.srt)
  # }
  # # create nodes and riverplot object
  # nodes <- list(nodes1, nodes_middle, nodes2)
  # node.limit <- max(unlist(lapply(nodes, length)))
  #
  # node_Ys <- lapply(1:length(nodes), function(i) {
  #   seq(1, node.limit, by = node.limit / length(nodes[[i]]))
  # })
  # rp <- makeRiver(c(nodes1, nodes_middle, nodes2), edge_list,
  #                 node_xpos = node_Xs, node_ypos = unlist(node_Ys), node_styles = node_cols
  # )
  # prevent normal riverplot output being printed to console
  # invisible(capture.output(riverplot(rp,
  #                                    yscale = river.yscale, lty = river.lty,
  #                                    node_margin = river.node_margin, usr = river.usr
  # )))
}

#' Plot cluster proportions by dataset
#'
#' Generates plot of clusters sized by the proportion of total cells
#'
#' @param object \code{liger} object. Should call quantileAlignSNF before calling.
#' @param return.plot Return ggplot object (default FALSE)
#'
#' @return print plot to console (return.plot = FALSE); ggplot object (return.plot = TRUE)
#'   list of ggplot objects.
#'
#' @importFrom grid unit
#' @importFrom ggplot2 ggplot aes coord_fixed element_blank geom_point guides guide_legend
#' scale_size scale_y_discrete theme
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 2)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- louvainCluster(ligerex)
#' plotClusterProportions(ligerex)
plotClusterProportions <- function(object, return.plot = FALSE) {

  sample_names <- unlist(lapply(seq_along(object@H), function(i) {
    rep(names(object@H)[i], nrow(object@H[[i]]))
  }))
  freq_table <- data.frame(rep(object@clusters, length(object@scale.data)),
                           sample_names)
  freq_table <- table(freq_table[,1], freq_table[,2])
  for (i in 1:ncol(freq_table)) {
    freq_table[, i] <- freq_table[, i] / sum(freq_table[, i])
  }
  freq_table <- as.data.frame(freq_table)
  colnames(freq_table) <- c("Cluster", "Sample", "Proportion")
  p1 <- ggplot(freq_table, aes_string(x = "Cluster", y = "Sample")) +
    geom_point(aes_string(size = 'Proportion', fill = 'Cluster', color = 'Cluster')) +
    scale_size(guide = "none") + theme(
      axis.line = element_blank(),
      axis.text.x = element_blank(),
      axis.title.y = element_blank(),
      axis.ticks = element_blank(),
      axis.title.x = element_blank(),
      legend.title = element_blank(),
      legend.position = 'bottom',
      plot.margin = unit(c(0, 0, 0, 0), "cm"),
      legend.justification = "center"
    ) + scale_y_discrete(position = "right") +
    guides(fill = guide_legend(ncol = 6, override.aes = list(size = 4))) +
    coord_fixed(ratio = 0.5)
  if (return.plot) {
    return(p1)
  }
  else {
    print(p1)
  }
}

#' Plot heatmap of cluster/factor correspondence
#'
#' Generates matrix of cluster/factor correspondence, using sum of row-normalized factor loadings
#' for every cell in each cluster. Plots heatmap of matrix, with red representing high total
#' loadings for a factor, black low. Optionally can also include dendrograms and sorting for
#' factors and clusters.
#'
#' @param object \code{liger} object.
#' @param use.aligned Use quantile normalized factor loadings to generate matrix (default FALSE).
#' @param Rowv Determines if and how the row dendrogram should be computed and reordered. Either a
#'   dendrogram or a vector of values used to reorder the row dendrogram or NA to suppress any row
#'   dendrogram (and reordering) (default NA for no dendrogram).
#' @param Colv Determines if and how the column dendrogram should be reordered. Has the same options
#'   as the Rowv argument (default 'Rowv' to match Rowv).
#' @param col Color map to use (defaults to red and black)
#' @param return.data Return matrix of total factor loadings for each cluster (default FALSE).
#' @param ... Additional parameters to pass on to heatmap()
#'
#' @return If requested, matrix of size num_cluster x num_factor
#'
#' @importFrom grDevices colorRampPalette
#' @importFrom stats heatmap
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 2)
#' ligerex <- quantile_norm(ligerex)
#' ligerex <- louvainCluster(ligerex)
#' plotClusterFactors(ligerex)
plotClusterFactors <- function(object, use.aligned = FALSE, Rowv = NA, Colv = "Rowv", col = NULL,
                               return.data = FALSE, ...) {
  if (use.aligned) {
    data.mat <- object@H.norm
  } else {
    scaled <- lapply(object@H, function(i) {
      scale(i, center = FALSE, scale = TRUE)
    })
    data.mat <- Reduce(rbind, scaled)
  }
  row.scaled <- t(apply(data.mat, 1, function(x) {
    x / sum(x)
  }))
  cluster.bars <- list()
  for (cluster in levels(object@clusters)) {
    cluster.bars[[cluster]] <- colSums(row.scaled[names(object@clusters)
                                                  [which(object@clusters == cluster)], ])

  }
  cluster.bars <- Reduce(rbind, cluster.bars)
  if (is.null(col)) {
    colfunc <- grDevices::colorRampPalette(c("black", "red"))
    col <- colfunc(15)
  }
  rownames(cluster.bars) <- levels(object@clusters)
  colnames(cluster.bars) <- 1:ncol(cluster.bars)
  title <- ifelse(use.aligned, "H.norm", "raw H")
  stats::heatmap(cluster.bars,
                 Rowv = Rowv, Colv = Rowv, col = col, xlab = "Factor", ylab = "Cluster",
                 main = title, ...
  )
  if (return.data) {
    return(cluster.bars)
  }
}

#######################################################################################
#### Marker/Cell Analysis

#' Find shared and dataset-specific markers
#'
#' Applies various filters to genes on the shared (W) and dataset-specific (V) components of the
#' factorization, before selecting those which load most significantly on each factor (in a shared
#' or dataset-specific way).
#'
#' @param object \code{liger} object. Should call optimizeALS before calling.
#' @param dataset1 Name of first dataset (default first dataset by order)
#' @param dataset2 Name of second dataset (default second dataset by order)
#' @param factor.share.thresh Use only factors with a dataset specificity less than or equalt to
#'   threshold (default 10).
#' @param dataset.specificity Pre-calculated dataset specificity if available. Will calculate if not
#'   available.
#' @param log.fc.thresh Lower log-fold change threshold for differential expression in markers
#'   (default 1).
#' @param pval.thresh Upper p-value threshold for Wilcoxon rank test for gene expression
#'   (default 0.05).
#' @param num.genes Max number of genes to report for each dataset (default 30).
#' @param print.genes Print ordered markers passing logfc, umi and frac thresholds (default FALSE).
#' @param verbose Print messages (TRUE by default)
#'
#' @return List of shared and specific factors. First three elements are dataframes of dataset1-
#'   specific, shared, and dataset2-specific markers. Last two elements are tables indicating the
#'   number of factors in which marker appears.
#'
#' @importFrom stats wilcox.test
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' ligerex <- optimizeALS(ligerex, k = 5, max.iter = 2)
#' ligerex <- quantile_norm(ligerex)
#' fm <- getFactorMarkers(ligerex, dataset1 = "stim", dataset2 = "ctrl")
getFactorMarkers <- function(object, dataset1 = NULL, dataset2 = NULL, factor.share.thresh = 10,
                             dataset.specificity = NULL, log.fc.thresh = 1, pval.thresh = 0.05,
                             num.genes = 30, print.genes = FALSE, verbose = TRUE) {
  if (is.null(dataset1) | is.null(dataset2)) {
    dataset1 <- names(object@H)[1]
    dataset2 <- names(object@H)[2]
  }
  if (is.null(num.genes)) {
    num.genes <- length(object@var.genes)
  }
  if (is.null(dataset.specificity)) {
    dataset.specificity <- calcDatasetSpecificity(object, dataset1 = dataset1,
                                                  dataset2 = dataset2, do.plot = FALSE)
  }
  factors.use <- which(abs(dataset.specificity[[3]]) <= factor.share.thresh)

  if (length(factors.use) < 2 && verbose) {
    message(
      "Warning: only ", length(factors.use),
      " factors passed the dataset specificity threshold."
    )
  }

  Hs_scaled <- lapply(object@H, function(x) {
    scale(x, scale = TRUE, center = TRUE)
  })
  labels <- list()
  for (i in 1:length(Hs_scaled)) {
    if (class(object@raw.data[[1]])[1] == "H5File"){
      if (object@h5file.info[[1]][["sample.data.type"]] != "norm.data"){
        stop("norm.data should be sampled for obtaining factor markers.")
      }
      labels[[i]] <- factors.use[as.factor(apply(Hs_scaled[[i]][colnames(object@sample.data[[i]]), factors.use], 1, which.max))]
    } else {
      labels[[i]] <- factors.use[as.factor(apply(Hs_scaled[[i]][, factors.use], 1, which.max))]
    }
  }
  names(labels) <- names(object@H)

  V1_matrices <- list()
  V2_matrices <- list()
  W_matrices <- list()
  for (j in 1:length(factors.use)) {
    i <- factors.use[j]

    W <- t(object@W)
    V1 <- t(object@V[[dataset1]])
    V2 <- t(object@V[[dataset2]])
    rownames(W) <- rownames(V1) <- rownames(V2) <- object@var.genes
    # if not max factor for any cell in either dataset
    if (sum(labels[[dataset1]] == i) <= 1 | sum(labels[[dataset2]] == i) <= 1) {
      message("Warning: factor", i, "did not appear as max in any cell in either dataset")
      next
    }
    # filter genes by gene_count and cell_frac thresholds
    if (class(object@raw.data[[1]])[1] == "H5File") {
      if (object@h5file.info[[1]][["sample.data.type"]] != "norm.data"){
        stop("Sampled norm.data are required for this analysis.")
      }
      expr_mat = Reduce(cbind, object@sample.data[c(dataset1,dataset2)])[object@var.genes, c(labels[[dataset1]] == i, labels[[dataset2]] == i)]
      cell_label = rep(c(dataset1, dataset2), c(sum(labels[[dataset1]] == i), sum(labels[[dataset2]] == i)))
      wilcoxon_result = wilcoxauc(log(expr_mat + 1e-10), cell_label)

    } else {
      expr_mat = cbind(object@norm.data[[dataset1]][object@var.genes, labels[[dataset1]] == i],
                       object@norm.data[[dataset2]][object@var.genes, labels[[dataset2]] == i])
      cell_label = rep(c(dataset1, dataset2), c(sum(labels[[dataset1]] == i), sum(labels[[dataset2]] == i)))
      wilcoxon_result = wilcoxauc(log(expr_mat + 1e-10), cell_label)
    }
    log2fc = wilcoxon_result[wilcoxon_result$group == dataset1, ]$logFC
    names(log2fc) = wilcoxon_result[wilcoxon_result$group == dataset1, ]$feature
    filtered_genes_V1 = wilcoxon_result[wilcoxon_result$logFC > log.fc.thresh & wilcoxon_result$pval < pval.thresh, ]$feature
    filtered_genes_V2 = wilcoxon_result[-wilcoxon_result$logFC > log.fc.thresh & wilcoxon_result$pval < pval.thresh, ]$feature

    W <- pmin(W + V1, W + V2)
    V1 <- V1[filtered_genes_V1, , drop = FALSE]
    V2 <- V2[filtered_genes_V2, , drop = FALSE]

    if (length(filtered_genes_V1) == 0) {
      top_genes_V1 <- character(0)
    } else {
      top_genes_V1 <- row.names(V1)[order(V1[, i], decreasing = TRUE)[1:num.genes] ]
      top_genes_V1 <- top_genes_V1[!is.na(top_genes_V1)]
      top_genes_V1 <- top_genes_V1[which(V1[top_genes_V1, i] > 0)]
    }
    if (length(filtered_genes_V2) == 0) {
      top_genes_V2 <- character(0)
    } else {
      top_genes_V2 <- row.names(V2)[order(V2[, i], decreasing = TRUE)[1:num.genes] ]
      top_genes_V2 <- top_genes_V2[!is.na(top_genes_V2)]
      top_genes_V2 <- top_genes_V2[which(V2[top_genes_V2, i] > 0)]
    }
    top_genes_W <- row.names(W)[order(W[, i], decreasing = TRUE)[1:num.genes] ]
    top_genes_W <- top_genes_W[!is.na(top_genes_W)]
    top_genes_W <- top_genes_W[which(W[top_genes_W, i] > 0)]

    if (print.genes && verbose) {
      message("Factor ", i)
      message('Dataset 1')
      message(top_genes_V1)
      message('Shared')
      message(top_genes_W)
      message('Dataset 2')
      message(top_genes_V2)
    }

    pvals <- list() # order is V1, V2, W
    top_genes <- list(top_genes_V1, top_genes_V2, top_genes_W)
    for (k in 1:length(top_genes)) {
      pvals[[k]] <- wilcoxon_result[wilcoxon_result$feature %in% top_genes[[k]] & wilcoxon_result$group == dataset1, ]$pval
    }
    # bind values in matrices
    V1_matrices[[j]] <- Reduce(cbind, list(
      rep(i, length(top_genes_V1)), top_genes_V1,
      log2fc[top_genes_V1], pvals[[1]]
    ))
    V2_matrices[[j]] <- Reduce(cbind, list(
      rep(i, length(top_genes_V2)), top_genes_V2,
      log2fc[top_genes_V2], pvals[[2]]
    ))
    W_matrices[[j]] <- Reduce(cbind, list(
      rep(i, length(top_genes_W)), top_genes_W,
      log2fc[top_genes_W], pvals[[3]]
    ))
  }
  V1_genes <- data.frame(Reduce(rbind, V1_matrices), stringsAsFactors = FALSE)
  V2_genes <- data.frame(Reduce(rbind, V2_matrices), stringsAsFactors = FALSE)
  W_genes <- data.frame(Reduce(rbind, W_matrices), stringsAsFactors = FALSE)
  df_cols <- c("factor_num", "gene", "log2fc", "p_value")
  output_list <- list(V1_genes, W_genes, V2_genes)
  output_list <- lapply(seq_along(output_list), function(x) {
    df <- output_list[[x]]
    colnames(df) <- df_cols
    df <- transform(df,
                    factor_num = as.numeric(df$'factor_num'), gene = as.character(df$'gene'),
                    log2fc = as.numeric(df$'log2fc'), p_value = as.numeric(df$'p_value')
    )
    # Cutoff only applies to dataset-specific dfs
    if (x != 2) {
      df[which(df$p_value < pval.thresh), ]
    } else {
      df
    }
  })
  names(output_list) <- c(dataset1, "shared", dataset2)
  output_list[["num_factors_V1"]] <- table(output_list[[dataset1]]$'gene')
  output_list[["num_factors_V2"]] <- table(output_list[[dataset2]]$'gene')
  return(output_list)
}

#######################################################################################
#### Conversion/Transformation

#' Create a Seurat object containing the data from a liger object
#'
#' Merges raw.data and scale.data of object, and creates Seurat object with these values along with
#' tsne.coords, iNMF factorization, and cluster assignments. Supports Seurat V2 and V3.
#'
#' Stores original dataset identity by default in new object metadata if dataset names are passed
#' in nms. iNMF factorization is stored in dim.reduction object with key "iNMF".
#'
#' @param object \code{liger} object.
#' @param nms By default, labels cell names with dataset of origin (this is to account for cells in
#'   different datasets which may have same name). Other names can be passed here as vector, must
#'   have same length as the number of datasets. (default names(H))
#' @param renormalize Whether to log-normalize raw data using Seurat defaults (default FALSE).
#' @param use.liger.genes Whether to carry over variable genes (default TRUE).
#' @param by.dataset Include dataset of origin in cluster identity in Seurat object (default FALSE).
#' @param assay Assay name to set in the Seurat object (default "RNA").
#' @return Seurat object with raw.data, scale.data, dr$tsne, dr$inmf, and ident slots set.
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' if (packageVersion("Matrix") <= package_version("1.6.1.1")) {
#'   # 1.6.2 is not compatible thus don't test
#'   # but can use `setOldClass("mMatrix")` as a hack
#'   srt <- ligerToSeurat(ligerex)
#' }
ligerToSeurat <- function(object, nms = NULL, renormalize = FALSE, use.liger.genes = TRUE,
                          by.dataset = FALSE, assay = "RNA") {
  if (!requireNamespace("Seurat", quietly = TRUE)) {
    stop("Package \"Seurat\" needed for this function to work. Please install it.",
         call. = FALSE)
  }
  if (!inherits(object@raw.data[[1]], 'dgCMatrix')) {
    object@raw.data <- lapply(object@raw.data, as, Class = "CsparseMatrix")
  }
  raw.data <- MergeSparseDataAll(object@raw.data, nms)
  new.seurat <- Seurat::CreateSeuratObject(raw.data, assay = assay)
  if (isTRUE(renormalize)) {
    new.seurat <- Seurat::NormalizeData(new.seurat)
  } else {
    if (length(object@norm.data) > 0) {
      norm.data <- MergeSparseDataAll(object@norm.data, nms)
      new.seurat <- SeuratObject::SetAssayData(new.seurat, layer = "data", slot = "data", new.data = norm.data)  
    }
  }
  if (length(object@var.genes) > 0 && use.liger.genes) {
    Seurat::VariableFeatures(new.seurat) <- object@var.genes
  }
  if (length(object@scale.data) > 0) {
    scale.data <- t(Reduce(rbind, object@scale.data))
    colnames(scale.data) <- colnames(raw.data)
    new.seurat <- SeuratObject::SetAssayData(object = new.seurat, layer = "scale.data", slot = "scale.data", new.data = scale.data)
  }
  if (all(dim(object@W) > 0) && all(dim(object@H.norm) > 0)) {
    inmf.loadings <- t(x = object@W)
    dimnames(inmf.loadings) <- list(object@var.genes, 
                                    paste0("iNMF_", seq_len(ncol(inmf.loadings))))
    inmf.embeddings <- object@H.norm
    dimnames(inmf.embeddings) <- list(unlist(lapply(object@scale.data, rownames), use.names = FALSE),
                                      paste0("iNMF_", seq_len(ncol(inmf.loadings))))
    inmf.obj <- Seurat::CreateDimReducObject(
      embeddings = inmf.embeddings,
      loadings = inmf.embeddings,
      assay = assay, 
      key = "iNMF_"
    )
    new.seurat[["iNMF"]] <- inmf.obj
  }
  if (all(dim(object@tsne.coords) > 0)) {
    tsne.embeddings <- object@tsne.coords
    dimnames(tsne.embeddings) <- list(rownames(object@H.norm),
                                      c("TSNE_1", "TSNE_2"))
    tsne.obj <- Seurat::CreateDimReducObject(
      embeddings = tsne.embeddings,
      assay = assay,
      key = "TSNE_"
    )
    new.seurat[["TSNE"]] <- tsne.obj
  }
  new.seurat$orig.ident <- object@cell.data$dataset
  
  idents <- object@clusters
  if (length(idents) == 0 || isTRUE(by.dataset)) idents <- object@cell.data$dataset
  Seurat::Idents(new.seurat) <- idents

  return(new.seurat)
}

#' Create liger object from one or more Seurat objects
#'
#' This function creates a \code{liger} object from multiple (disjoint) Seurat objects or a single
#' (combined-analysis) Seurat object. It includes options for keeping the variable genes and cluster
#' identities from the original Seurat objects. Seurat V2 and V3 supported (though all objects
#' should share the same major version).
#'
#' @param objects One or more Seurat v2 objects. If passing multiple objects, should be in list.
#' @param combined.seurat Whether Seurat object (single) already contains multiple datasets (default
#'   FALSE).
#' @param names Names to use for datasets in new liger object. If use-projects, takes project names
#'   from individual Seurat objects; if use-meta, takes value of object meta.data in meta.var column
#'   for each dataset; otherwise, user can pass in vector of names with same length
#'   as number of datasets. If combined.seurat, infers project names based on whether meta.var
#'   or assays.use is present (at least one required).
#' @param meta.var Seurat meta.data column name to use in naming datasets. Either meta.var or
#'   assays.use required if combined.seurat is TRUE (default NULL).
#' @param assays.use Names of Seurat v3 assays to use as separate datasets in conversion (e.g. RNA,
#'   ADT) (default NULL).
#' @param raw.assay Name of Seurat v3 assay to use for raw data if meta.var used to split combined
#'   Seurat object -- in case integrated assay has been set as default (default "RNA").
#' @param remove.missing Whether to remove missing genes/cells when converting raw.data to liger object
#'   (default TRUE).
#' @param renormalize Whether to automatically normalize raw.data once \code{liger} object is created
#'   (default TRUE).
#' @param use.seurat.genes Carry over variable genes from Seurat objects. If num.hvg.info is set, uses
#'   that value to get top most highly variable genes from hvg.info slot in Seurat objects. Otherwise
#'   uses var.genes slot in Seurat objects. For multiple datasets, takes the union of the variable
#'   genes. (default TRUE)
#' @param num.hvg.info Number of highly variable genes to include from each object's hvg.info slot.
#'   Only available for Seurat v2 objects. If set, recommended value is 2000 (default NULL).
#' @param use.idents Carry over cluster identities from Seurat objects. If multiple objects with
#'   overlapping cluster names, will preface cluster names by dataset names to distinguish. (default
#'   TRUE).
#' @param use.tsne Carry over t-SNE coordinates from Seurat object (only meaningful for combined
#'   analysis Seurat object). Useful for plotting directly afterwards. (default TRUE)
#' @param cca.to.H Carry over CCA (and aligned) loadings and insert them into H (and H.norm) slot in
#'   liger object (only meaningful for combined analysis Seurat object). Useful for plotting directly
#'   afterwards. (default FALSE)
#' @return \code{liger} object.
#' @export
#' @examples
#' if (packageVersion("Matrix") <= package_version("1.6.1.1")) {
#'   ctrl.srt <- Seurat::CreateSeuratObject(ctrl, project = "ctrl")
#'   stim.srt <- Seurat::CreateSeuratObject(stim, project = "stim")
#'   ligerex <- seuratToLiger(list(ctrl = ctrl.srt, stim = stim.srt),
#'                            use.seurat.genes = FALSE)
#' }
seuratToLiger <- function(objects, combined.seurat = FALSE, names = "use-projects", meta.var = NULL,
                          assays.use = NULL, raw.assay = "RNA", remove.missing = TRUE, renormalize = TRUE,
                          use.seurat.genes = TRUE, num.hvg.info = NULL, use.idents = TRUE, use.tsne = TRUE,
                          cca.to.H = FALSE) {
  if (!requireNamespace("Seurat", quietly = TRUE)) {
    stop("Package \"Seurat\" needed for this function to work. Please install it.",
         call. = FALSE
    )
  }

  # Remind to set combined.seurat
  if ((typeof(objects) != "list") & (!combined.seurat)) {
    stop("Please pass a list of objects or set combined.seurat = TRUE")
  }
  # Get Seurat versions
  if (typeof(objects) != "list") {
    version <- package_version(objects@version)$major
  } else {
    version <- sapply(objects, function(x) {
      package_version(x@version)$major
    })
    if (min(version) != max(version)) {
      stop("Please ensure all Seurat objects have the same major version.")
    } else {
      version <- version[1]
    }
  }

  # Only a single seurat object expected if combined.seurat
  if (combined.seurat) {
    if ((is.null(meta.var)) & (is.null(assays.use))) {
      stop("Please provide Seurat meta.var or assays.use to use in identifying individual datasets.")
    }
    if (!is.null(meta.var)) {
      # using meta.var column as division split
      if (version > 2) {
        # if integrated assay present, want to make sure to use original raw data
        object.raw <- Seurat::GetAssayData(objects, assay = raw.assay, slot = "counts")
      } else {
        object.raw <- objects@raw.data
      }
      if (nrow(objects@meta.data) != ncol(object.raw)) {
        message("Warning: Mismatch between meta.data and raw.data in this Seurat object. \nSome cells",
                "will not be assigned to a raw dataset. \nRepeat Seurat analysis without filters to",
                "allow all cells to be assigned.\n")
      }
      raw.data <- lapply(unique(objects@meta.data[[meta.var]]), function(x) {
        cells <- rownames(objects@meta.data[objects@meta.data[[meta.var]] == x, ])
        object.raw[, cells]
      })
      names(raw.data) <- unique(objects@meta.data[[meta.var]])
    } else {
      # using different assays in v3 object
      raw.data <- lapply(assays.use, function(x) {
        Seurat::GetAssayData(objects, assay = x, slot = "counts")
      })
      names(raw.data) <- assays.use
    }

    if (version > 2) {
      var.genes <- Seurat::VariableFeatures(objects)
      idents <- Seurat::Idents(objects)
      if (is.null(objects@reductions$tsne)) {
        message("Warning: no t-SNE coordinates available for this Seurat object.")
        tsne.coords <- NULL
      } else {
        tsne.coords <- objects@reductions$tsne@cell.embeddings
      }
    } else {
      # Get var.genes
      var.genes <- objects@var.genes
      # Get idents/clusters
      idents <- objects@ident
      # Get tsne.coords
      if (is.null(objects@dr$tsne)) {
        message("Warning: no t-SNE coordinates available for this Seurat object.")
        tsne.coords <- NULL
      } else {
        tsne.coords <- objects@dr$tsne@cell.embeddings
      }
    }
  } else {
    # for multiple Seurat objects
    raw.data <- lapply(objects, function(x) {
      if (version > 2) {
        # assuming default assays have been set for each v3 object
        Seurat::GetAssayData(x, slot = "counts")
      } else {
        x@raw.data
      }
    })
    names(raw.data) <- lapply(seq_along(objects), function(x) {
      if (identical(names, "use-projects")) {
        if (!is.null(meta.var)) {
          message("Warning: meta.var value is set - set names = 'use-meta' to use meta.var for names.\n")
        }
        objects[[x]]@project.name
      } else if (identical(names, "use-meta")) {
        if (is.null(meta.var)) {
          stop("Please provide meta.var to use in naming individual datasets.")
        }
        objects[[x]]@meta.data[[meta.var]][1]
      } else {
        names[x]
      }
    })
    # tsne coords not very meaningful for separate objects
    tsne.coords <- NULL

    if (version > 2) {
      var.genes <- Reduce(union, lapply(objects, function(x) {
        Seurat::VariableFeatures(x)
      }))
      # Get idents, label by dataset
      idents <- unlist(lapply(seq_along(objects), function(x) {
        idents <- rep("NA", ncol(raw.data[[x]]))
        names(idents) <- colnames(raw.data[[x]])
        idents[names(Seurat::Idents(objects[[x]]))] <- as.character(Seurat::Idents(objects[[x]]))
        idents <- paste0(names(raw.data)[x], idents)
      }))
      idents <- factor(idents)
    } else {
      var.genes <- Reduce(union, lapply(objects, function(x) {
        if (!is.null(num.hvg.info)) {
          rownames(head(x@hvg.info, num.hvg.info))
        } else {
          x@var.genes
        }
      }))
      # Get idents, label by dataset
      idents <- unlist(lapply(seq_along(objects), function(x) {
        idents <- rep("NA", ncol(objects[[x]]@raw.data))
        names(idents) <- colnames(objects[[x]]@raw.data)
        idents[names(objects[[x]]@ident)] <- as.character(objects[[x]]@ident)
        idents <- paste0(names(raw.data)[x], idents)
      }))
      idents <- factor(idents)
    }
  }
  new.liger <- createLiger(raw.data = raw.data, remove.missing = remove.missing)
  if (renormalize) {
    new.liger <- normalize(new.liger)
  }
  if (use.seurat.genes) {
    # Include only genes which appear in all datasets
    for (i in 1:length(new.liger@raw.data)) {
      var.genes <- intersect(var.genes, rownames(new.liger@raw.data[[i]]))
      # Seurat has an extra CheckGenes step which we can include here
      # Remove genes with no expression anywhere
      var.genes <- var.genes[rowSums(new.liger@raw.data[[i]][var.genes, ]) > 0]
      var.genes <- var.genes[!is.na(var.genes)]
    }

    new.liger@var.genes <- var.genes
  }
  if (use.idents) {
    new.liger@clusters <- idents
  }
  if ((use.tsne) & (!is.null(tsne.coords))) {
    new.liger@tsne.coords <- tsne.coords
  }
  # Get CCA loadings if requested
  if (cca.to.H & combined.seurat) {
    if (version > 2) {
      message("Warning: no CCA loadings available for Seurat v3 objects.\n")
      return(new.liger)
    }
    if (is.null(objects@dr$cca)) {
      message("Warning: no CCA loadings available for this Seurat object.\n")
    } else {
      new.liger@H <- lapply(unique(objects@meta.data[[meta.var]]), function(x) {
        cells <- rownames(objects@meta.data[objects@meta.data[[meta.var]] == x, ])
        objects@dr$cca@cell.embeddings[cells, ]
      })
      new.liger@H <- lapply(seq_along(new.liger@H), function(x) {
        addMissingCells(new.liger@raw.data[[x]], new.liger@H[[x]])
      })
      names(new.liger@H) <- names(new.liger@raw.data)
    }
    if (is.null(objects@dr$cca.aligned)) {
      message("Warning: no aligned CCA loadings available for this Seurat object.\n")
    } else {
      new.liger@H.norm <- objects@dr$cca.aligned@cell.embeddings
      new.liger@H.norm <- addMissingCells(Reduce(rbind, new.liger@H), new.liger@H.norm,
                                          transpose = TRUE)
    }
  }
  return(new.liger)
}

#' Construct a liger object with a specified subset
#'
#' The subset can be based on cell names or clusters. This function applies the subsetting to
#' raw.data, norm.data, scale.data, cell.data, H, W, V, H.norm, tsne.coords, and clusters.
#' Note that it does NOT reoptimize the factorization. See optimizeSubset for this functionality.
#'
#' @param object \code{liger} object. Should run quantileAlignSNF and runTSNE before calling.
#' @param clusters.use Clusters to use for subset.
#' @param cells.use Vector of cell names to keep from any dataset.
#' @param remove.missing Whether to remove genes/cells with no expression when creating new object
#'   (default TRUE).
#'
#' @return \code{liger} object with subsetting applied to raw.data, norm.data, scale.data, H, W, V,
#'   H.norm, tsne.coords, and clusters.
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' lig.small <- subsetLiger(ligerex, cells.use = c(colnames(ctrl)[1:100], colnames(stim)[1:100]))
subsetLiger <- function(object, clusters.use = NULL, cells.use = NULL, remove.missing = TRUE) {
  if (!is.null(clusters.use)) {
    cells.use <- names(object@clusters)[which(object@clusters %in% clusters.use)]
  }
  raw.data <- lapply(seq_along(object@raw.data), function(q) {
    cells <- intersect(cells.use, colnames(object@raw.data[[q]]))
    if (length(cells) > 0) {
      if (length(cells) < 25) {
        warning("Number of subsetted cells too small (less than 25), please check cells.use!")
      }
      object@raw.data[[q]][, cells, drop = FALSE]
    } else {
      warning("Selected subset eliminates dataset ", names(object@raw.data)[q])
      return(NULL)
    }
  })
  missing <- sapply(raw.data, is.null)
  raw.data <- raw.data[!missing]
  nms <- names(object@raw.data)[!missing]
  names(raw.data) <- nms
  a <- createLiger(raw.data, remove.missing = remove.missing)
  a@norm.data <- lapply(1:length(a@raw.data), function(i) {
    object@norm.data[[nms[i]]][, colnames(a@raw.data[[i]])]
  })
  a@scale.data <- lapply(1:length(a@raw.data), function(i) {
    object@scale.data[[nms[i]]][colnames(a@raw.data[[i]]), ]
  })
  a@H <- lapply(1:length(a@raw.data), function(i) {
    object@H[[nms[i]]][colnames(a@raw.data[[i]]), ]
  })
  a@clusters <- object@clusters[unlist(lapply(a@H, rownames))]
  a@clusters <- droplevels(a@clusters)
  a@tsne.coords <- object@tsne.coords[names(a@clusters), ]
  a@H.norm <- object@H.norm[names(a@clusters), ]
  cell.names <- unname(unlist(lapply(a@raw.data, colnames)))
  # Add back additional cell.data
  if (ncol(a@cell.data) < ncol(object@cell.data)) {
    a@cell.data <- droplevels(data.frame(object@cell.data[cell.names, ]))
  }

  a@W <- object@W
  a@V <- object@V
  a@var.genes <- object@var.genes
  names(a@scale.data) <- names(a@norm.data) <- names(a@H) <- nms
  return(a)
}

#' Construct a liger object organized by another feature
#'
#' Using the same data, rearrange functional datasets using another discrete feature in cell.data.
#' This removes most computed data slots, though cell.data and current clustering can be retained.
#'
#' @param object \code{liger} object.
#' @param by.feature Column in cell.data to use in reorganizing raw data.
#' @param keep.meta Whether to carry over all existing data in cell.data slot (default TRUE).
#' @param new.label If cell.data is to be retained, new column name for original organizing feature
#'   (previously labeled as dataset) (default "orig.dataset")
#' @param ... Additional parameters passed on to createLiger.
#'
#' @return \code{liger} object with rearranged raw.data slot.
#'
#' @import Matrix
#'
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' # Create a random variable of two categories
#' ligerex@cell.data$foo <- factor(sample(c(1,2), 600, replace = TRUE))
#' ligerexFoo <- reorganizeLiger(ligerex, "foo")
reorganizeLiger <- function(object, by.feature, keep.meta = TRUE, new.label = "orig.dataset",
                            ...) {
  if (!(by.feature %in% colnames(object@cell.data))) {
    stop("Please select existing feature in cell.data to reorganize by, or add it before calling.")
  }
  if(!is.factor(object@cell.data[, by.feature])){
    stop("Error: cell.data feature must be of class 'factor' to reorganize object.  Please change column to factor and re-run reorganizeLiger")
  }
  if (length(object@clusters) > 0) {
    object@cell.data[['orig.clusters']] <- object@clusters
  }
  orig.data <- object@cell.data
  colnames(orig.data)[colnames(orig.data) == "dataset"] <- new.label

  # make this less memory intensive for large datasets
  all.data <- MergeSparseDataAll(object@raw.data)

  new.raw <- lapply(levels(orig.data[[by.feature]]), function(x) {
    cells.keep <- rownames(orig.data)[which(orig.data[[by.feature]] == x)]
    all.data[, cells.keep]
  })
  names(new.raw) <- levels(orig.data[[by.feature]])
  rm(all.data)
  gc()
  new.object <- createLiger(raw.data = new.raw, ...)

  if (keep.meta) {
    cols.to.add <- setdiff(colnames(orig.data), colnames(new.object@cell.data))
    cols.to.add <- cols.to.add[which(cols.to.add != by.feature)]
    for (col in cols.to.add) {
      new.object@cell.data[[col]] = orig.data[rownames(new.object@cell.data), col]
    }
  }
  return(new.object)
}

#' Convert older liger object into most current version (based on class definition)
#'
#' Also works for Analogizer objects (but must have both liger and Analogizer loaded). Transfers
#' data in slots with same names from old class object to new, leaving slots defined only in new
#' class NULL.
#'
#' @param object \code{liger} object.
#' @param override.raw Keep original raw.data without any modifications (removing missing cells
#'   etc.) (defualt FALSE).
#' @param verbose Print progress bar/messages (TRUE by default)
#'
#' @return Updated \code{liger} object.
#'
#' @importFrom methods .hasSlot slot slotNames
#'
#' @export
#' @examples
#' \dontrun{
#' # Not able to generate old object from current version, thus not run
#' ligerex <- convertOldLiger(analogy)
#' }
convertOldLiger = function(object, override.raw = FALSE, verbose = TRUE) {
  new.liger <- createLiger(object@raw.data)
  slots_new <- slotNames(new.liger)
  slots_old <- slotNames(object)
  slots_exist <- sapply(slots_new, function(x) {
    .hasSlot(object, x)
  })

  slots <- slots_new[slots_exist]
  for (slotname in slots) {
    if (!(slotname %in% c('raw.data')) | (override.raw)) {
      slot(new.liger, slotname) <- slot(object, slotname)
    }
  }
  if (verbose) {
    message('Old slots not transferred: ', setdiff(slots_old, slots_new))
    # compare to slots since it's possible that the analogizer object
    # class has slots that this particular object does not
    message('New slots not filled: ', setdiff(slots_new[slots_new != "cell.data"], slots))
  }
  return(new.liger)
}

#' Perform iNMF on scaled datasets, and include unshared, scaled and normalized, features
#' @param object \code{liger} object. Should normalize, select genes, and scale before calling.
#' @param k Inner dimension of factorization (number of factors).
#' @param lambda The lambda penalty. Default 5
#' @param thresh Convergence threshold. Convergence occurs when |obj0-obj|/(mean(obj0,obj)) < thresh.
#'   (default 1e-6)
#' @param max.iters Maximum number of block coordinate descent iterations to perform (default 30).
#' @param nrep Number of restarts to perform (iNMF objective function is non-convex, so taking the
#'   best objective from multiple successive initializations is recommended). For easier
#'   reproducibility, this increments the random seed by 1 for each consecutive restart, so future
#'   factorizations of the same dataset can be run with one rep if necessary. (default 1)
#' @param rand.seed Random seed to allow reproducible results (default 1).
#' @param print.obj  Print objective function values after convergence (default FALSE).
#' @param vectorized.lamba Whether or not to expect a vectorized lambda parameter
#' @noRd
optimize_UANLS = function(object, k=30,lambda= 5, max.iters=30,nrep=1,thresh=1e-10,rand.seed=1, print.obj = FALSE, vectorized.lambda = FALSE){

  set.seed(seed =rand.seed)
  #Account for vectorized lambda
  print('Performing Factorization using UINMF and unshared features')
  if (vectorized.lambda == FALSE){
    lambda = rep(lambda, length(names(object@raw.data)))
  }

  # Get a list of all the matrices
  mlist = list()
  xdim =  list()
  for (i in 1:length(object@scale.data)){
    mlist[[i]] = t(object@scale.data[[i]])
    xdim[[i]] = dim(mlist[[i]])
  }

  #return what datasets have unshared features, and the dimensions of those unshared features
  u_dim <- c()
  max_feats = 0
  unshared <- c()
  ulist <- c()
  for (i in 1:length(object@var.unshared.features)){
    if(length(object@var.unshared.features[[i]])){
      u_dim[[i]] <- dim(object@scale.unshared.data[[i]])
      names(u_dim[i]) <- i
      unshared = c(unshared, i)
      if (u_dim[[i]][2] > max_feats){
        max_feats = u_dim[[i]][1]
      }
      ulist[[i]] = t(object@scale.unshared.data[[i]])
    }
  }
  ############## For every set of additional features less than the maximum, append an additional zero matrix s.t. it matches the maximum
  for (i in 1:length(object@scale.data)){
    if (i %in% unshared){
      mlist[[i]] <-  rbind(mlist[[i]],object@scale.unshared.data[[i]])
    }
    #For the U matrix with the maximum amount of features, append the whole thing
    else {
      mlist[[i]] <- rbind(mlist[[i]])
    }
  }

  X <- mlist
  ################# Create an 0 matrix the size of U for all U's, s.t. it can be stacked to W
  zero_matrix_u_full <- c()
  zero_matrix_u_partial <- c()
  for (i in 1:length(object@raw.data)){
    if (i %in% unshared){
      zero_matrix_u_full[[i]] <- matrix(0, nrow = u_dim[[i]][1], ncol = u_dim[[i]][2])
      zero_matrix_u_partial[[i]] <- matrix(0, nrow = u_dim[[i]][1], ncol = k)
    }
  }

  num_cells = c()
  for (i in 1:length(X)){
    num_cells = c(num_cells, ncol(X[[i]]))
  }

  num_genes = length(object@var.genes)

  best_obj <- Inf
  for (i in 1:nrep){
    print("Processing")
    current <- rand.seed + i -1
    # initialization
    idX = list()
    for (i in 1:length(X)){
      idX[[i]] = sample(1:num_cells[i], k)
    }
    V = list()

    #Establish V from only the RNA dimensions

    for (i in 1:length(X)){
      V[[i]] = t(object@scale.data[[i]])[,idX[[i]]]
    }
    #Establish W from the shared gene dimensions

    W = matrix(abs(runif(num_genes * k, 0, 2)), num_genes, k)

    H = list()

    #Initialize U
    U = list()
    for (i in 1:length(X)){
      if (i %in% unshared){
        U[[i]] = t(ulist[[i]])[,idX[[i]]]
      }
    }

    iter = 0
    total_time = 0
    pb <- txtProgressBar(min = 0, max = max.iters, style = 3)
    sqrt_lambda = list()
    for (i in 1:length(X)){
      sqrt_lambda[[i]]= sqrt(lambda[[i]])
    }
    ############################ Initial Training Objects

    obj_train_approximation = 0
    obj_train_penalty = 0

    for (i in 1:length(X)){
      H[[i]] = matrix(abs(runif(k * num_cells[i], 0, 2)), k, num_cells[i])
      if (i %in% unshared){
        obj_train_approximation = obj_train_approximation + norm(X[[i]] - (rbind(W,zero_matrix_u_partial[[i]]) + rbind(V[[i]],U[[i]])) %*% H[[i]],"F")^2
        obj_train_penalty = obj_train_penalty + lambda[[i]]*norm(rbind(V[[i]],U[[i]])%*% H[[i]], "F")^2
      }
      else {
        obj_train_approximation = obj_train_approximation + norm(X[[i]] - (W+ V[[i]]) %*% H[[i]],"F")^2
        obj_train_penalty = obj_train_penalty + lambda[[i]]*norm(V[[i]]%*% H[[i]], "F")^2

      }
    }
    obj_train = obj_train_approximation + obj_train_penalty

    ######################### Initialize Object Complete ###########################
    ########################## Begin Updates########################################
    delta = Inf
    objective_value_list = list()

    iter = 1
    while(delta > thresh & iter <= max.iters){
      iter_start_time = Sys.time()


      #H- Updates
      for (i in 1:length(X)){
        if (!(i %in% unshared)){
          H[[i]] = solveNNLS(rbind((W + V[[i]]), sqrt_lambda[[i]] * V[[i]]), rbind(X[[i]], matrix(0, num_genes, xdim[[i]][2])))
        }
        else{
          H[[i]] = solveNNLS(rbind(rbind(W,zero_matrix_u_partial[[i]]) + rbind((V[[i]]),U[[i]]), sqrt_lambda[[i]] * rbind(V[[i]],U[[i]])), rbind((X[[i]]), matrix(0, num_genes+ u_dim[[i]][1], xdim[[i]][2])))
        }
      }

      #V - updates
      for (i in 1:length(X)){
        V[[i]] = t(solveNNLS(rbind(t(H[[i]]), sqrt_lambda[[i]] * t(H[[i]])), rbind(t(X[[i]][0:num_genes,] - W %*% H[[i]]), matrix(0, num_cells[i], num_genes))))
      }
      ################################################# Updating U##################################

      for (i in 1:length(X)){
        if (i %in% unshared){
          U[[i]] = t(solveNNLS(rbind(t(H[[i]]),sqrt_lambda[[i]]* t(H[[i]])), rbind(t(X[[i]][(num_genes+1):(u_dim[[i]][1]+num_genes), ]),t(zero_matrix_u_full[[i]]))))
        }
      }


      ##############################################################################################
      ################################################# Updating W #################################
      H_t_stack = c()
      for (i in 1:length(X)){
        H_t_stack = rbind(H_t_stack, t(H[[i]]))
      }
      diff_stack_w = c()
      for (i in 1:length(X)){
        diff_stack_w = rbind(diff_stack_w,t(X[[i]][0:num_genes,] - V[[i]] %*% H[[i]]))
      }
      W = t(solveNNLS(H_t_stack, diff_stack_w))

      ############################################################################################
      iter_end_time = Sys.time()
      iter_time = as.numeric(difftime(iter_end_time, iter_start_time, units = "secs"))
      total_time = total_time + iter_time

      #Updating training object
      obj_train_prev = obj_train
      obj_train_approximation = 0
      obj_train_penalty = 0



      for (i in 1:length(X)){
        if (i %in% unshared){
          obj_train_approximation = obj_train_approximation + norm(X[[i]] - (rbind(W,zero_matrix_u_partial[[i]]) + rbind(V[[i]],U[[i]])) %*% H[[i]],"F")^2
          obj_train_penalty = obj_train_penalty + lambda[[i]]*norm(rbind(V[[i]],U[[i]])%*% H[[i]], "F")^2
        }
        else {
          obj_train_approximation = obj_train_approximation + norm(X[[i]] - (W+ V[[i]]) %*% H[[i]],"F")^2
          obj_train_penalty = obj_train_penalty + lambda[[i]]*norm(V[[i]]%*% H[[i]], "F")^2

        }
      }

      obj_train = obj_train_approximation + obj_train_penalty
      delta = abs(obj_train_prev-obj_train)/mean(c(obj_train_prev,obj_train))
      iter = iter + 1
      setTxtProgressBar(pb = pb, value = iter)
    }
    setTxtProgressBar(pb = pb, value = max.iters)
    cat("\nCurrent seed ",  current , " current objective ", obj_train)
    if (obj_train < best_obj){
      W_m <- W
      H_m <- H
      V_m <- V
      U_m <- U
      best_obj <- obj_train
      best_seed <- current
    }
  }

  rownames(W_m) = rownames(X[[1]][0:xdim[[i]][1],])
  colnames(W_m) = NULL

  for (i in 1:length(X)){
    if (i %in% unshared){
      rownames(U_m[[i]]) = rownames(X[[i]][(num_genes+1):(u_dim[[i]][1]+num_genes), ])
      colnames(U_m[[i]]) = NULL
    }
    rownames(V_m[[i]]) = rownames(X[[i]][0:xdim[[i]][1],])
    colnames(V_m[[i]]) = NULL
    colnames(H_m[[i]]) = colnames(X[[i]])
  }

  ################################## Returns Results Section #########################################################
  object@W <- t(W_m)
  for (i in 1:length(X)){
    object@V[[i]] <- t(V_m[[i]])
    object@H[[i]] <- t(H_m[[i]])
    if(i %in% unshared){
      object@U[[i]] <- t(U_m[[i]])
    }
  }
  titles <- names(object@raw.data)
  names(object@H) <- titles
  names(object@V) <- titles
  if(i %in% unshared){
    names(object@U) <- titles
  }
  if (print.obj) {
    cat("\n", "Objective:", best_obj, "\n")
  }

  rel_cells = list()
  for (i in 1:length(X)){
    rel_cells <- c(rel_cells, rownames(object@scale.data[[i]]))
  }
  rel_cells <- unlist(rel_cells)

  object@cell.data <- object@cell.data[rel_cells,]
  cat("\n", "Best results with seed ", best_seed, ".\n", sep = "")
  return (object)
}



#' Calculate loadings for each factor
#'
#' Calculates the contribution of each factor of W,V, and U to the reconstruction.
#'
#' @param object \code{liger} object. Should call quantileNorm before calling.
#' @return A dataframe, such that each column represents the contribution of a specific matrix (W, V_1, V_2, etc. )
#' @export
#' @examples
#' ligerex <- createLiger(list(ctrl = ctrl, stim = stim))
#' ligerex <- normalize(ligerex)
#' ligerex <- selectGenes(ligerex)
#' ligerex <- scaleNotCenter(ligerex)
#' # Minimum specification for fast example pass
#' ligerex <- optimizeALS(ligerex, k = 5, max.iters = 1)
#' ligerex <- quantile_norm(ligerex)
#' calcNormLoadings(ligerex)
calcNormLoadings = function(object) {
  H_norm = object@H.norm
  W_norm = object@W
  V_norm = object@V
  U_norm = object@U
  ##### Calculation of Contribution #########################
  w_loadings = list()
  u_loadings = list()
  for (i in 1:length(object@raw.data)){
    u_loadings[[i]] = list()
  }
  v_loadings = list()
  for (i in 1:length(object@raw.data)){
    v_loadings[[i]] = list()
  }
  for ( i in 1:dim(object@H.norm)[[2]]){
    hi= as.matrix(H_norm[,i])
    ####### Calculate W
    wi = t(as.matrix(W_norm[i,]))
    hw = hi %*% wi
    forb_hw = norm(hw, type = "F")/dim(W_norm)[[2]]
    w_loadings = append(w_loadings, forb_hw)

    ###### Calculate V
    for (j in 1:length(object@raw.data)){
      temp_v = t(as.matrix(V_norm[[j]][i,]))
      hv_temp = hi %*% temp_v
      forb_hv = norm(hv_temp, type = "F")/dim(V_norm[[j]])[[2]]
      v_loadings[[j]]= append( v_loadings[[j]], forb_hv)
    }
    if (length(object@U) != 0){
      ###### Calculate U
      for (j in 1:length(object@raw.data)){
        if (length(object@U[[j]]) != 0){
          temp_u = t(as.matrix(U_norm[[j]][i,]))
          hu_temp = hi %*% temp_u
          forb_hu = norm(hu_temp, type = "F")/dim(U_norm[[j]])[[2]]
          u_loadings[[j]]= append(u_loadings[[j]], forb_hu) }
      }
    }
  }

  ################# Format the return object
  w_loadings = unlist(w_loadings)
  factors = 1:dim(object@H.norm)[[2]]
  results = data.frame(factors, w_loadings)

  # For all V
  for (j in 1:length(object@raw.data)){
    results = cbind(results, unlist(v_loadings[[j]]))
    colnames(results)[[2+j]] = paste0("V_", j,"_loadings")
  }
  if (length(object@U) != 0){
    # For all U
    for (j in 1:length(object@raw.data)){
      name_di = dim(results)[[2]]
      if (length(object@U[[j]]) != 0){
        results = cbind(results, unlist(u_loadings[[j]]))
        colnames(results)[[name_di+1]] = paste0("U_", j,"_loadings")
      }
    }
  }

  return(results)
}

Try the rliger package in your browser

Any scripts or data that you put into this service are public.

rliger documentation built on Nov. 9, 2023, 1:07 a.m.