R/model_performance.R

Defines functions perplexity resamplePerplexity plotGridSearchPerplexity plotGridSearchPerplexity.celda_CG plotGridSearchPerplexity.celda_C plotGridSearchPerplexity.celda_G resampleCountMatrix

Documented in perplexity plotGridSearchPerplexity plotGridSearchPerplexity.celda_C plotGridSearchPerplexity.celda_CG plotGridSearchPerplexity.celda_G resamplePerplexity

#' Calculate the perplexity from a single celda model
#' 
#' Perplexity can be seen as a measure of how well a provided set of 
#' cluster assignments fit the data being clustered.
#' 
#' @param counts Integer matrix. Rows represent features and columns represent cells. This matrix should be the same as the one used to generate `celda.mod`.
#' @param celda.mod Celda object of class "celda_C", "celda_G" or "celda_CG".
#' @param new.counts A new counts matrix used to calculate perplexity. If NULL, perplexity will be calculated for the 'counts' matrix. Default NULL.
#' @return Numeric. The perplexity for the provided count data and model.
#' @examples
#' perplexity = perplexity(celda.CG.sim$counts, celda.CG.mod)
#' @export
perplexity = function(counts, celda.mod, new.counts=NULL) {
  compareCountMatrix(counts, celda.mod)
  UseMethod("perplexity", celda.mod)
}


#' Calculate and visualize perplexity of all models in a celda_list, with count resampling
#' 
#' Calculates the erplexity of each model's cluster assignments given the provided count matrix,
#' as well as resamplings of that count matrix, providing a distribution of perplexities and a better sense of the 
#' quality of a given K/L choice. 
#' 
#' @param counts Integer matrix. Rows represent features and columns represent cells. This matrix should be the same as the one used to generate `celda.mod`.
#' @param celda.list Object of class 'celda_list'. 
#' @param resample Integer. The number of times to resample the counts matrix for evaluating perplexity. Default 5.
#' @param seed Parameter to set.seed() for random number generation. Default 12345.
#' @return celda_list. Returns the provided `celda.list` with a `perplexity` property, detailing the perplexity of all K/L combinations that appeared in the celda_list's models.
#' @examples
#' celda.CG.grid.search.res = resamplePerplexity(celda.CG.sim$counts, 
#'                                               celda.CG.grid.search.res)
#' plotGridSearchPerplexity(celda.CG.grid.search.res)
#' @export
resamplePerplexity <- function(counts, celda.list, resample=5, seed=12345) {
  if (!isTRUE(class(celda.list)[1] == "celda_list")) stop("celda.list parameter was not of class celda_list.")
  if (!isTRUE(is.numeric(resample))) stop("Provided resample parameter was not numeric.")
  
  set.seed(seed)
  countsList <- lapply(1:resample, function(x){
    resampleCountMatrix(counts)
  })
  
  perp.res = matrix(NA, nrow=length(celda.list$res.list), ncol=resample)
  for(i in 1:length(celda.list$res.list)) {
    for(j in 1:resample) {
      perp.res[i,j] = perplexity(counts, celda.list$res.list[[i]], countsList[[j]])
    }
  }
  celda.list$perplexity = perp.res
  
  return(celda.list)
}


#' Visualize perplexity of every model in a celda_list, by unique K/L combinations
#' 
#' @param celda.list Object of class 'celda_list'. 
#' @return A ggplot plot object showing perplexity as a function of clustering parameters.
#' @examples
#' ## Run various combinations of parameters with 'celdaGridSearch'
#' celda.CG.grid.search.res = resamplePerplexity(celda.CG.sim$counts, 
#'                                               celda.CG.grid.search.res)
#' plotGridSearchPerplexity(celda.CG.grid.search.res)
#' @export
plotGridSearchPerplexity = function(celda.list) {
  UseMethod("plotGridSearchPerplexity", celda.list)
}

#' Plot perplexity as a function of K and L from celda_CG models
#' 
#' This function plots perplexity as a function of the cell/gene (K/L) clusters as generated by celdaGridSearch().
#' 
#' @param celda.list Object of class 'celda_list'. 
#' @return A ggplot plot object showing perplexity as a function of clustering parameters.
#' @examples
#' celda.CG.grid.search.res = resamplePerplexity(celda.CG.sim$counts, 
#'                                               celda.CG.grid.search.res)
#' plotGridSearchPerplexity(celda.CG.grid.search.res)
#' @export
plotGridSearchPerplexity.celda_CG = function(celda.list) {
  if (!all(c("K", "L") %in% colnames(celda.list$run.params))) {
    stop("celda.list$run.params needs K and L columns.")
  }
  if (is.null(celda.list$perplexity)) {
    stop("No perplexity measurements available. First run 'resamplePerplexity' with celda.list object.")
  }

  ix1 = rep(1:nrow(celda.list$perplexity), each=ncol(celda.list$perplexity))
  ix2 = rep(1:ncol(celda.list$perplexity), nrow(celda.list$perplexity))
  df = data.frame(celda.list$run.params[ix1,], perplexity=celda.list$perplexity[cbind(ix1, ix2)])
  df$K = as.factor(df$K)
  df$L = as.factor(df$L)  

  l.means.by.k = stats::aggregate(df$perplexity, by=list(df$K, df$L), FUN=mean)
  colnames(l.means.by.k) = c("K", "L", "mean_perplexity")
  l.means.by.k$K = as.factor(l.means.by.k$K)
  l.means.by.k$L = as.factor(l.means.by.k$L)
  
  plot = ggplot2::ggplot(df, ggplot2::aes_string(x="K", y="perplexity")) +
  		ggplot2::geom_jitter(height=0, width=0.1, ggplot2::aes_string(color="L")) +
        ggplot2::scale_color_discrete(name="L") +
        ggplot2::geom_path(data=l.means.by.k, 
                           ggplot2::aes_string(x="K", y="mean_perplexity", group="L", color="L")) +
        ggplot2::ylab("Perplexity") +
        ggplot2::xlab("K") +
        ggplot2::theme_bw()
  
  return(plot)
}


#' Plot perplexity as a function of K from celda_C models
#' 
#' Plots perplexity as a function of the cell (K) clusters as generated by celdaGridSearch().
#' 
#' @param celda.list Object of class 'celda_list'. 
#' @return A ggplot plot object showing perplexity as a function of clustering parameters.
#' @examples
#' celda.CG.grid.search.res = resamplePerplexity(celda.CG.sim$counts,
#'                                               celda.CG.grid.search.res)
#' plotGridSearchPerplexity(celda.CG.grid.search.res)
#' @export
plotGridSearchPerplexity.celda_C = function(celda.list) {
  if (!all(c("K") %in% colnames(celda.list$run.params))) {
    stop("celda.list$run.params needs the column K.")
  }
  if (is.null(celda.list$perplexity)) {
    stop("No perplexity measurements available. First run 'resamplePerplexity' with celda.list object.")
  }

  ix1 = rep(1:nrow(celda.list$perplexity), each=ncol(celda.list$perplexity))
  ix2 = rep(1:ncol(celda.list$perplexity), nrow(celda.list$perplexity))
  df = data.frame(celda.list$run.params[ix1,], perplexity=celda.list$perplexity[cbind(ix1, ix2)])
  df$K = as.factor(df$K)


  means.by.k = stats::aggregate(df$perplexity, by=list(df$K), FUN=mean)
  colnames(means.by.k) = c("K", "mean_perplexity")
  means.by.k$K = as.factor(means.by.k$K)
  
  plot = ggplot2::ggplot(df, ggplot2::aes_string(x="K", y="perplexity")) +
  		ggplot2::geom_jitter(height=0, width=0.1) +
        ggplot2::geom_path(data=means.by.k, 
                           ggplot2::aes_string(x="K", y="mean_perplexity", group=1)) +  
        ggplot2::ylab("Perplexity") +
        ggplot2::xlab("K") +
        ggplot2::theme_bw()
  
  return(plot)
}


#' Plot perplexity as a function of L from a celda_G model
#' 
#' Plots perplexity as a function of the gene (L) clusters as generated by celdaGridSearch().
#' 
#' @param celda.list Object of class 'celda_list'. 
#' @return A ggplot plot object showing perplexity as a function of clustering parameters.
#' @examples
#' celda.CG.grid.search.res = resamplePerplexity(celda.CG.sim$counts, 
#'                                               celda.CG.grid.search.res)
#' plotGridSearchPerplexity(celda.CG.grid.search.res)
#' @export
plotGridSearchPerplexity.celda_G = function(celda.list) {
  if (!all(c("L") %in% colnames(celda.list$run.params))) {
    stop("celda.list$run.params needs the column L.")
  }
  if (is.null(celda.list$perplexity)) {
    stop("No perplexity measurements available. First run 'resamplePerplexity' with celda.list object.")
  }

  ix1 = rep(1:nrow(celda.list$perplexity), each=ncol(celda.list$perplexity))
  ix2 = rep(1:ncol(celda.list$perplexity), nrow(celda.list$perplexity))
  df = data.frame(celda.list$run.params[ix1,], perplexity=celda.list$perplexity[cbind(ix1, ix2)])
  df$L = as.factor(df$L)


  means.by.l = stats::aggregate(df$perplexity, by=list(df$L), FUN=mean)
  colnames(means.by.l) = c("L", "mean_perplexity")
  means.by.l$L = as.factor(means.by.l$L)
  
  plot = ggplot2::ggplot(df, ggplot2::aes_string(x="L", y="perplexity")) +
  		ggplot2::geom_jitter(height=0, width=0.1) +
        ggplot2::geom_path(data=means.by.l, 
                           ggplot2::aes_string(x="L", y="mean_perplexity", group=1)) +  
        ggplot2::ylab("Perplexity") +
        ggplot2::xlab("L") +
        ggplot2::theme_bw()
  
  return(plot)
}


# Resample a counts matrix for evaluating perplexity
#
# Normalizes each column (cell) of a count matrix by the column sum to 
# create a distribution of observing a given number of counts for a given gene in that cell,
# then samples across all cells.
#
# This is primarily used to evaluate the stability of the perplexity for a given K/L combination.
# 
# @param celda.mod A single celda run (usually from the _res.list_ property of a celda_list).
# @return The perplexity for the provided chain as an mpfr number.
resampleCountMatrix = function(count.matrix) {
  colsums  = colSums(count.matrix)
  prob     = t(t(count.matrix) / colsums)
  resample = sapply(1:ncol(count.matrix), function(idx){
                      stats::rmultinom(n=1, size=colsums[idx], prob=prob[, idx])
                   })
  return(resample)
}
compbiomed/celda documentation built on May 25, 2019, 3:58 a.m.