R/plotting.R

#' Gene expression heatmap
#'
#' Draws a heatmap of single cell gene expression using ggplot2.
#'
#' @param object Seurat object
#' @param data.use Option to pass in data to use in the heatmap. Default will pick from either
#' object@@data or object@@scale.data depending on use.scaled parameter. Should have cells as columns
#' and genes as rows.
#' @param use.scaled Whether to use the data or scaled data if data.use is NULL
#' @param cells.use Cells to include in the heatmap (default is all cells)
#' @param genes.use Genes to include in the heatmap (ordered)
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param group.by Groups cells by this variable. Default is object@@ident
#' @param draw.line Draw vertical lines delineating different groups
#' @param col.low Color for lowest expression value
#' @param col.mid Color for mid expression value
#' @param col.high Color for highest expression value
#' @param slim.col.label display only the identity class name once for each group
#' @param remove.key Removes the color key from the plot.
#' @param rotate.key Rotate color scale horizantally
#' @param title Title for plot
#' @param cex.col Controls size of column labels (cells)
#' @param cex.row Controls size of row labels (genes)
#' @param group.label.loc Place group labels on bottom or top of plot.
#' @param group.label.rot Whether to rotate the group label.
#' @param group.cex Size of group label text
#' @param group.spacing Controls amount of space between columns.
#' @param do.plot Whether to display the plot.
#'
#' @return Returns a ggplot2 plot object
#'
#' @importFrom reshape2 melt
#' @importFrom dplyr %>%
#'
#' @export
#'
DoHeatmap <- function(
  object,
  data.use = NULL,
  use.scaled = TRUE,
  cells.use = NULL,
  genes.use = NULL,
  disp.min = -2.5,
  disp.max = 2.5,
  group.by = "ident",
  draw.line = TRUE,
  col.low = "#FF00FF",
  col.mid = "#000000",
  col.high = "#FFFF00",
  slim.col.label = FALSE,
  remove.key = FALSE,
  rotate.key = FALSE,
  title = NULL,
  cex.col = 10,
  cex.row = 10,
  group.label.loc = "bottom",
  group.label.rot = FALSE,
  group.cex = 15,
  group.spacing = 0.15,
  do.plot = TRUE
) {
  if (is.null(x = data.use)) {
    if (use.scaled) {
      data.use <- GetAssayData(object,assay.type = "RNA",slot = "scale.data")
    } else {
      data.use <- GetAssayData(object,assay.type = "RNA",slot = "data")
    }
  }
  # note: data.use should have cells as column names, genes as row names
  cells.use <- SetIfNull(x = cells.use, default = object@cell.names)
  cells.use <- intersect(x = cells.use, y = colnames(x = data.use))
  if (length(x = cells.use) == 0) {
    stop("No cells given to cells.use present in object")
  }
  genes.use <- SetIfNull(x = genes.use, default = rownames(x = data.use))
  genes.use <- intersect(x = genes.use, y = rownames(x = data.use))
  if (length(x = genes.use) == 0) {
    stop("No genes given to genes.use present in object")
  }
  if (is.null(x = group.by) || group.by == "ident") {
    cells.ident <- object@ident[cells.use]
  } else {
    cells.ident <- factor(x = FetchData(
      object = object,
      cells.use = cells.use,
      vars.all = group.by
    )[, 1])
    names(x = cells.ident) <- cells.use
  }
  cells.ident <- factor(
    x = cells.ident,
    labels = intersect(x = levels(x = cells.ident), y = cells.ident)
  )
  data.use <- data.use[genes.use, cells.use]
  if ((!use.scaled)) {
    data.use = as.matrix(x = data.use)
    if (disp.max==2.5) disp.max = 10;
  }
  data.use <- MinMax(data = data.use, min = disp.min, max = disp.max)
  data.use <- as.data.frame(x = t(x = data.use))
  data.use$cell <- rownames(x = data.use)
  colnames(x = data.use) <- make.unique(names = colnames(x = data.use))
  data.use %>% melt(id.vars = "cell") -> data.use
  names(x = data.use)[names(x = data.use) == 'variable'] <- 'gene'
  names(x = data.use)[names(x = data.use) == 'value'] <- 'expression'
  data.use$ident <- cells.ident[data.use$cell]
  breaks <- seq(
    from = min(data.use$expression),
    to = max(data.use$expression),
    length = length(x = PurpleAndYellow()) + 1
  )
  data.use$gene <- with(
    data = data.use,
    expr = factor(x = gene, levels = rev(x = unique(x = data.use$gene)))
  )
  data.use$cell <- with(
    data = data.use,
    expr = factor(x = cell, levels = cells.use)
  )
  # might be a solution if we want discrete interval units, makes the legend clunky though
  #data.use$expression <- cut(data.use$expression, breaks = breaks, include.lowest = T)
  #heatmap <- ggplot(data.use, aes(x = cell, y = gene, fill = expression)) + geom_tile() +
  #                  scale_fill_manual(values = PurpleAndYellow(), name= "Expression") +
  #                  scale_y_discrete(position = "right", labels = rev(genes.use)) +
  #                  theme(axis.line=element_blank(), axis.title.y=element_blank(),
  #                        axis.ticks.y = element_blank())
  if (rotate.key) {
    key.direction <- "horizontal"
    key.title.pos <- "top"
  } else {
    key.direction <- "vertical"
    key.title.pos <- "left"
  }
  heatmap <- ggplot(
    data = data.use,
    mapping = aes(x = cell, y = gene, fill = expression)
  ) +
    geom_tile() +
    scale_fill_gradient2(
      low = col.low,
      mid = col.mid,
      high = col.high,
      name= "Expression",
      guide = guide_colorbar(
        direction = key.direction,
        title.position = key.title.pos
      )
    ) +
    scale_y_discrete(position = "right", labels = rev(genes.use)) +
    theme(
      axis.line = element_blank(),
      axis.title.y = element_blank(),
      axis.ticks.y = element_blank(),
      strip.text.x = element_text(size = group.cex),
      axis.text.y = element_text(size = cex.row),
      axis.text.x = element_text(size = cex.col),
      axis.title.x = element_blank()
    )
  if (slim.col.label) {
    heatmap <- heatmap +
      theme(
        axis.title.x = element_blank(),
        axis.text.x = element_blank(),
        axis.ticks.x = element_blank(),
        axis.line = element_blank(),
        axis.title.y = element_blank(),
        axis.ticks.y = element_blank()
      )
  } else {
    heatmap <- heatmap + theme(axis.text.x = element_text(angle = 90))
  }
  if (! is.null(x = group.by)) {
    if (group.label.loc == "top") {
      switch <- NULL
      # heatmap <- heatmap +
      #   facet_grid(
      #     facets = ~ident,
      #     drop = TRUE,
      #     space = "free",
      #     scales = "free"
      #   ) +
      #   scale_x_discrete(expand = c(0, 0), drop = TRUE)
    } else {
      switch <- 'x'
      # heatmap <- heatmap +
      #   facet_grid(
      #     facets = ~ident,
      #     drop = TRUE,
      #     space = "free",
      #     scales = "free",
      #     switch = "x"
      #   ) +
      #   scale_x_discrete(expand = c(0, 0), drop = TRUE)
    }
    heatmap <- heatmap +
      facet_grid(
        facets = ~ident,
        drop = TRUE,
        space = "free",
        scales = "free",
        switch = switch,
      ) +
      scale_x_discrete(expand = c(0, 0), drop = TRUE)
    if (draw.line) {
      panel.spacing <- unit(x = group.spacing, units = 'lines')
      # heatmap <- heatmap + theme(strip.background = element_blank(), panel.spacing = unit(group.spacing, "lines"))
    } else {
      panel.spacing <- unit(x = 0, units = 'lines')
      #
    }
    heatmap <- heatmap +
      theme(strip.background = element_blank(), panel.spacing = panel.spacing)
    if (group.label.rot) {
      heatmap <- heatmap + theme(strip.text.x = element_text(angle = 90))
    }
  }
  if (remove.key) {
    heatmap <- heatmap + theme(legend.position = "none")
  }
  if (! is.null(x = title)) {
    heatmap <- heatmap + labs(title = title)
  }
  if (do.plot) {
    heatmap
  }
  return(heatmap)
}

#' Single cell violin plot
#'
#' Draws a violin plot of single cell data (gene expression, metrics, PC
#' scores, etc.)
#'
#' @param object Seurat object
#' @param features.plot Features to plot (gene expression, metrics, PC scores,
#' anything that can be retreived by FetchData)
#' @param ident.include Which classes to include in the plot (default is all)
#' @param nCol Number of columns if multiple plots are displayed
#' @param do.sort Sort identity classes (on the x-axis) by the average
#' expression of the attribute being potted
#' @param y.max Maximum y axis value
#' @param same.y.lims Set all the y-axis limits to the same values
#' @param size.x.use X axis title font size
#' @param size.y.use Y axis title font size
#' @param size.title.use Main title font size
#' @param adjust.use Adjust parameter for geom_violin
#' @param point.size.use Point size for geom_violin
#' @param cols.use Colors to use for plotting
#' @param group.by Group (color) cells in different ways (for example, orig.ident)
#' @param y.log plot Y axis on log scale
#' @param x.lab.rot Rotate x-axis labels
#' @param y.lab.rot Rotate y-axis labels
#' @param legend.position Position the legend for the plot
#' @param single.legend Consolidate legend the legend for all plots
#' @param remove.legend Remove the legend from the plot
#' @param do.return Return a ggplot2 object (default : FALSE)
#' @param return.plotlist Return the list of individual plots instead of compiled plot.
#' @param \dots additional parameters to pass to FetchData (for example, use.imputed, use.scaled, use.raw)
#'
#' @import ggplot2
#' @importFrom cowplot plot_grid get_legend
#'
#' @return By default, no return, only graphical output. If do.return=TRUE,
#' returns a list of ggplot objects.
#'
#' @export
VlnPlot <- function(
  object,
  features.plot,
  ident.include = NULL,
  nCol = NULL,
  do.sort = FALSE,
  y.max = NULL,
  same.y.lims = FALSE,
  size.x.use = 16,
  size.y.use = 16,
  size.title.use = 20,
  adjust.use = 1,
  point.size.use = 1,
  cols.use = NULL,
  group.by = NULL,
  y.log = FALSE,
  x.lab.rot = FALSE,
  y.lab.rot = FALSE,
  legend.position = "right",
  single.legend = TRUE,
  remove.legend = FALSE,
  do.return = FALSE,
  return.plotlist = FALSE,
  ...
) {
  if (is.null(x = nCol)) {
    if (length(x = features.plot) > 9) {
      nCol <- 4
    } else {
      nCol <- min(length(x = features.plot), 3)
    }
  }
  data.use <- data.frame(FetchData(object = object, vars.all = features.plot, ...), check.names = F)
  if (is.null(x = ident.include)) {
    cells.to.include <- object@cell.names
  } else {
    cells.to.include <- WhichCells(object = object, ident = ident.include)
  }
  data.use <- data.use[cells.to.include, ,drop = FALSE]
  if (!is.null(x = group.by)) {
    ident.use <- as.factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[cells.to.include, 1])
  } else {
    ident.use <- object@ident[cells.to.include]
  }
  gene.names <- colnames(x = data.use)[colnames(x = data.use) %in% rownames(x = object@data)]
  if (single.legend) {
    remove.legend <- TRUE
  }
  if (same.y.lims && is.null(x = y.max)) {
    y.max <- max(data.use)
  }
  plots <- lapply(
    X = features.plot,
    FUN = function(x) {
      return(SingleVlnPlot(
        feature = x,
        data = data.use[, x, drop = FALSE],
        cell.ident = ident.use,
        do.sort = do.sort, y.max = y.max,
        size.x.use = size.x.use,
        size.y.use = size.y.use,
        size.title.use = size.title.use,
        adjust.use = adjust.use,
        point.size.use = point.size.use,
        cols.use = cols.use,
        gene.names = gene.names,
        y.log = y.log,
        x.lab.rot = x.lab.rot,
        y.lab.rot = y.lab.rot,
        legend.position = legend.position,
        remove.legend = remove.legend
      ))
    }
  )
  if (length(x = features.plot) > 1) {
    plots.combined <- plot_grid(plotlist = plots, ncol = nCol)
    if (single.legend && !remove.legend) {
      legend <- get_legend(
        plot = plots[[1]] + theme(legend.position = legend.position)
      )
      if (legend.position == "bottom") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          ncol = 1,
          rel_heights = c(1, .2)
        )
      } else if (legend.position == "right") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          rel_widths = c(3, .3)
        )
      } else {
        warning("Shared legends must be at the bottom or right of the plot")
      }
    }
  } else {
    plots.combined <- plots[[1]]
  }
  if (do.return) {
    if (return.plotlist) {
      return(plots)
    } else {
      return(plots.combined)
    }
  } else {
    if (length(x = plots.combined) > 1) {
      plots.combined
    }
    else {
      invisible(x = lapply(X = plots.combined, FUN = print))
    }
  }
}

#' Single cell joy plot
#'
#' Draws a joy plot of single cell data (gene expression, metrics, PC
#' scores, etc.)
#'
#' @param object Seurat object
#' @param features.plot Features to plot (gene expression, metrics, PC scores,
#' anything that can be retreived by FetchData)
#' @param ident.include Which classes to include in the plot (default is all)
#' @param nCol Number of columns if multiple plots are displayed
#' @param do.sort Sort identity classes (on the x-axis) by the average
#' expression of the attribute being potted
#' @param y.max Maximum y axis value
#' @param same.y.lims Set all the y-axis limits to the same values
#' @param size.x.use X axis title font size
#' @param size.y.use Y axis title font size
#' @param size.title.use Main title font size
#' @param cols.use Colors to use for plotting
#' @param group.by Group (color) cells in different ways (for example, orig.ident)
#' @param y.log plot Y axis on log scale
#' @param x.lab.rot Rotate x-axis labels
#' @param y.lab.rot Rotate y-axis labels
#' @param legend.position Position the legend for the plot
#' @param single.legend Consolidate legend the legend for all plots
#' @param remove.legend Remove the legend from the plot
#' @param do.return Return a ggplot2 object (default : FALSE)
#' @param return.plotlist Return the list of individual plots instead of compiled plot.
#' @param \dots additional parameters to pass to FetchData (for example, use.imputed, use.scaled, use.raw)
#'
#' @import ggplot2
#' @importFrom cowplot get_legend
#' @importFrom ggjoy geom_joy theme_joy
#' @importFrom cowplot plot_grid
#'
#' @return By default, no return, only graphical output. If do.return=TRUE,
#' returns a list of ggplot objects.
#'
#' @export
#'
JoyPlot <- function(
  object,
  features.plot,
  ident.include = NULL,
  nCol = NULL,
  do.sort = FALSE,
  y.max = NULL,
  same.y.lims = FALSE,
  size.x.use = 16,
  size.y.use = 16,
  size.title.use = 20,
  cols.use = NULL,
  group.by = NULL,
  y.log = FALSE,
  x.lab.rot = FALSE,
  y.lab.rot = FALSE,
  legend.position = "right",
  single.legend = TRUE,
  remove.legend = FALSE,
  do.return = FALSE,
  return.plotlist = FALSE,
  ...
) {
  if (is.null(x = nCol)) {
    if (length(x = features.plot) > 9) {
      nCol <- 4
    } else {
      nCol <- min(length(x = features.plot), 3)
    }
  }
  data.use <- data.frame(FetchData(object = object, vars.all = features.plot, ...),
                         check.names = F)
  if (is.null(x = ident.include)) {
    cells.to.include <- object@cell.names
  } else {
    cells.to.include <- WhichCells(object = object, ident = ident.include)
  }
  data.use <- data.use[cells.to.include, ,drop = FALSE]
  if (!is.null(x = group.by)) {
    ident.use <- as.factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[cells.to.include, 1])
  } else {
    ident.use <- object@ident[cells.to.include]
  }
  gene.names <- colnames(x = data.use)[colnames(x = data.use) %in% rownames(x = object@data)]
  if (single.legend) {
    remove.legend <- TRUE
  }
  if (same.y.lims && is.null(x = y.max)) {
    y.max <- max(data.use)
  }
  plots <- lapply(
    X = features.plot,
    FUN = function(x) {
      return(SingleJoyPlot(
        feature = x,
        data = data.use[, x, drop = FALSE],
        cell.ident = ident.use,
        do.sort = do.sort, y.max = y.max,
        size.x.use = size.x.use,
        size.y.use = size.y.use,
        size.title.use = size.title.use,
        cols.use = cols.use,
        gene.names = gene.names,
        y.log = y.log,
        x.lab.rot = x.lab.rot,
        y.lab.rot = y.lab.rot,
        legend.position = legend.position,
        remove.legend = remove.legend
      ))
    }
  )
  if (length(x = features.plot) > 1) {
    plots.combined <- plot_grid(plotlist = plots, ncol = nCol)
    if (single.legend && !remove.legend) {
      legend <- get_legend(
        plot = plots[[1]] + theme(legend.position = legend.position)
      )
      if (legend.position == "bottom") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          ncol = 1,
          rel_heights = c(1, .2)
        )
      } else if (legend.position == "right") {
        plots.combined <- plot_grid(
          plots.combined,
          legend,
          rel_widths = c(3, .3)
        )
      } else {
        warning("Shared legends must be at the bottom or right of the plot")
      }
    }
  } else {
    plots.combined <- plots[[1]]
  }
  if (do.return) {
    if (return.plotlist) {
      return(plots)
    } else {
      return(plots.combined)
    }
  } else {
    if (length(x = plots.combined) > 1) {
      plots.combined
    }
    else {
      invisible(x = lapply(X = plots.combined, FUN = print))
    }
  }
}

#' Old Dot plot visualization (pre-ggplot implementation)
#
#' Intuitive way of visualizing how gene expression changes across different identity classes (clusters).
#' The size of the dot encodes the percentage of cells within a class, while the color encodes the
#' AverageExpression level of 'expressing' cells (green is high).
#'
#' @param object Seurat object
#' @param genes.plot Input vector of genes
#' @param cex.use Scaling factor for the dots (scales all dot sizes)
#' @param cols.use colors to plot
#' @param thresh.col The raw data value which corresponds to a red dot (lowest expression)
#' @param dot.min The fraction of cells at which to draw the smallest dot (default is 0.05)
#' @param group.by Factor to group the cells by
#'
#' @return Only graphical output
#'
#' @importFrom graphics axis
#'
#' @export
#'
DotPlotOld <- function(
  object,
  genes.plot,
  cex.use = 2,
  cols.use = NULL,
  thresh.col = 2.5,
  dot.min = 0.05,
  group.by = NULL
) {
  if (! is.null(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  #object@data=object@data[genes.plot,]
  object@data <- data.frame(t(x = FetchData(object = object, vars.all = genes.plot)))
  #this line is in case there is a '-' in the cell name
  colnames(x = object@data) <- object@cell.names
  avg.exp <- AverageExpression(object = object)
  avg.alpha <- AverageDetectionRate(object = object)
  cols.use <- SetIfNull(x = cols.use, default = CustomPalette(low = "red", high = "green"))
  exp.scale <- t(x = scale(x = t(x = avg.exp)))
  exp.scale <- MinMax(data = exp.scale, max = thresh.col, min = (-1) * thresh.col)
  n.col <- length(x = cols.use)
  data.y <- rep(x = 1:ncol(x = avg.exp), nrow(x = avg.exp))
  data.x <- unlist(x = lapply(X = 1:nrow(x = avg.exp), FUN = rep, ncol(x = avg.exp)))
  data.avg <- unlist(x = lapply(
    X = 1:length(x = data.y),
    FUN = function(x) {
      return(exp.scale[data.x[x], data.y[x]])
    }
  ))
  exp.col <- cols.use[floor(
    x = n.col * (data.avg + thresh.col) / (2 * thresh.col) + .5
  )]
  data.cex <- unlist(x = lapply(
    X = 1:length(x = data.y),
    FUN = function(x) {
      return(avg.alpha[data.x[x], data.y[x]])
    }
  )) * cex.use + dot.min
  plot(
    x = data.x,
    y = data.y,
    cex = data.cex,
    pch = 16,
    col = exp.col,
    xaxt = "n",
    xlab = "",
    ylab = "",
    yaxt = "n"
  )
  axis(side = 1, at = 1:length(x = genes.plot), labels = genes.plot)
  axis(side = 2, at = 1:ncol(x = avg.alpha), colnames(x = avg.alpha), las = 1)
}

#' Dot plot visualization
#'
#' Intuitive way of visualizing how gene expression changes across different
#' identity classes (clusters). The size of the dot encodes the percentage of
#' cells within a class, while the color encodes the AverageExpression level of
#' 'expressing' cells (blue is high).
#'
#' @param object Seurat object
#' @param genes.plot Input vector of genes
#' @param cols.use colors to plot
#' @param col.min Minimum scaled average expression threshold (everything smaller
#'  will be set to this)
#' @param col.max Maximum scaled average expression threshold (everything larger
#' will be set to this)
#' @param dot.min The fraction of cells at which to draw the smallest dot
#' (default is 0.05). All cell groups with less than this expressing the given
#' gene will have no dot drawn.
#' @param dot.scale Scale the size of the points, similar to cex
#' @param group.by Factor to group the cells by
#' @param plot.legend plots the legends
#' @param x.lab.rot Rotate x-axis labels
#' @param do.return Return ggplot2 object
#' @return default, no return, only graphical output. If do.return=TRUE, returns a ggplot2 object
#' @importFrom dplyr %>% group_by summarize_each mutate ungroup
#' @importFrom tidyr gather
#' @export
DotPlot <- function(
  object,
  genes.plot,
  cols.use = c("lightgrey", "blue"),
  col.min = -2.5,
  col.max = 2.5,
  dot.min = 0,
  dot.scale = 6,
  group.by,
  plot.legend = FALSE,
  do.return = FALSE,
  x.lab.rot = FALSE
) {
  if (! missing(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  data.to.plot <- data.frame(FetchData(object = object, vars.all = genes.plot))
  data.to.plot$cell <- rownames(x = data.to.plot)
  data.to.plot$id <- object@ident
  data.to.plot %>% gather(
    key = genes.plot,
    value = expression,
    -c(cell, id)
  ) -> data.to.plot
  data.to.plot %>%
    group_by(id, genes.plot) %>%
    summarize(
      avg.exp = mean(expm1(x = expression)),
      pct.exp = PercentAbove(x = expression, threshold = 0)
    ) -> data.to.plot
  data.to.plot %>%
    ungroup() %>%
    group_by(genes.plot) %>%
    mutate(avg.exp.scale = scale(x = avg.exp)) %>%
    mutate(avg.exp.scale = MinMax(
      data = avg.exp.scale,
      max = col.max,
      min = col.min
    )) ->  data.to.plot
  data.to.plot$genes.plot <- factor(
    x = data.to.plot$genes.plot,
    levels = rev(x = sub(pattern = "-", replacement = ".", x = genes.plot))
  )
  data.to.plot$pct.exp[data.to.plot$pct.exp < dot.min] <- NA
  p <- ggplot(data = data.to.plot, mapping = aes(x = genes.plot, y = id)) +
    geom_point(mapping = aes(size = pct.exp, color = avg.exp.scale)) +
    scale_radius(range = c(0, dot.scale)) +
    scale_color_gradient(low = cols.use[1], high = cols.use[2]) +
    theme(axis.title.x = element_blank(), axis.title.y = element_blank())
  if (! plot.legend) {
    p <- p + theme(legend.position = "none")
  }
  if (x.lab.rot) {
    p <- p + theme(axis.text.x = element_text(angle = 90, vjust = 0.5))
  }
  suppressWarnings(print(p))
  if (do.return) {
    return(p)
  }
}


#' Split Dot plot visualization
#'
#' Intuitive way of visualizing how gene expression changes across different identity classes (clusters).
#' The size of the dot encodes the percentage of cells within a class, while the color encodes the
#' AverageExpression level of 'expressing' cells (green is high). Splits the cells into two groups based on a
#' grouping variable.
#' Still in BETA
#'
#' @param object Seurat object
#' @param grouping.var Grouping variable for splitting the dataset
#' @param genes.plot Input vector of genes
#' @param cols.use colors to plot
#' @param col.min Minimum scaled average expression threshold (everything smaller will be set to this)
#' @param col.max Maximum scaled average expression threshold (everything larger will be set to this)
#' @param dot.min The fraction of cells at which to draw the smallest dot (default is 0.05).
#' @param dot.scale Scale the size of the points, similar to cex
#' @param group.by Factor to group the cells by
#' @param plot.legend plots the legends
#' @param x.lab.rot Rotate x-axis labels
#' @param do.return Return ggplot2 object
#' @param gene.groups Add labeling bars to the top of the plot
#' @return default, no return, only graphical output. If do.return=TRUE, returns a ggplot2 object
#' @importFrom dplyr %>% group_by summarize_each mutate ungroup
#' @importFrom tidyr gather
#' @export
SplitDotPlotGG <- function(
  object,
  grouping.var,
  genes.plot,
  gene.groups,
  cols.use = c("green", "red"),
  col.min = -2.5,
  col.max = 2.5,
  dot.min = 0,
  dot.scale = 6,
  group.by,
  plot.legend = FALSE,
  do.return = FALSE,
  x.lab.rot = FALSE
) {
  if (! missing(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  grouping.data <- FetchData(
    object = object,
    vars.all = grouping.var
  )[names(x = object@ident), 1]
  idents.old <- levels(x = object@ident)
  object@ident <- paste(object@ident, grouping.data, sep="_")
  object@ident <- factor(
    x = object@ident,
    levels = unlist(x = lapply(
      X = idents.old,
      FUN = function(x) {
        return(c(
          paste(x, unique(x = grouping.data)[1], sep="_"),
          paste(x, unique(x = grouping.data)[2], sep="_")
        ))
      }
    )),
    ordered = TRUE
  )
  data.to.plot <- data.frame(FetchData(object = object, vars.all = genes.plot))
  data.to.plot$cell <- rownames(x = data.to.plot)
  data.to.plot$id <- object@ident
  data.to.plot %>%
    gather(key = genes.plot, value = expression, -c(cell, id)) -> data.to.plot
  data.to.plot %>%
    group_by(id, genes.plot) %>%
    summarize(
      avg.exp = ExpMean(x = expression),
      pct.exp = PercentAbove(x = expression, threshold = 0)
    ) -> data.to.plot
  ids.2 <- paste(
    idents.old,
    as.character(x = unique(x = grouping.data)[2]),
    sep = "_"
  )
  vals.2 <- which(x = data.to.plot$id %in% ids.2)
  ids.1 <- paste(
    idents.old,
    as.character(x = unique(x = grouping.data)[1]),
    sep = "_"
  )
  vals.1 <- which(x = data.to.plot$id %in% ids.1)
  #data.to.plot[vals.2,3]=-1*data.to.plot[vals.2,3]
  data.to.plot %>%
    ungroup() %>%
    group_by(genes.plot) %>%
    mutate(avg.exp = scale(avg.exp)) %>%
    mutate(avg.exp.scale = as.numeric(x = cut(
      x = MinMax(data = avg.exp, max = col.max, min = col.min),
      breaks = 20
    ))) ->  data.to.plot
  data.to.plot$genes.plot <- factor(
    x = data.to.plot$genes.plot,
    levels = rev(x = sub(pattern = "-", replacement = ".", x = genes.plot))
  )
  data.to.plot$pct.exp[data.to.plot$pct.exp < dot.min] <- NA
  palette.1 <- CustomPalette(low = "grey", high = "blue", k = 20)
  palette.2 <- CustomPalette(low = "grey", high = "red", k = 20)
  data.to.plot$ptcolor <- "grey"
  data.to.plot[vals.1, "ptcolor"] <- palette.1[as.matrix(
    x = data.to.plot[vals.1, "avg.exp.scale"]
  )[, 1]]
  data.to.plot[vals.2, "ptcolor"] <- palette.2[as.matrix(
    x = data.to.plot[vals.2, "avg.exp.scale"]
  )[, 1]]
  if (! missing(x = gene.groups)) {
    names(x = gene.groups) <- genes.plot
    data.to.plot %>%
      mutate(gene.groups = gene.groups[genes.plot]) -> data.to.plot
  }
  p <- ggplot(data = data.to.plot, mapping = aes(x = genes.plot, y = id)) +
    geom_point(mapping = aes(size = pct.exp, color = ptcolor)) +
    scale_radius(range = c(0, dot.scale)) +
    scale_color_identity() +
    theme(axis.title.x = element_blank(), axis.title.y = element_blank())
  if (! missing(x = gene.groups)) {
    p <- p +
      facet_grid(
        facets = ~gene.groups,
        scales = "free_x",
        space = "free_x",
        switch = "y"
      ) +
      theme(
        panel.spacing = unit(x = 1, units = "lines"),
        strip.background = element_blank(),
        strip.placement = "outside"
      )
  }
  if (! plot.legend) {
    p <- p + theme(legend.position = "none")
  }
  if (x.lab.rot) {
    p <- p + theme(axis.text.x = element_text(angle = 90, vjust = 0.5))
  }
  suppressWarnings(print(p))
  if (do.return) {
    return(p)
  }
}

#' Visualize 'features' on a dimensional reduction plot
#'
#' Colors single cells on a dimensional reduction plot according to a 'feature'
#' (i.e. gene expression, PC scores, number of genes detected, etc.)
#'
#'
#' @param object Seurat object
#' @param features.plot Vector of features to plot
#' @param min.cutoff Vector of minimum cutoff values for each feature, may specify quantile in the form of 'q##' where '##' is the quantile (eg, 1, 10)
#' @param max.cutoff Vector of maximum cutoff values for each feature, may specify quantile in the form of 'q##' where '##' is the quantile (eg, 1, 10)
#' @param dim.1 Dimension for x-axis (default 1)
#' @param dim.2 Dimension for y-axis (default 2)
#' @param cells.use Vector of cells to plot (default is all cells)
#' @param pt.size Adjust point size for plotting
#' @param cols.use The two colors to form the gradient over. Provide as string vector with
#' the first color corresponding to low values, the second to high. Also accepts a Brewer
#' color scale or vector of colors. Note: this will bin the data into number of colors provided.
#' @param pch.use Pch for plotting
#' @param overlay Plot two features overlayed one on top of the other
#' @param do.hover Enable hovering over points to view information
#' @param data.hover Data to add to the hover, pass a character vector of features to add. Defaults to cell name and identity. Pass 'NULL' to remove extra data.
#' @param do.identify Opens a locator session to identify clusters of cells
#' @param reduction.use Which dimensionality reduction to use. Default is
#' "tsne", can also be "pca", or "ica", assuming these are precomputed.
#' @param use.imputed Use imputed values for gene expression (default is FALSE)
#' @param nCol Number of columns to use when plotting multiple features.
#' @param no.axes Remove axis labels
#' @param no.legend Remove legend from the graph. Default is TRUE.
#' @param dark.theme Plot in a dark theme
#' @param do.return return the ggplot2 object
#'
#' @importFrom RColorBrewer brewer.pal.info
#'
#' @return No return value, only a graphical output
#'
#' @export
#'
FeaturePlot <- function(
  object,
  features.plot,
  min.cutoff = NA,
  max.cutoff = NA,
  dim.1 = 1,
  dim.2 = 2,
  cells.use = NULL,
  pt.size = 1,
  cols.use = c("yellow", "red"),
  pch.use = 16,
  overlay = FALSE,
  do.hover = FALSE,
  data.hover = 'ident',
  do.identify = FALSE,
  reduction.use = "tsne",
  use.imputed = FALSE,
  nCol = NULL,
  no.axes = FALSE,
  no.legend = TRUE,
  dark.theme = FALSE,
  do.return = FALSE
) {
  cells.use <- SetIfNull(x = cells.use, default = colnames(x = object@data))
  if (is.null(x = nCol)) {
    nCol <- 2
    if (length(x = features.plot) == 1) {
      nCol <- 1
    }
    if (length(x = features.plot) > 6) {
      nCol <- 3
    }
    if (length(x = features.plot) > 9) {
      nCol <- 4
    }
  }
  num.row <- floor(x = length(x = features.plot) / nCol - 1e-5) + 1
  if (overlay | do.hover) {
    num.row <- 1
    nCol <- 1
  }
  par(mfrow = c(num.row, nCol))
  dim.code <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = 'key'
  )
  dim.codes <- paste0(dim.code, c(dim.1, dim.2))
  data.plot <- as.data.frame(GetCellEmbeddings(
    object = object,
    reduction.type = reduction.use,
    dims.use = c(dim.1, dim.2),
    cells.use = cells.use
  ))
  x1 <- paste0(dim.code, dim.1)
  x2 <- paste0(dim.code, dim.2)
  data.plot$x <- data.plot[, x1]
  data.plot$y <- data.plot[, x2]
  data.plot$pt.size <- pt.size
  names(x = data.plot) <- c('x', 'y')
  data.use <- t(x = FetchData(
    object = object,
    vars.all = features.plot,
    cells.use = cells.use,
    use.imputed = use.imputed
  ))
  #   Check mins and maxes
  min.cutoff <- mapply(
    FUN = function(cutoff, feature) {
      ifelse(
        test = is.na(x = cutoff),
        yes = min(data.use[feature, ]),
        no = cutoff
      )
    },
    cutoff = min.cutoff,
    feature = features.plot
  )
  max.cutoff <- mapply(
    FUN = function(cutoff, feature) {
      ifelse(
        test = is.na(x = cutoff),
        yes = max(data.use[feature, ]),
        no = cutoff
      )
    },
    cutoff = max.cutoff,
    feature = features.plot
  )
  check_lengths = unique(x = vapply(
    X = list(features.plot, min.cutoff, max.cutoff),
    FUN = length,
    FUN.VALUE = numeric(length = 1)
  ))
  if (length(x = check_lengths) != 1) {
    stop('There must be the same number of minimum and maximum cuttoffs as there are features')
  }
  if (overlay) {
    #   Wrap as a list for MutiPlotList
    pList <- list(
      BlendPlot(
        data.use = data.use,
        features.plot = features.plot,
        data.plot = data.plot,
        pt.size = pt.size,
        pch.use = pch.use,
        cols.use = cols.use,
        dim.codes = dim.codes,
        min.cutoff = min.cutoff,
        max.cutoff = max.cutoff,
        no.axes = no.axes,
        no.legend = no.legend,
        dark.theme = dark.theme
      )
    )
  } else {
    #   Use mapply instead of lapply for multiple iterative variables.
    pList <- mapply(
      FUN = SingleFeaturePlot,
      feature = features.plot,
      min.cutoff = min.cutoff,
      max.cutoff = max.cutoff,
      MoreArgs = list( # Arguments that are not being repeated
        data.use = data.use,
        data.plot = data.plot,
        pt.size = pt.size,
        pch.use = pch.use,
        cols.use = cols.use,
        dim.codes = dim.codes,
        no.axes = no.axes,
        no.legend = no.legend,
        dark.theme = dark.theme
      ),
      SIMPLIFY = FALSE # Get list, not matrix
    )
  }
  if (do.hover) {
    if (length(x = pList) != 1) {
      stop("'do.hover' only works on a single feature or an overlayed FeaturePlot")
    }
    if (is.null(x = data.hover)) {
      features.info <- NULL
    } else {
      features.info <- FetchData(object = object, vars.all = data.hover)
    }
    #   Use pList[[1]] to properly extract the ggplot out of the plot list
    return(HoverLocator(
      plot = pList[[1]],
      data.plot = data.plot,
      features.info = features.info,
      dark.theme = dark.theme,
      title = features.plot
    ))
    # invisible(readline(prompt = 'Press <Enter> to continue\n'))
  } else if (do.identify) {
    if (length(x = pList) != 1) {
      stop("'do.identify' only works on a single feature or an overlayed FeaturePlot")
    }
    #   Use pList[[1]] to properly extract the ggplot out of the plot list
    return(FeatureLocator(
      plot = pList[[1]],
      data.plot = data.plot,
      dark.theme = dark.theme
    ))
  } else {
    print(x = cowplot::plot_grid(plotlist = pList, ncol = nCol))
  }
  ResetPar()
  if (do.return){
    return(pList)
  }
}

#' Vizualization of multiple features
#'
#' Similar to FeaturePlot, however, also splits the plot by visualizing each
#' identity class separately.
#'
#' Particularly useful for seeing if the same groups of cells co-exhibit a
#' common feature (i.e. co-express a gene), even within an identity class. Best
#' understood by example.
#'
#' @param object Seurat object
#' @param features.plot Vector of features to plot
#' @param dim.1 Dimension for x-axis (default 1)
#' @param dim.2 Dimension for y-axis (default 2)
#' @param idents.use Which identity classes to display (default is all identity
#' classes)
#' @param pt.size Adjust point size for plotting
#' @param cols.use Ordered vector of colors to use for plotting. Default is
#' heat.colors(10).
#' @param pch.use Pch for plotting
#' @param reduction.use Which dimensionality reduction to use. Default is
#' "tsne", can also be "pca", or "ica", assuming these are precomputed.
#' @param group.by Group cells in different ways (for example, orig.ident)
#' @param sep.scale Scale each group separately. Default is FALSE.
#' @param max.exp Max cutoff for scaled expression value, supports quantiles in the form of 'q##' (see FeaturePlot)
#' @param min.exp Min cutoff for scaled expression value, supports quantiles in the form of 'q##' (see FeaturePlot)
#' @param rotate.key rotate the legend
#' @param plot.horiz rotate the plot such that the features are columns, groups are the rows
#' @param key.position position of the legend ("top", "right", "bottom", "left")
#' @param do.return Return the ggplot2 object
#'
#' @return No return value, only a graphical output
#'
#' @importFrom dplyr %>% mutate_each group_by select ungroup
#'
#' @seealso \code{FeaturePlot}
#'
#' @export
#'
FeatureHeatmap <- function(
  object,
  features.plot,
  dim.1 = 1,
  dim.2 = 2,
  idents.use = NULL,
  pt.size = 2,
  cols.use = c("grey", "red"),
  pch.use = 16,
  reduction.use = "tsne",
  group.by = NULL,
  sep.scale = FALSE,
  do.return = FALSE,
  min.exp = -Inf,
  max.exp = Inf,
  rotate.key = FALSE,
  plot.horiz = FALSE,
  key.position = "right"
) {
  if (! is.null(x = group.by)) {
    object <- SetAllIdent(object = object, id = group.by)
  }
  idents.use <- SetIfNull(x = idents.use, default = sort(x = unique(x = object@ident)))
  par(mfrow = c(length(x = features.plot), length(x = idents.use)))
  dim.code <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = 'key'
  )
  dim.codes <- paste0(dim.code, c(dim.1, dim.2))
  data.plot <- data.frame(FetchData(
    object = object,
    vars.all = c(dim.codes, features.plot)
  ))
  colnames(x = data.plot)[1:2] <- c("dim1", "dim2")
  data.plot$ident <- as.character(x = object@ident)
  data.plot$cell <- rownames(x = data.plot)
  features.plot <- gsub('-', '\\.', features.plot)
  data.plot  %>% gather(gene, expression, features.plot, -dim1, -dim2, -ident, -cell) -> data.plot
  if (sep.scale) {
    data.plot %>% group_by(ident, gene) %>% mutate(scaled.expression = scale(expression)) -> data.plot
  } else {
    data.plot %>%  group_by(gene) %>% mutate(scaled.expression = scale(expression)) -> data.plot
  }
  min.exp <- SetQuantile(cutoff = min.exp, data = data.plot$scaled.expression)
  max.exp <- SetQuantile(cutoff = max.exp, data = data.plot$scaled.expression)
  data.plot$gene <- factor(x = data.plot$gene, levels = features.plot)
  data.plot$scaled.expression <- MinMax(
    data = data.plot$scaled.expression,
    min = min.exp,
    max = max.exp
  )
  if (rotate.key) {
    key.direction <- "horizontal"
    key.title.pos <- "top"
  } else {
    key.direction <- "vertical"
    key.title.pos <- "left"
  }
  p <- ggplot(data = data.plot, mapping = aes(x = dim1, y = dim2)) +
    geom_point(mapping = aes(colour = scaled.expression), size = pt.size)
  if (rotate.key) {
    p <- p + scale_colour_gradient(
      low = cols.use[1],
      high = cols.use[2],
      guide = guide_colorbar(
        direction = key.direction,
        title.position = key.title.pos,
        title = "Scaled Expression"
      )
    )
  } else {
    p <- p + scale_colour_gradient(
      low = cols.use[1],
      high = cols.use[2],
      guide = guide_colorbar(title = "Scaled Expression")
    )
  }
  if(plot.horiz){
    p <- p + facet_grid(ident ~ gene)
  }
  else{
    p <- p + facet_grid(gene ~ ident)
  }
  p2 <- p +
    theme_bw() +
    NoGrid() +
    ylab(label = dim.codes[2]) +
    xlab(label = dim.codes[1])
  p2 <- p2 + theme(legend.position = key.position)
  if (do.return) {
    return(p2)
  }
  print(p2)
}

#' Gene expression heatmap
#'
#' Draws a heatmap of single cell gene expression using the heatmap.2 function. Has been replaced by the ggplot2
#' version (now in DoHeatmap), but kept for legacy
#'
#' @param object Seurat object
#' @param cells.use Cells to include in the heatmap (default is all cells)
#' @param genes.use Genes to include in the heatmap (ordered)
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param draw.line Draw vertical lines delineating cells in different identity
#' classes.
#' @param do.return Default is FALSE. If TRUE, return a matrix of scaled values
#' which would be passed to heatmap.2
#' @param order.by.ident Order cells in the heatmap by identity class (default
#' is TRUE). If FALSE, cells are ordered based on their order in cells.use
#' @param col.use Color palette to use
#' @param slim.col.label if (order.by.ident==TRUE) then instead of displaying
#' every cell name on the heatmap, display only the identity class name once
#' for each group
#' @param group.by If (order.by.ident==TRUE) default,  you can group cells in
#' different ways (for example, orig.ident)
#' @param remove.key Removes the color key from the plot.
#' @param cex.col positive numbers, used as cex.axis in for the column axis labeling.
#' The defaults currently only use number of columns
#' @param do.scale whether to use the data or scaled data
#' @param ... Additional parameters to heatmap.2. Common examples are cexRow
#' and cexCol, which set row and column text sizes
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @importFrom gplots heatmap.2
#'
#' @export
#'
OldDoHeatmap <- function(
  object,
  cells.use = NULL,
  genes.use = NULL,
  disp.min = NULL,
  disp.max = NULL,
  draw.line = TRUE,
  do.return = FALSE,
  order.by.ident = TRUE,
  col.use = PurpleAndYellow(),
  slim.col.label = FALSE,
  group.by = NULL,
  remove.key = FALSE,
  cex.col = NULL,
  do.scale = TRUE,
  ...
) {
  cells.use <- SetIfNull(x = cells.use, default = object@cell.names)
  cells.use <- intersect(x = cells.use, y = object@cell.names)
  cells.ident <- object@ident[cells.use]
  if (! is.null(x = group.by)) {
    cells.ident <- factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[, 1])
  }
  cells.ident <- factor(
    x = cells.ident,
    labels = intersect(x = levels(x = cells.ident), y = cells.ident)
  )
  if (order.by.ident) {
    cells.use <- cells.use[order(cells.ident)]
  } else {
    cells.ident <- factor(
      x = cells.ident,
      levels = as.vector(x = unique(x = cells.ident))
    )
  }
  #determine assay type
  data.use <- NULL
  assays.use <- c("RNA", names(x = object@assay))
  if (do.scale) {
    slot.use <- "scale.data"
    if ((is.null(x = disp.min) || is.null(x = disp.max))) {
      disp.min <- -2.5
      disp.max <- 2.5
    }
  } else {
    slot.use <- "data"
    if ((is.null(x = disp.min) || is.null(x = disp.max))) {
      disp.min <- -Inf
      disp.max <- Inf
    }
  }
  for (assay.check in assays.use) {
    data.assay <- GetAssayData(
      object = object,
      assay.type = assay.check,
      slot = slot.use
    )
    genes.intersect <- intersect(x = genes.use, y = rownames(x = data.assay))
    new.data <- data.assay[genes.intersect, cells.use, drop = FALSE]
    if (! (is.matrix(x = new.data))) {
      new.data <- as.matrix(x = new.data)
    }
    data.use <- rbind(data.use, new.data)
  }
  data.use <- MinMax(data = data.use, min = disp.min, max = disp.max)
  vline.use <- NULL
  colsep.use <- NULL
  if (remove.key) {
    hmFunction <- heatmap2NoKey
  } else {
    hmFunction <- heatmap.2
  }
  if (draw.line) {
    colsep.use <- cumsum(x = table(cells.ident))
  }
  if (slim.col.label && order.by.ident) {
    col.lab <- rep("", length(x = cells.use))
    col.lab[round(x = cumsum(x = table(cells.ident)) - table(cells.ident) / 2) + 1] <- levels(x = cells.ident)
    cex.col <- SetIfNull(
      x = cex.col,
      default = 0.2 + 1 / log10(x = length(x = unique(x = cells.ident)))
    )
    hmFunction(
      data.use,
      Rowv = NA,
      Colv = NA,
      trace = "none",
      col = col.use,
      colsep = colsep.use,
      labCol = col.lab,
      cexCol = cex.col,
      ...
    )
  } else if (slim.col.label) {
    col.lab = rep("", length(x = cells.use))
    cex.col <- SetIfNull(
      x = cex.col,
      default = 0.2 + 1 / log10(x = length(x = unique(x = cells.ident)))
    )
    hmFunction(
      data.use,
      Rowv = NA,
      Colv = NA,
      trace = "none",
      col = col.use,
      colsep = colsep.use,
      labCol = col.lab,
      cexCol = cex.col,
      ...
    )
  } else {
    hmFunction(
      data.use,
      Rowv = NA,
      Colv = NA,
      trace = "none",
      col = col.use,
      colsep = colsep.use,
      ...
    )
  }
  if (do.return) {
    return(data.use)
  }
}

#' JackStraw Plot
#'
#' Plots the results of the JackStraw analysis for PCA significance. For each
#' PC, plots a QQ-plot comparing the distribution of p-values for all genes
#' across each PC, compared with a uniform distribution. Also determines a
#' p-value for the overall significance of each PC (see Details).
#'
#' Significant PCs should show a p-value distribution (black curve) that is
#' strongly skewed to the left compared to the null distribution (dashed line)
#' The p-value for each PC is based on a proportion test comparing the number
#' of genes with a p-value below a particular threshold (score.thresh), compared with the
#' proportion of genes expected under a uniform distribution of p-values.
#'
#' @param object Seurat plot
#' @param PCs Which PCs to examine
#' @param nCol Number of columns
#' @param score.thresh Threshold to use for the proportion test of PC
#' significance (see Details)
#' @param plot.x.lim X-axis maximum on each QQ plot.
#' @param plot.y.lim Y-axis maximum on each QQ plot.
#'
#' @return A ggplot object
#'
#' @author Thanks to Omri Wurtzel for integrating with ggplot
#'
#' @import gridExtra
#' @importFrom stats qqplot runif prop.test qunif
#'
#' @export
#'
JackStrawPlot <- function(
  object,
  PCs = 1:5,
  nCol = 3,
  score.thresh = 1e-5,
  plot.x.lim = 0.1,
  plot.y.lim = 0.3
) {
  pAll <- GetDimReduction(object,reduction.type = "pca", slot = "jackstraw")@emperical.p.value
  pAll <- pAll[, PCs, drop = FALSE]
  pAll <- as.data.frame(pAll)
  pAll$Contig <- rownames(x = pAll)
  pAll.l <- melt(data = pAll, id.vars = "Contig")
  colnames(x = pAll.l) <- c("Contig", "PC", "Value")
  qq.df <- NULL
  score.df <- NULL
  for (i in PCs) {
    q <- qqplot(x = pAll[, i], y = runif(n = 1000), plot.it = FALSE)
    #pc.score=mean(q$y[which(q$x <=score.thresh)])
    pc.score <- suppressWarnings(prop.test(
      x = c(
        length(x = which(x = pAll[, i] <= score.thresh)),
        floor(x = nrow(x = pAll) * score.thresh)
      ),
      n = c(nrow(pAll), nrow(pAll))
    )$p.val)
    if (length(x = which(x = pAll[, i] <= score.thresh)) == 0) {
      pc.score <- 1
    }
    if (is.null(x = score.df)) {
      score.df <- data.frame(PC = paste0("PC", i), Score = pc.score)
    } else {
      score.df <- rbind(score.df, data.frame(PC = paste0("PC",i), Score = pc.score))
    }
    if (is.null(x = qq.df)) {
      qq.df <- data.frame(x = q$x, y = q$y, PC = paste0("PC", i))
    } else {
      qq.df <- rbind(qq.df, data.frame(x = q$x, y = q$y, PC = paste0("PC", i)))
    }
  }
  # create new dataframe column to wrap on that includes the PC number and score
  pAll.l$PC.Score <- rep(
    x = paste0(score.df$PC, " ", sprintf("%1.3g", score.df$Score)),
    each = length(x = unique(x = pAll.l$Contig))
  )
  pAll.l$PC.Score <- factor(
    x = pAll.l$PC.Score,
    levels = paste0(score.df$PC, " ", sprintf("%1.3g", score.df$Score))
  )
  gp <- ggplot(data = pAll.l, mapping = aes(sample=Value)) +
    stat_qq(distribution = qunif) +
    facet_wrap("PC.Score", ncol = nCol) +
    labs(x = "Theoretical [runif(1000)]", y = "Empirical") +
    xlim(0, plot.y.lim) +
    ylim(0, plot.x.lim) +
    coord_flip() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed", na.rm = TRUE) +
    theme_bw()
  return(gp)
}

#' Scatter plot of single cell data
#'
#' Creates a scatter plot of two features (typically gene expression), across a
#' set of single cells. Cells are colored by their identity class.
#'
#' @param object Seurat object
#' @inheritParams FetchData
#' @param gene1 First feature to plot. Typically gene expression but can also
#' be metrics, PC scores, etc. - anything that can be retreived with FetchData
#' @param gene2 Second feature to plot.
#' @param cell.ids Cells to include on the scatter plot.
#' @param col.use Colors to use for identity class plotting.
#' @param pch.use Pch argument for plotting
#' @param cex.use Cex argument for plotting
#' @param use.imputed Use imputed values for gene expression (Default is FALSE)
#' @param use.scaled Use scaled data
#' @param use.raw Use raw data
#' @param do.hover Enable hovering over points to view information
#' @param data.hover Data to add to the hover, pass a character vector of features to add. Defaults to cell name and ident. Pass 'NULL' to clear extra information.
#' @param do.identify Opens a locator session to identify clusters of cells.
#' @param dark.theme Use a dark theme for the plot
#' @param do.spline Add a spline (currently hardwired to df=4, to be improved)
#' @param spline.span spline span in loess function call
#' @param \dots Additional arguments to be passed to plot.
#'
#' @return No return, only graphical output
#'
#' @export
#'
GenePlot <- function(
  object,
  gene1,
  gene2,
  cell.ids = NULL,
  col.use = NULL,
  pch.use = 16,
  cex.use = 1.5,
  use.imputed = FALSE,
  use.scaled = FALSE,
  use.raw = FALSE,
  do.hover = FALSE,
  data.hover = 'ident',
  do.identify = FALSE,
  dark.theme = FALSE,
  do.spline = FALSE,
  spline.span = 0.75,
  ...
) {
  cell.ids <- SetIfNull(x = cell.ids, default = object@cell.names)
  #   Don't transpose the data.frame for better compatability with FeatureLocator and the rest of Seurat
  data.use <- as.data.frame(
    x = FetchData(
      object = object,
      vars.all = c(gene1, gene2),
      cells.use = cell.ids,
      use.imputed = use.imputed,
      use.scaled = use.scaled,
      use.raw = use.raw
    )
  )
  #   Ensure that our data is only the cells we're working with and
  #   the genes we want. This step seems kind of redundant though...
  data.plot <- data.use[cell.ids, c(gene1, gene2)]
  #   Set names to 'x' and 'y' for easy calling later on
  names(x = data.plot) <- c('x', 'y')
  ident.use <- as.factor(x = object@ident[cell.ids])
  if (length(x = col.use) > 1) {
    col.use <- col.use[as.numeric(x = ident.use)]
  } else {
    col.use <- SetIfNull(x = col.use, default = as.numeric(x = ident.use))
  }
  gene.cor <- round(x = cor(x = data.plot$x, y = data.plot$y), digits = 2)
  if (dark.theme) {
    par(bg = 'black')
    col.use <- sapply(
      X = col.use,
      FUN = function(color) ifelse(
        test = all(col2rgb(color) == 0),
        yes = 'white',
        no = color
      )
    )
    axes = FALSE
    col.lab = 'white'
  } else {
    axes = TRUE
    col.lab = 'black'
  }
  #   Plot the data
  plot(
    x = data.plot$x,
    y = data.plot$y,
    xlab = gene1,
    ylab = gene2,
    col = col.use,
    cex = cex.use,
    main = gene.cor,
    pch = pch.use,
    axes = axes,
    col.lab = col.lab,
    col.main = col.lab,
    ...
  )
  if (dark.theme) {
    axis(
      side = 1,
      at = NULL,
      labels = TRUE,
      col.axis = col.lab,
      col = col.lab
    )
    axis(
      side = 2,
      at = NULL,
      labels = TRUE,
      col.axis = col.lab,
      col = col.lab
    )
  }
  if (do.spline) {
    # spline.fit <- smooth.spline(x = g1, y = g2, df = 4)
    spline.fit <- smooth.spline(x = data.plot$x, y = data.plot$y, df = 4)
    #lines(spline.fit$x,spline.fit$y,lwd=3)
    #spline.fit=smooth.spline(g1,g2,df = 4)
    # loess.fit <- loess(formula = g2 ~ g1, span=spline.span)
    loess.fit <- loess(formula = y ~ x, data = data.plot, span = spline.span)
    #lines(spline.fit$x,spline.fit$y,lwd=3)
    # points(x = g1, y = loess.fit$fitted, col="darkblue")
    points(x = data.plot$x, y = loess.fit$fitted, col = 'darkblue')
  }
  if (do.identify | do.hover) {
    #   This is where that untransposed renamed data.frame comes in handy
    p <- ggplot2::ggplot(data = data.plot, mapping = aes(x = x, y = y))
    p <- p + geom_point(
      mapping = aes(color = colors),
      size = cex.use,
      shape = pch.use,
      color = col.use
    )
    p <- p + labs(title = gene.cor, x = gene1, y = gene2)
    if (do.hover) {
      names(x = data.plot) <- c(gene1, gene2)
      if (is.null(x = data.hover)) {
        features.info <- NULL
      } else {
        features.info <- FetchData(object = object, vars.all = data.hover)
      }
      return(HoverLocator(
        plot = p,
        data.plot = data.plot,
        features.info = features.info,
        dark.theme = dark.theme,
        title = gene.cor
      ))
    } else if (do.identify) {
      return(FeatureLocator(
        plot = p,
        data.plot = data.plot,
        dark.theme = dark.theme
      ))
    }
  }
}

#' Cell-cell scatter plot
#'
#' Creates a plot of scatter plot of genes across two single cells
#'
#' @param object Seurat object
#' @param cell1 Cell 1 name (can also be a number, representing the position in
#' object@@cell.names)
#' @param cell2 Cell 2 name (can also be a number, representing the position in
#' object@@cell.names)
#' @param gene.ids Genes to plot (default, all genes)
#' @param col.use Colors to use for the points
#' @param nrpoints.use Parameter for smoothScatter
#' @param pch.use Point symbol to use
#' @param cex.use Point size
#' @param do.hover Enable hovering over points to view information
#' @param do.identify Opens a locator session to identify clusters of cells.
#' points to reveal gene names (hit ESC to stop)
#' @param \dots Additional arguments to pass to smoothScatter
#'
#' @return No return value (plots a scatter plot)
#'
#' @importFrom stats cor
#' @importFrom graphics smoothScatter
#'
#' @export
#'
CellPlot <- function(
  object,
  cell1,
  cell2,
  gene.ids = NULL,
  col.use = "black",
  nrpoints.use = Inf,
  pch.use = 16,
  cex.use = 0.5,
  do.hover = FALSE,
  do.identify = FALSE,
  ...
) {
  gene.ids <- SetIfNull(x = gene.ids, default = rownames(x = object@data))
  #   Transpose this data.frame so that the genes are in the row for
  #   easy selecting with do.identify
  data.plot <- as.data.frame(
    x = t(
      x = FetchData(
        object = object,
        vars.all = gene.ids,
        cells.use = c(cell1, cell2)
      )
    )
  )
  #   Set names for easy calling with ggplot
  names(x = data.plot) <- c('x', 'y')
  gene.cor <- round(x = cor(x = data.plot$x, y = data.plot$y), digits = 2)
  smoothScatter(
    x = data.plot$x,
    y = data.plot$y,
    xlab = cell1,
    ylab = cell2,
    col = col.use,
    nrpoints = nrpoints.use,
    pch = pch.use,
    cex = cex.use,
    main = gene.cor
  )
  if (do.identify | do.hover) {
    #   This is where that untransposed renamed data.frame comes in handy
    p <- ggplot2::ggplot(data = data.plot, mapping = aes(x = x, y = y))
    p <- p + geom_point(
      mapping = aes(color = colors),
      size = cex.use,
      shape = pch.use,
      color = col.use
    )
    p <- p + labs(title = gene.cor, x = cell1, y = cell2)
    if (do.hover) {
      names(x = data.plot) <- c(cell1, cell2)
      return(HoverLocator(plot = p, data.plot = data.plot, title = gene.cor))
    } else if (do.identify) {
      return(FeatureLocator(plot = p, data.plot = data.plot, ...))
    }
  }
}

#' Dimensional reduction heatmap
#'
#' Draws a heatmap focusing on a principal component. Both cells and genes are sorted by their
#' principal component scores. Allows for nice visualization of sources of heterogeneity in the dataset.
#'
#' @param object Seurat object.
#' @param reduction.type Which dimmensional reduction t use
#' @param dim.use Dimensions to plot
#' @param cells.use A list of cells to plot. If numeric, just plots the top cells.
#' @param num.genes NUmber of genes to plot
#' @param use.full Use the full PCA (projected PCA). Default is FALSE
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param do.return If TRUE, returns plot object, otherwise plots plot object
#' @param col.use Color to plot.
#' @param use.scale Default is TRUE: plot scaled data. If FALSE, plot raw data on the heatmap.
#' @param do.balanced Plot an equal number of genes with both + and - scores.
#' @param remove.key Removes the color key from the plot.
#' @param label.columns Labels for columns
#' @param ... Extra parameters for heatmap plotting.
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @importFrom graphics par
#'
#' @export
#'
DimHeatmap <- function(
  object,
  reduction.type = "pca",
  dim.use = 1,
  cells.use = NULL,
  num.genes = 30,
  use.full = FALSE,
  disp.min = -2.5,
  disp.max = 2.5,
  do.return = FALSE,
  col.use = PurpleAndYellow(),
  use.scale = TRUE,
  do.balanced = FALSE,
  remove.key = FALSE,
  label.columns = NULL,

  ...
) {
  num.row <- floor(x = length(x = dim.use) / 3.01) + 1
  orig_par <- par()$mfrow
  par(mfrow = c(num.row, min(length(x = dim.use), 3)))
  cells <- cells.use
  plots <- c()

  if (is.null(x = label.columns)) {
    label.columns <- ! (length(x = dim.use) > 1)
  }
  for (ndim in dim.use) {
    if (is.numeric(x = (cells))) {
      cells.use <- DimTopCells(
        object = object,
        dim.use = ndim,
        reduction.type = reduction.type,
        num.cells = cells,
        do.balanced = do.balanced
      )
    } else {
      cells.use <- SetIfNull(x = cells, default = object@cell.names)
    }
    genes.use <- rev(x = DimTopGenes(
      object = object,
      dim.use = ndim,
      reduction.type = reduction.type,
      num.genes = num.genes,
      use.full = use.full,
      do.balanced = do.balanced
    ))
    dim.scores <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "cell.embeddings"
    )
    dim.key <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "key"
    )
    cells.ordered <- cells.use[order(dim.scores[cells.use, paste0(dim.key, ndim)])]
    #determine assay type
    data.use <- NULL
    assays.use <- c("RNA", names(x = object@assay))
    if (! use.scale) {
      slot.use="data"
    } else {
      slot.use <- "scale.data"
    }
    for (assay.check in assays.use) {
      data.assay <- GetAssayData(
        object = object,
        assay.type = assay.check,
        slot = slot.use
      )
      genes.intersect <- intersect(x = genes.use, y = rownames(x = data.assay))
      new.data <- data.assay[genes.intersect, cells.ordered]
      if (! is.matrix(x = new.data)) {
        new.data <- as.matrix(x = new.data)
      }
      data.use <- rbind(data.use, new.data)
    }
    #data.use <- object@scale.data[genes.use, cells.ordered]
    data.use <- MinMax(data = data.use, min = disp.min, max = disp.max)
    #if (!(use.scale)) data.use <- as.matrix(object@data[genes.use, cells.ordered])
    vline.use <- NULL
    hmTitle <- paste(dim.key, ndim)
    if (remove.key || length(dim.use) > 1) {
      hmFunction <- "heatmap2NoKey(data.use, Rowv = NA, Colv = NA, trace = \"none\", col = col.use, dimTitle = hmTitle, "
    } else {
      hmFunction <- "heatmap.2(data.use,Rowv=NA,Colv=NA,trace = \"none\",col=col.use, dimTitle = hmTitle, "
    }
    if (! label.columns) {
      hmFunction <- paste0(hmFunction, "labCol='', ")
    }
    hmFunction <- paste0(hmFunction, "...)")
    #print(hmFunction)
    eval(expr = parse(text = hmFunction))
  }
  if (do.return) {
    return(data.use)
  }
  # reset graphics parameters
  par(mfrow = orig_par)
}

#' Principal component heatmap
#'
#' Draws a heatmap focusing on a principal component. Both cells and genes are sorted by their principal component scores.
#' Allows for nice visualization of sources of heterogeneity in the dataset.
#'
#' @param object Seurat object.
#' @param pc.use PCs to plot
#' @param cells.use A list of cells to plot. If numeric, just plots the top cells.
#' @param num.genes Number of genes to plot
#' @param use.full Use the full PCA (projected PCA). Default is FALSE
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param do.return If TRUE, returns plot object, otherwise plots plot object
#' @param col.use Color to plot.
#' @param use.scale Default is TRUE: plot scaled data. If FALSE, plot raw data on the heatmap.
#' @param do.balanced Plot an equal number of genes with both + and - scores.
#' @param remove.key Removes the color key from the plot.
#' @param label.columns Whether to label the columns. Default is TRUE for 1 PC, FALSE for > 1 PC
#' @param ... Extra parameters for DimHeatmap
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @export
#'
PCHeatmap <- function(
  object,
  pc.use = 1,
  cells.use = NULL,
  num.genes = 30,
  use.full = FALSE,
  disp.min = -2.5,
  disp.max = 2.5,
  do.return = FALSE,
  col.use = PurpleAndYellow(),
  use.scale = TRUE,
  do.balanced = FALSE,
  remove.key = FALSE,
  label.columns = NULL,
  ...
) {
  return(DimHeatmap(
    object,
    reduction.type = "pca",
    dim.use = pc.use,
    cells.use = cells.use,
    num.genes = num.genes,
    use.full = use.full,
    disp.min = disp.min,
    disp.max = disp.max,
    do.return = do.return,
    col.use = col.use,
    use.scale = use.scale,
    do.balanced = do.balanced,
    remove.key = remove.key,
    label.columns = label.columns,
    ...
  ))
}

#' Independent component heatmap
#'
#' Draws a heatmap focusing on a principal component. Both cells and genes are sorted by their
#' principal component scores. Allows for nice visualization of sources of heterogeneity
#' in the dataset."()
#'
#' @param object Seurat object
#' @param ic.use Components to use
#' @param cells.use A list of cells to plot. If numeric, just plots the top cells.
#' @param num.genes NUmber of genes to plot
#' @param disp.min Minimum display value (all values below are clipped)
#' @param disp.max Maximum display value (all values above are clipped)
#' @param do.return If TRUE, returns plot object, otherwise plots plot object
#' @param col.use Colors to plot.
#' @param use.scale Default is TRUE: plot scaled data. If FALSE, plot raw data on the heatmap.
#' @param do.balanced Plot an equal number of genes with both + and - scores.
#' @param remove.key Removes the color key from the plot.
#' @param label.columns Labels for columns
#' @param ... Extra parameters passed to DimHeatmap
#'
#' @return If do.return==TRUE, a matrix of scaled values which would be passed
#' to heatmap.2. Otherwise, no return value, only a graphical output
#'
#' @export
#'
ICHeatmap <- function(
  object,
  ic.use = 1,
  cells.use = NULL,
  num.genes = 30,
  disp.min = -2.5,
  disp.max = 2.5,
  do.return = FALSE,
  col.use = PurpleAndYellow(),
  use.scale = TRUE,
  do.balanced = FALSE,
  remove.key = FALSE,
  label.columns = NULL,
  ...
) {
  return(DimHeatmap(
    object = object,
    reduction.type = "ica",
    dim.use = ic.use,
    cells.use = cells.use,
    num.genes = num.genes,
    disp.min = disp.min,
    disp.max = disp.max,
    do.return = do.return,
    col.use = col.use,
    use.scale = use.scale,
    do.balanced = do.balanced,
    remove.key = remove.key,
    label.columns = label.columns,
    ...
  ))
}


#' Visualize Dimensional Reduction genes
#'
#' Visualize top genes associated with reduction components
#'
#' @param object Seurat object
#' @param reduction.type Reduction technique to visualize results for
#' @param dims.use Number of dimensions to display
#' @param num.genes Number of genes to display
#' @param use.full Use reduction values for full dataset (i.e. projected dimensional reduction values)
#' @param font.size Font size
#' @param nCol Number of columns to display
#' @param do.balanced Return an equal number of genes with + and - scores. If FALSE (default), returns
#' the top genes ranked by the scores absolute values
#'
#' @return Graphical, no return value
#'
#' @importFrom graphics axis
#'
#' @export
#'
VizDimReduction <- function(
  object,
  reduction.type = "pca",
  dims.use = 1:5,
  num.genes = 30,
  use.full = FALSE,
  font.size = 0.5,
  nCol = NULL,
  do.balanced = FALSE
) {
  if (use.full) {
    dim.scores <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "gene.loadings.full"
    )
  } else {
    dim.scores <- GetDimReduction(
      object = object,
      reduction.type = reduction.type,
      slot = "gene.loadings"
    )
  }
  if (is.null(x = nCol)) {
    if (length(x = dims.use) > 6) {
      nCol <- 3
    } else if (length(x = dims.use) > 9) {
      nCol <- 4
    } else {
      nCol <- 2
    }
  }
  num.row <- floor(x = length(x = dims.use) / nCol - 1e-5) + 1
  par(mfrow = c(num.row, nCol))
  for (i in dims.use) {
    subset.use <- dim.scores[DimTopGenes(
      object = object,
      dim.use = i,
      reduction.type = reduction.type,
      num.genes = num.genes,
      use.full = use.full,
      do.balanced = do.balanced
    ), ]
    plot(
      x = subset.use[, i],
      y = 1:nrow(x = subset.use),
      pch = 16,
      col = "blue",
      xlab = paste0("PC", i),
      yaxt="n",
      ylab=""
    )
    axis(
      side = 2,
      at = 1:nrow(x = subset.use),
      labels = rownames(x = subset.use),
      las = 1,
      cex.axis = font.size
    )
  }
  ResetPar()
}

#' Visualize PCA genes
#'
#' Visualize top genes associated with principal components
#'
#' @param object Seurat object
#' @param pcs.use Number of PCs to display
#' @param num.genes Number of genes to display
#' @param use.full Use full PCA (i.e. the projected PCA, by default FALSE)
#' @param font.size Font size
#' @param nCol Number of columns to display
#' @param do.balanced Return an equal number of genes with both + and - PC scores.
#' If FALSE (by default), returns the top genes ranked by the score's absolute values
#'
#' @return Graphical, no return value
#'
#' @export
#'
VizPCA <- function(
  object,
  pcs.use = 1:5,
  num.genes = 30,
  use.full = FALSE,
  font.size = 0.5,
  nCol = NULL,
  do.balanced = FALSE
) {
  VizDimReduction(
    object = object,
    reduction.type = "pca",
    dims.use = pcs.use,
    num.genes = num.genes,
    use.full = use.full,
    font.size = font.size,
    nCol = nCol,
    do.balanced = do.balanced
  )
}

#' Visualize ICA genes
#'
#' Visualize top genes associated with principal components
#'
#' @param object Seurat object
#' @param ics.use Number of ICs to display
#' @param num.genes Number of genes to display
#' @param use.full Use full ICA (i.e. the projected ICA, by default FALSE)
#' @param font.size Font size
#' @param nCol Number of columns to display
#' @param do.balanced Return an equal number of genes with both + and - IC scores.
#' If FALSE (by default), returns the top genes ranked by the score's absolute values
#'
#' @return Graphical, no return value
#'
#' @export
#'
VizICA <- function(
  object,
  ics.use = 1:5,
  num.genes = 30,
  use.full = FALSE,
  font.size = 0.5,
  nCol = NULL,
  do.balanced = FALSE
) {
  VizDimReduction(
    object = object,
    reduction.type = "ica",
    dims.use = pcs.use,
    num.genes = num.genes,
    use.full = use.full,
    font.size = font.size,
    nCol = nCol,
    do.balanced = do.balanced
  )
}

#' Dimensional reduction plot
#'
#' Graphs the output of a dimensional reduction technique (PCA by default).
#' Cells are colored by their identity class.
#'
#' @param object Seurat object
#' @param reduction.use Which dimensionality reduction to use. Default is
#' "pca", can also be "tsne", or "ica", assuming these are precomputed.
#' @param dim.1 Dimension for x-axis (default 1)
#' @param dim.2 Dimension for y-axis (default 2)
#' @param cells.use Vector of cells to plot (default is all cells)
#' @param pt.size Adjust point size for plotting
#' @param do.return Return a ggplot2 object (default : FALSE)
#' @param do.bare Do only minimal formatting (default : FALSE)
#' @param cols.use Vector of colors, each color corresponds to an identity
#' class. By default, ggplot assigns colors.
#' @param group.by Group (color) cells in different ways (for example, orig.ident)
#' @param pt.shape If NULL, all points are circles (default). You can specify any
#' cell attribute (that can be pulled with FetchData) allowing for both different colors and
#' different shapes on cells.
#' @param do.hover Enable hovering over points to view information
#' @param data.hover Data to add to the hover, pass a character vector of features to add. Defaults to cell name and ident. Pass 'NULL' to clear extra information.
#' @param do.identify Opens a locator session to identify clusters of cells.
#' @param do.label Whether to label the clusters
#' @param label.size Sets size of labels
#' @param no.legend Setting to TRUE will remove the legend
#' @param no.axes Setting to TRUE will remove the axes
#' @param dark.theme Use a dark theme for the plot
#' @param ... Extra parameters to FeatureLocator for do.identify = TRUE
#'
#' @return If do.return==TRUE, returns a ggplot2 object. Otherwise, only
#' graphical output.
#'
#' @seealso \code{FeatureLocator}
#'
#' @import SDMTools
#' @importFrom stats median
#' @importFrom dplyr summarize group_by
#'
#' @export
#'
DimPlot <- function(
  object,
  reduction.use = "pca",
  dim.1 = 1,
  dim.2 = 2,
  cells.use = NULL,
  pt.size = 3,
  do.return = FALSE,
  do.bare = FALSE,
  cols.use = NULL,
  group.by = "ident",
  pt.shape = NULL,
  do.hover = FALSE,
  data.hover = 'ident',
  do.identify = FALSE,
  do.label = FALSE,
  label.size = 1,
  no.legend = FALSE,
  no.axes = FALSE,
  dark.theme = FALSE,
  ...
) {
  embeddings.use = GetDimReduction(object = object, reduction.type = reduction.use, slot = "cell.embeddings")
  if (length(x = embeddings.use) == 0) {
    stop(paste(reduction.use, "has not been run for this object yet."))
  }
  cells.use <- SetIfNull(x = cells.use, default = colnames(x = object@data))
  dim.code <- GetDimReduction(
    object = object,
    reduction.type = reduction.use,
    slot = "key"
  )
  dim.codes <- paste0(dim.code, c(dim.1, dim.2))
  data.plot <- as.data.frame(x = embeddings.use)
  # data.plot <- as.data.frame(GetDimReduction(object, reduction.type = reduction.use, slot = ""))
  cells.use <- intersect(x = cells.use, y = rownames(x = data.plot))
  data.plot <- data.plot[cells.use, dim.codes]
  ident.use <- as.factor(x = object@ident[cells.use])
  if (group.by != "ident") {
    ident.use <- as.factor(x = FetchData(
      object = object,
      vars.all = group.by
    )[cells.use, 1])
  }
  data.plot$ident <- ident.use
  data.plot$x <- data.plot[, dim.codes[1]]
  data.plot$y <- data.plot[, dim.codes[2]]
  data.plot$pt.size <- pt.size
  p <- ggplot(data = data.plot, mapping = aes(x = x, y = y)) +
    geom_point(mapping = aes(colour = factor(x = ident)), size = pt.size)
  if (! is.null(x = pt.shape)) {
    shape.val <- FetchData(object = object, vars.all = pt.shape)[cells.use, 1]
    if (is.numeric(shape.val)) {
      shape.val <- cut(x = shape.val, breaks = 5)
    }
    data.plot[, "pt.shape"] <- shape.val
    p <- ggplot(data = data.plot, mapping = aes(x = x, y = y)) +
      geom_point(
        mapping = aes(colour = factor(x = ident), shape = factor(x = pt.shape)),
        size = pt.size
      )
  }
  if (! is.null(x = cols.use)) {
    p <- p + scale_colour_manual(values = cols.use)
  }
  p2 <- p +
    xlab(label = dim.codes[[1]]) +
    ylab(label = dim.codes[[2]]) +
    scale_size(range = c(pt.size, pt.size))
  p3 <- p2 +
    SetXAxisGG() +
    SetYAxisGG() +
    SetLegendPointsGG(x = 6) +
    SetLegendTextGG(x = 12) +
    no.legend.title +
    theme_bw() +
    NoGrid()
  p3 <- p3 + theme(legend.title = element_blank())
  if (do.label) {
    data.plot %>%
      dplyr::group_by(ident) %>%
      summarize(x = median(x = x), y = median(x = y)) -> centers
    p3 <- p3 +
      geom_point(data = centers, mapping = aes(x = x, y = y), size = 0, alpha = 0) +
      geom_text(data = centers, mapping = aes(label = ident), size = label.size)
  }
  if (dark.theme) {
    p <- p + DarkTheme()
    p3 <- p3 + DarkTheme()
  }
  if (no.legend) {
    p3 <- p3 + theme(legend.position = "none")
  }
  if (no.axes) {
    p3 <- p3 + theme(
      axis.line = element_blank(),
      axis.text.x = element_blank(),
      axis.text.y = element_blank(),
      axis.ticks = element_blank(),
      axis.title.x = element_blank(),
      axis.title.y = element_blank(),
      panel.background = element_blank(),
      panel.border = element_blank(),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      plot.background = element_blank()
    )
  }
  if (do.identify || do.hover) {
    if (do.bare) {
      plot.use <- p
    } else {
      plot.use <- p3
    }
    if (do.hover) {
      if (is.null(x = data.hover)) {
        features.info <- NULL
      } else {
        features.info <- FetchData(object = object, vars.all = data.hover)
      }
      return(HoverLocator(
        plot = plot.use,
        data.plot = data.plot,
        features.info = features.info,
        dark.theme = dark.theme
      ))
    } else if (do.identify) {
      return(FeatureLocator(
        plot = plot.use,
        data.plot = data.plot,
        dark.theme = dark.theme,
        ...
      ))
    }
  }
  if (do.return) {
    if (do.bare) {
      return(p)
    } else {
      return(p3)
    }
  }
  if (do.bare) {
    print(p)
  } else {
    print(p3)
  }
}

#' Plot PCA map
#'
#' Graphs the output of a PCA analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @export
#'
PCAPlot <- function(object, ...) {
  return(DimPlot(object = object, reduction.use = "pca", label.size = 6, ...))
}

#' Plot Diffusion map
#'
#' Graphs the output of a Diffusion map analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @export
DMPlot <- function(object, ...) {
  return(DimPlot(object = object, reduction.use = "dm", label.size = 6, ...))
}

#' Plot ICA map
#'
#' Graphs the output of a ICA analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @export
#'
ICAPlot <- function(object, ...) {
  return(DimPlot(object = object, reduction.use = "ica", ...))
}

#' Plot tSNE map
#'
#' Graphs the output of a tSNE analysis
#' Cells are colored by their identity class.
#'
#' This function is a wrapper for DimPlot. See ?DimPlot for a full list of possible
#' arguments which can be passed in here.
#'
#' @param object Seurat object
#' @param do.label FALSE by default. If TRUE, plots an alternate view where the center of each
#' cluster is labeled
#' @param pt.size Set the point size
#' @param label.size Set the size of the text labels
#' @param cells.use Vector of cell names to use in the plot.
#' @param colors.use Manually set the color palette to use for the points
#' @param \dots Additional parameters to DimPlot, for example, which dimensions to plot.
#'
#' @seealso DimPlot
#'
#' @export
#'
TSNEPlot <- function(
  object,
  do.label = FALSE,
  pt.size=1,
  label.size=4,
  cells.use = NULL,
  colors.use = NULL,
  ...
) {
  return(DimPlot(
    object = object,
    reduction.use = "tsne",
    cells.use = cells.use,
    pt.size = pt.size,
    do.label = do.label,
    label.size = label.size,
    cols.use = colors.use,
    ...
  ))
}

#' Quickly Pick Relevant Dimensions
#'
#' Plots the standard deviations (or approximate singular values if running PCAFast)
#' of the principle components for easy identification of an elbow in the graph.
#' This elbow often corresponds well with the significant dims and is much faster to run than
#' Jackstraw
#'
#'
#' @param object Seurat object
#' @param reduction.type  Type of dimensional reduction to plot data for
#' @param dims.plot Number of dimensions to plot sd for
#' @param xlab X axis label
#' @param ylab Y axis label
#' @param title Plot title
#'
#' @return Returns ggplot object
#'
#' @export
#'
DimElbowPlot <- function(
  object,
  reduction.type = "pca",
  dims.plot = 20,
  xlab = "",
  ylab = "",
  title = ""
) {
  data.use <- GetDimReduction(
    object = object,
    reduction.type = reduction.type,
    slot = "sdev"
  )
  if (length(data.use) == 0) {
    stop(paste("No standard deviation info stored for", reduction.type))
  }
  if (length(x = data.use) < dims.plot) {
    warning(paste(
      "The object only has information for",
      length(x = data.use),
      "PCs."
    ))
    dims.plot <- length(x = data.use)
  }
  data.use <- data.use[1:dims.plot]
  dims <- 1:length(x = data.use)
  data.plot <- data.frame(dims, data.use)
  plot <- ggplot(data = data.plot, mapping = aes(x = dims, y = data.use)) +
    geom_point()
  if (reduction.type == "pca") {
    plot <- plot +
      labs(y = "Standard Deviation of PC", x = "PC", title = title)
  } else if(reduction.type == "ica"){
    plot <- plot +
      labs(y = "Standard Deviation of IC", x = "IC", title = title)
  } else {
    plot <- plot +
      labs(y = ylab, x = xlab, title = title)
  }
  return(plot)
}

#' Quickly Pick Relevant PCs
#'
#' Plots the standard deviations (or approximate singular values if running PCAFast)
#' of the principle components for easy identification of an elbow in the graph.
#' This elbow often corresponds well with the significant PCs and is much faster to run.
#'
#' @param object Seurat object
#' @param num.pc Number of PCs to plot
#'
#' @return Returns ggplot object
#'
#' @export
#'
PCElbowPlot <- function(object, num.pc = 20) {
  return(DimElbowPlot(
    object = object,
    reduction.type = "pca",
    dims.plot = num.pc
  ))
}

#' View variable genes
#'
#' @param object Seurat object
#' @param do.text Add text names of variable genes to plot (default is TRUE)
#' @param cex.use Point size
#' @param cex.text.use Text size
#' @param do.spike FALSE by default. If TRUE, color all genes starting with ^ERCC a different color
#' @param pch.use Pch value for points
#' @param col.use Color to use
#' @param spike.col.use if do.spike, color for spike-in genes
#' @param plot.both Plot both the scaled and non-scaled graphs.
#' @param do.contour Draw contour lines calculated based on all genes
#' @param contour.lwd Contour line width
#' @param contour.col Contour line color
#' @param contour.lty Contour line type
#' @param x.low.cutoff Bottom cutoff on x-axis for identifying variable genes
#' @param x.high.cutoff Top cutoff on x-axis for identifying variable genes
#' @param y.cutoff Bottom cutoff on y-axis for identifying variable genes
#' @param y.high.cutoff Top cutoff on y-axis for identifying variable genes
#'
#' @importFrom stats cor loess smooth.spline
#' @importFrom grDevices col2rgb
#' @importFrom graphics axis points smoothScatter contour points text
#'
#' @export
#'
VariableGenePlot <- function(
  object,
  do.text = TRUE,
  cex.use = 0.5,
  cex.text.use = 0.5,
  do.spike = FALSE,
  pch.use = 16,
  col.use = "black",
  spike.col.use = "red",
  plot.both = FALSE,
  do.contour = TRUE,
  contour.lwd = 3,
  contour.col = "white",
  contour.lty = 2,
  x.low.cutoff = 0.1,
  x.high.cutoff = 8,
  y.cutoff = 1,
  y.high.cutoff = Inf
) {
  gene.mean <- object@hvg.info[, 1]
  gene.dispersion <- object@hvg.info[, 2]
  gene.dispersion.scaled <- object@hvg.info[, 3]
  names(x = gene.mean) <- names(x = gene.dispersion) <- names(x = gene.dispersion.scaled) <- rownames(x = object@data)
  pass.cutoff <- names(x = gene.mean)[which(
    x = (
      (gene.mean > x.low.cutoff) & (gene.mean < x.high.cutoff)
    ) &
      (gene.dispersion.scaled > y.cutoff) &
      (gene.dispersion.scaled < y.high.cutoff)
  )]
  if (do.spike) {
    spike.genes <- rownames(x = SubsetRow(data = object@data, code = "^ERCC"))
  }
  if (plot.both) {
    par(mfrow = c(1, 2))
    smoothScatter(
      x = gene.mean,
      y = gene.dispersion,
      pch = pch.use,
      cex = cex.use,
      col = col.use,
      xlab = "Average expression",
      ylab = "Dispersion",
      nrpoints = Inf
    )
    if (do.contour) {
      data.kde <- kde2d(x = gene.mean, y = gene.dispersion)
      contour(
        x = data.kde,
        add = TRUE,
        lwd = contour.lwd,
        col = contour.col,
        lty = contour.lty
      )
    }
    if (do.spike) {
      points(
        x = gene.mean[spike.genes],
        y = gene.dispersion[spike.genes],
        pch = 16,
        cex = cex.use,
        col = spike.col.use
      )
    }
    if (do.text) {
      text(
        x = gene.mean[pass.cutoff],
        y = gene.dispersion[pass.cutoff],
        labels = pass.cutoff,
        cex = cex.text.use
      )
    }
  }
  smoothScatter(
    x = gene.mean,
    y = gene.dispersion.scaled,
    pch = pch.use,
    cex = cex.use,
    col = col.use,
    xlab = "Average expression",
    ylab = "Dispersion",
    nrpoints = Inf
  )
  if (do.contour) {
    data.kde <- kde2d(x = gene.mean, y = gene.dispersion.scaled)
    contour(
      x = data.kde,
      add = TRUE,
      lwd = contour.lwd,
      col = contour.col,
      lty = contour.lty
    )
  }
  if (do.spike) {
    points(
      x = gene.mean[spike.genes],
      y = gene.dispersion.scaled[spike.genes],
      pch = 16,
      cex = cex.use,
      col = spike.col.use,
      nrpoints = Inf
    )
  }
  if (do.text) {
    text(
      x = gene.mean[pass.cutoff],
      y = gene.dispersion.scaled[pass.cutoff],
      labels = pass.cutoff,
      cex = cex.text.use
    )
  }
}

#' Highlight classification results
#'
#' This function is useful to view where proportionally the clusters returned from
#' classification map to the clusters present in the given object. Utilizes the FeaturePlot()
#' function to color clusters in object.
#'
#' @param object Seurat object on which the classifier was trained and
#' onto which the classification results will be highlighted
#' @param clusters vector of cluster ids (output of ClassifyCells)
#' @param ... additional parameters to pass to FeaturePlot()
#'
#' @return Returns a feature plot with clusters highlighted by proportion of cells
#' mapping to that cluster
#'
#' @export
#'
VizClassification <- function(object, clusters, ...) {
  cluster.dist <- prop.table(x = table(out)) # What is out?
  object@meta.data$Classification <- numeric(nrow(x = object@meta.data))
  for (cluster in 1:length(x = cluster.dist)) {
    cells.to.highlight <- WhichCells(object, names(cluster.dist[cluster]))
    if (length(x = cells.to.highlight) > 0) {
      object@meta.data[cells.to.highlight, ]$Classification <- cluster.dist[cluster]
    }
  }
  if (any(grepl(pattern = "cols.use", x = deparse(match.call())))) {
    return(FeaturePlot(object, "Classification", ...))
  }
  cols.use = c("#f6f6f6", "black")
  return(FeaturePlot(object, "Classification", cols.use = cols.use, ...))
}

#' Plot phylogenetic tree
#'
#' Plots previously computed phylogenetic tree (from BuildClusterTree)
#'
#' @param object Seurat object
#' @param \dots Additional arguments for plotting the phylogeny
#'
#' @return Plots dendogram (must be precomputed using BuildClusterTree), returns no value
#'
#' @importFrom ape plot.phylo
#' @importFrom ape nodelabels
#'
#' @export
#'
PlotClusterTree <- function(object, ...) {
  if (length(x = object@cluster.tree) == 0) {
    stop("Phylogenetic tree does not exist, build using BuildClusterTree")
  }
  data.tree <- object@cluster.tree[[1]]
  plot.phylo(x = data.tree, direction = "downwards", ...)
  nodelabels()
}

#' Color tSNE Plot Based on Split
#'
#' Returns a tSNE plot colored based on whether the cells fall in clusters
#' to the left or to the right of a node split in the cluster tree.
#'
#' @param object Seurat object
#' @param node Node in cluster tree on which to base the split
#' @param color1 Color for the left side of the split
#' @param color2 Color for the right side of the split
#' @param color3 Color for all other cells
#' @inheritDotParams TSNEPlot -object
#' @return Returns a tSNE plot
#' @export
ColorTSNESplit <- function(
  object,
  node,
  color1 = "red",
  color2 = "blue",
  color3 = "gray",
  ...
) {
  tree <- object@cluster.tree[[1]]
  split <- tree$edge[which(x = tree$edge[,1] == node), ][, 2]
  all.children <- DFT(
    tree = tree,
    node = tree$edge[,1][1],
    only.children = TRUE
  )
  left.group <- DFT(tree = tree, node = split[1], only.children = TRUE)
  right.group <- DFT(tree = tree, node = split[2], only.children = TRUE)
  if (any(is.na(x = left.group))) {
    left.group <- split[1]
  }
  if (any(is.na(x = right.group))) {
    right.group <- split[2]
  }
  remaining.group <- setdiff(x = all.children, y = c(left.group, right.group))
  left.cells <- WhichCells(object = object, ident = left.group)
  right.cells <- WhichCells(object = object, ident = right.group)
  remaining.cells <- WhichCells(object = object, ident = remaining.group)
  object <- SetIdent(
    object = object,
    cells.use = left.cells,
    ident.use = "Left Split"
  )
  object <- SetIdent(
    object = object,
    cells.use = right.cells,
    ident.use = "Right Split"
  )
  object <- SetIdent(
    object = object,
    cells.use = remaining.cells,
    ident.use = "Not in Split"
  )
  colors.use = c(color1, color3, color2)
  return(TSNEPlot(object = object, colors.use = colors.use, ...))
}

#' Plot k-means clusters
#'
#' @param object A Seurat object
#' @param cells.use Cells to include in the heatmap
#' @param genes.cluster Clusters to include in heatmap
#' @param max.genes Maximum number of genes to include in the heatmap
#' @param slim.col.label Instead of displaying every cell name on the heatmap,
#' display only the identity class name once for each group
#' @param remove.key Removes teh color key from the plot
#' @param row.lines Color separations of clusters
#' @param ... Extra parameters to DoHeatmap
#'
#' @seealso \code{DoHeatmap}
#'
#' @export
#'
KMeansHeatmap <- function(
  object,
  cells.use = object@cell.names,
  genes.cluster = NULL,
  max.genes = 1e6,
  slim.col.label = TRUE,
  remove.key = TRUE,
  row.lines = TRUE,
  ...
) {
  genes.cluster <- SetIfNull(
    x = genes.cluster,
    default = unique(x = object@kmeans@gene.kmeans.obj$cluster)
  )
  genes.use <- GenesInCluster(
    object = object,
    cluster.num = genes.cluster,
    max.genes = max.genes
  )
  cluster.lengths <- sapply(
    X = genes.cluster,
    FUN = function(x) {
      return(length(x = GenesInCluster(object = object, cluster.num = x)))
    }
  )
  #print(cluster.lengths)
  # if (row.lines) {
  #   rowsep.use <- cumsum(x = cluster.lengths)
  # } else {
  #   rowsep.use <- NA
  # }
  DoHeatmap(
    object = object,
    cells.use = cells.use,
    genes.use = genes.use,
    slim.col.label = slim.col.label,
    remove.key = remove.key,
    # rowsep = rowsep.use,
    ...
  )
}

#' Node Heatmap
#'
#' Takes an object, a marker list (output of FindAllMarkers), and a node
#' and plots a heatmap where genes are ordered vertically by the splits present
#' in the object@@cluster.tree slot.
#'
#' @param object Seurat object. Must have the cluster.tree slot filled (use BuildClusterTree)
#' @param marker.list List of marker genes given from the FindAllMarkersNode function
#' @param node Node in the cluster tree from which to start the plot, defaults to highest node in marker list
#' @param max.genes Maximum number of genes to keep for each division
#' @param ... Additional parameters to pass to DoHeatmap
#'
#' @importFrom dplyr %>% group_by filter top_n select
#'
#' @return Plots heatmap. No return value.
#'
#' @export
#'
NodeHeatmap <- function(object, marker.list, node = NULL, max.genes = 10, ...) {
  tree <- object@cluster.tree[[1]]
  node <- SetIfNull(x = node, default = min(marker.list$cluster))
  node.order <- c(node, DFT(tree = tree, node = node))
  marker.list$rank <- seq(1:nrow(x = marker.list))
  marker.list %>% group_by(cluster) %>% filter(avg_diff > 0) %>%
    top_n(max.genes, -rank) %>% select(gene, cluster) -> pos.genes
  marker.list %>% group_by(cluster) %>% filter(avg_diff < 0) %>%
    top_n(max.genes, -rank) %>% select(gene, cluster) -> neg.genes
  gene.list <- vector()
  node.stack <- vector()
  for (n in node.order) {
    if (NodeHasChild(tree = tree, node = n)) {
      gene.list <- c(
        gene.list,
        c(
          subset(x = pos.genes, subset = cluster == n)$gene,
          subset(x = neg.genes, subset = cluster == n)$gene
        )
      )
      if (NodeHasOnlyChildren(tree = tree, node = n)) {
        gene.list <- c(
          gene.list,
          subset(x = neg.genes, subset = cluster == node.stack[length(node.stack)])$gene
        )
        node.stack <- node.stack[-length(x = node.stack)]
      }
    }
    else {
      gene.list <- c(gene.list, subset(x = pos.genes, subset = cluster == n)$gene)
      node.stack <- append(x = node.stack, values = n)
    }
  }
  #gene.list <- rev(unique(rev(gene.list)))
  descendants <- GetDescendants(tree = tree, node = node)
  children <- descendants[!descendants %in% tree$edge[, 1]]
  all.children <- tree$edge[,2][!tree$edge[,2] %in% tree$edge[, 1]]
  DoHeatmap(
    object = object,
    cells.use = WhichCells(object = object, ident = children),
    genes.use = gene.list,
    slim.col.label = TRUE,
    remove.key = TRUE,
    ...
  )
}

#' Posterior Plot
#'
#' @param object A Seurat object
#' @param name Spatial code
#'
#' @seealso \code{SubsetColumn}
#' @seealso \code{VlnPlot}
#'
#' @export
#'
PosteriorPlot <- function(object, name) {
  post.names <- colnames(x = SubsetColumn(data = object@spatial@mix.probs, code = name))
  VlnPlot(
    object = object,
    features.plot = post.names,
    inc.first = TRUE,
    inc.final = TRUE,
    by.k = TRUE
  )
}
mayer-lab/SeuratForMayer2018 documentation built on May 25, 2019, 9:34 p.m.