R/canyon_plot.R

Defines functions canyon.plot

Documented in canyon.plot

#' Visualize results from an NMF cross-validation experiment
#' 
#' @description
#' Renders a ggplot2 object given output from \code{nnmf.cv}
#'
#' @details

#' Find the optimal rank by running scNMF::nnmf.cv() and then using canyon.plot to generate a plot from the nnmf.cv result using ggplot2
#' Any ggplot2 grammer can be applied to the returned ggplot2 object
#'
#' @param nnmf.cv The result from nnmf.cv
#' @param ribbon either "none", "fit" (default), or "sd". "fit" (default) plots a polynomial fit confidence interval with the loess method, "sd" plots one standard deviation channel based on individual factor angles.
#' @param points boolean, whether to show jittered points for the cosine angle of each factor pair (default is FALSE)
#' @param line.collapse Whether to collapse multiple lines into a single layer (default is FALSE)
#' @param ribbon.collapse Whether to collapse multiple ribbons into a single layer (default is TRUE). Ribbons will be collapsed if lines are collapsed.
#' @param ribbon.confidence Confidence interval for a ribbon, either a multiple of the standard deviation or a confidence range for fit (default 0.95 for ribbon = "fit", 1.0 for ribbon = "sd")
#' @param title plot ggtitle
#' @return A ggplot2 object
#' @seealso \code{\link{nnmf.cv}}
canyon.plot <- function(nnmf.cv, ribbon = "fit", points = FALSE, ribbon.collapse = TRUE, line.collapse = FALSE, ribbon.confidence = ifelse(ribbon == "fit", 0.95, 1.00), title = "Canyon Plot") {
  require(ggplot2)
  require(dplyr)

  data <- tibble(nnmf.cv$factor.angles)
  if (line.collapse == FALSE) {
    data <- data %>% group_by(k, seed) %>% mutate(model.angle = mean(factor.angle))
  } else {
    data <- tibble(nnmf.cv$factor.angles)
    data <- data %>% group_by(k) %>% mutate(model.angle = mean(factor.angle))
  }
  if (ribbon.collapse == TRUE || line.collapse == TRUE) {
    data <- data %>% group_by(k) %>% mutate(ribbon.lower = mean(factor.angle) - sd(factor.angle) / 2 * ribbon.confidence)
    data <- data %>% group_by(k) %>% mutate(ribbon.upper = mean(factor.angle) + sd(factor.angle) / 2 * ribbon.confidence)
  } else {
    data <- data %>% group_by(k, seed) %>% mutate(ribbon.upper = mean(factor.angle) + sd(factor.angle) / 2 * ribbon.confidence)
    data <- data %>% group_by(k, seed) %>% mutate(ribbon.lower = mean(factor.angle) - sd(factor.angle) / 2 * ribbon.confidence)
  }
  data <- data[order(as.numeric(data$k)),]
  data <- data %>% mutate_at("seed", factor)
  data[, "all"] <- rep("all",nrow(data))
 
  # get some pretty colors by spinning the color wheel just like ggplot does
  num.colors <- length(unique(data$seed))
  ggcolors <- function(n) {
    hues <- seq(15, 375, length = n + 1)
    hcl(h = hues, l = 65, c = 100)[1:n]
  }
  colors <- ggcolors(num.colors)
  if (line.collapse == TRUE || num.colors == 1) colors <- c("#000000")

  # function for displaying only integers on the x-axis
  integer_breaks <- function(n = 5, ...) {
    fxn <- function(x) {
      breaks <- floor(pretty(x, n, ...))
      names(breaks) <- attr(breaks, "labels")
      breaks
    }
    fxn
  }

  # plot the basis
  if (line.collapse == FALSE) {
    p <- ggplot(data, aes(x = k, y = factor.angle, group = seed)) +
    geom_line(aes(y = model.angle, color = seed), size = 1) +
    scale_color_manual(values = colors, labels = paste0("Run ", 1:num.colors)) +
    scale_fill_manual(values = colors, labels = paste0("Run ", 1:num.colors))
  } else {
    p <- ggplot(data, aes(x = k, y = factor.angle)) +
    geom_line(aes(y = model.angle), size = 1) +
    scale_color_manual(values = colors) +
    scale_fill_manual(values = colors)
  }

  p <- p + xlab("rank k") +
    theme_classic() +
    theme(aspect.ratio = 1) +
    scale_y_continuous(expand = c(0, 0)) +
    scale_x_continuous(breaks = integer_breaks(), expand = c(0, 0)) +
    theme(legend.title = element_blank()) +
    theme(plot.title = element_text(hjust = 0.5, size = 12)) + 
    ggtitle(title)
    
  # plot jitter points if requested
  if (points) {
    p <- p + ylab("angle between factors")
    ifelse(line.collapse == FALSE,
        p <- p + geom_jitter(width = 0.15, aes(color = seed)),
        p <- p + geom_jitter(width = 0.15)
    )
  } else {
    p <- p + ylab("angle between models")
  }

  if (num.colors == 1 || line.collapse == TRUE) {
    p <- p + NoLegend()
  }

  # plot a ribbon if requested, using either stat_smooth or a single standard deviation
  if (ribbon == "fit") {
    if (line.collapse == TRUE || ribbon.collapse == TRUE) {
      p <- p + geom_ribbon(stat = "smooth", method = "loess", aes(group = all), fill = "#a0a0a0", alpha = 0.3, formula = y ~ x, level = ribbon.confidence)
    } else {
      p <- p + geom_ribbon(stat = "smooth", method = "loess", aes(fill = seed), alpha = 0.3, formula = y ~ x, level = ribbon.confidence)
    }
  } else if (ribbon == "sd") {
    if (line.collapse == TRUE || ribbon.collapse == TRUE) {
      p <- p + geom_ribbon(aes(ymin = ribbon.lower, ymax = ribbon.upper), fill = "#a0a0a0", alpha = 0.3)
    } else {
      p <- p + geom_ribbon(aes(ymin = ribbon.lower, ymax = ribbon.upper, fill = seed), alpha = 0.3)
    }
  }

  return(p)
}
zdebruine/scNMF documentation built on Jan. 1, 2021, 1:50 p.m.