#' Remove intra-TEC doublets by finding cells with both a high cTEC signature and a high mTEC signature.
#'
#' @param dge Seurat object
#' @param results_path Plots are saved here.
#' @param mTEC_thresh, @param cTEC_thresh Cells are excluded if they exceed both thresholds.
#' If these are NULL, they are selected automatically. 
#' @return A list with names "dge", "mTEC_thresh", and "cTEC_thresh."
#' 
#'
#' @details For threshold selection, we isolate some nearly-pure cTECs by ranking with all cells on 
#' cTEC signatures and retaining the top pure_fraction_cTEC (default 20%).
#' Among those cells, we take the (1 - reject_rate) quantile (default 98th percentile), as the
#' rejection threshold. 
#' The same process is used for mTEC threshold selection, but pure_fraction_mTEC defaults to just 2%.
#' 
#' @export
remove_TEC_doublets = function( dge, 
                                results_path,
                                mTEC_thresh = NULL, 
                                cTEC_thresh = NULL, 
                                pure_fraction_mTEC = 0.02, 
                                pure_fraction_cTEC = 0.2, 
                                reject_rate = 0.05 ) {
  dir.create.nice( results_path )
  dge %<>% add_cTEC_mTEC_signatures
  X = FetchData(dge, c("mTEC_signature", "cTEC_signature"))

  if( is.null( mTEC_thresh ) ){ 
    cTEC_purity_cutoff = quantile( X$cTEC_signature, p = (1 - pure_fraction_cTEC) )
    pure_cTECs = subset(X, cTEC_signature > cTEC_purity_cutoff )
    mTEC_thresh = quantile( pure_cTECs$mTEC_signature, (1-reject_rate) )
  }
  if( is.null( cTEC_thresh ) ){ 
    mTEC_purity_cutoff = quantile( X$mTEC_signature, p = (1 - pure_fraction_mTEC) )
    pure_mTECs = subset(X, mTEC_signature > mTEC_purity_cutoff )
    cTEC_thresh = quantile( pure_mTECs$cTEC_signature, (1-reject_rate) )
  }
  df_keep = subset( X, cTEC_signature < cTEC_thresh | mTEC_signature < mTEC_thresh )
  p = custom_feature_plot( dge, colour = NULL, 
                           axes = c( "mTEC_signature", "cTEC_signature" ) ) +
    ggtitle( paste0( "TEC doublet removal (", nrow( df_keep ), " of ", nrow( X ), " remain)" ) ) + 
    geom_hline( data = data.frame(cTEC_thresh), aes( yintercept = cTEC_thresh ), col = "red" ) +
    geom_vline( data = data.frame(mTEC_thresh), aes( xintercept = mTEC_thresh ), col = "red" ) 
  print( p )  
  ggsave( file.path( results_path, "tec_dub_cutoffs.pdf" ), plot = p, width = 7, height = 7 )
  cu = rownames( df_keep )
  dge = SubsetData( dge, cells.use = cu )
  return(list( dge = dge, 
               mTEC_thresh = mTEC_thresh, 
               cTEC_thresh = cTEC_thresh ) )
}

Doublet detection and classification of new cells

#' Given a named vector x with counts of various cell types, returns expected doublet quantities for each possible pairing.
#' 
#' @details Make sure you only feed this one replicate at a time! You can't get doublets across replicates.
#' Assumes a doublet rate of 5%. Your mileage (and flow rates) may vary.
#'
#' Output is a named numeric vector of expected cell counts. 
#' For names, every combination of names(x) should be present once in the output.
#' Order doesn't matter, so labels get alphabetized and concatenated with '_'.
#' Within-cluster doublets are included.
#' E.g. you get BLD_END and BLD_BLD but not END_BLD.
#'
#' @export
expected_doublet_counts = function( x, rate = 0.05 ){
  my_mat = matrix(x, nrow = length(x)) %*% matrix(x, ncol = length(x))
  my_mat = my_mat / sum( x )
  my_mat = my_mat*rate
  # Count BLD_TEC and TEC_BLD together
  for(i in 1:nrow(my_mat)){
    for(j in 1:ncol(my_mat)){
      if(i > j){ 
        my_mat[i, j] = 0
      }
      # adding this tiny bit is a dumb hack to prevent zeroes from disappearing during summary(Matrix(, sparse = T))
      if(i < j){ 
        my_mat[i, j] = 2*my_mat[i, j] + 0.000001 
      }
      if( i==j ){
        my_mat[i, j] = my_mat[i, j] + 0.000001 
      }

    }
  }
  colnames(my_mat) = names( x )
  rownames(my_mat) = names( x )
  dubs = summary( Matrix( my_mat, sparse = T ) ) 
  colnames(dubs) = c("cell_type_1", "cell_type_2", "n_dubs")
  dubs$cell_type_1 = names( x )[dubs$cell_type_1]
  dubs$cell_type_2 = names( x )[dubs$cell_type_2]
  dubs$n_dubs = round(dubs$n_dubs, 3)
  postprocess = function( s1_s2 ){ paste( sort( s1_s2), collapse = "_" ) }
  rownames( dubs ) =  apply( dubs[, 1:2], 1, postprocess )
  dubs = setNames( dubs$n_dubs, nm = rownames(dubs))
  return( dubs )
}

#' Classify cells from one Seurat object in terms of another Seurat object's identity field, with a "reject option" for unfamiliar cells. 
#' 
#' @param dge_train Cells to train classifier on. Seurat object.
#' @param dge_test Cells to be classified. Seurat object.
#' @param ident.use Identity variable to use for training labels.
#' @param vars.all List of raw genes/features to use. If possible, will be accessed 
#' through `FetchData`; in this case, should be numeric. For others, zeroes are filled in.
#' If NULL, uses variable genes from both `dge_train` and `dge_test`.
#' @param my_transform NULL, character, or function. 
#' If `is.null(my_transform)` (default), then `my_transform` is the identity. 
#' if `my_transform` has the form "PCA_<integer>", then the `my_transform` is an unscaled <integer>-dimensional
#' PCA based on the training data. This option triggers special behavior for quantifying 
#' classifier badness, because NN will perform badly in a principal subspace.
#' If a function is given, `my_transform` should accept and return matrices where rows are cells.
#' @param badness Either "pc_dist" or "neighbor_dist" or `NULL`. If `NULL`, default depends on 
#' `my_transform`. You can't use "pc_dist" unless `my_transform` has the form "PCA_<integer>".
#' @param k Number of nearest neighbors to use. Default 25. 
#' @param reject_prop Expected rate of false rejections you're willing to tolerate
#' on held-out training instances. Default is 1/100. This is not honest if `my_transform`
#' is chosen using the training data, and it cannot account for batch effects.
#' @return Seurat object identical to `dge_test` but with new/modified fields for 
#' - `classifier_ident` (predicted class) 
#' - `classifier_badness` (lower means higher confidence)
#' - `classifier_probs_<each identity class from trainset>` (predicted class probabilities)
#' @details Using k-nearest neighbors, classify cells from `dge_test` in terms of 
#' the options in `unique(FetchData(dge_train, ident.use))`, plus a reject option.
#' Rejection happens when the badness (usually distance to the nearest neighbors)
#' falls above a threshold (see `reject_prop`). Badness gets adjusted by cluster,
#' because some clusters naturally are less concentrated on the principal subspace
#' or the coordinates of interest.
#' @export
knn_classifier = function( dge_train, dge_test, ident.use = "ident", 
                           vars.all = NULL, my_transform = "PCA_20", badness = NULL,
                           k = 25, reject_prop = 0 ){

  if( is.null( vars.all ) ){ 
    vars.all = union( dge_train@var.genes, dge_test@var.genes ) 
  }

  # # Retrieve data, padding test data with zeroes as needed
  coords_train_orig = FetchDataZeroPad( dge_train, vars.all, warn = F )
  coords_test_orig  = FetchDataZeroPad( dge_test,  vars.all, warn = F )

  # # set the transform (and badness) 
  if( is.null( my_transform ) ){ my_transform = function(x) x }
  if( is.character( my_transform ) ) { 
    pca_n = strsplit( my_transform, "_", fixed = T )[[1]] 
    atae( "PCA",                  pca_n[1] )
    atat( !is.na( as.numeric( pca_n[2] ) ) )
    cat("Training transform: unscaled PCA of dimension", pca_n[2], " ...\n")
    w = irlba::prcomp_irlba( x = coords_train_orig, n = as.numeric( pca_n[2] ), 
                             retx = F,
                             center = colMeans( coords_train_orig ), 
                             scale = F )
    my_transform = function(x){
      sweep( x, 
             STATS = w$center, 
             FUN = "-",
             MARGIN = 2 ) %*% w$rotation 
    }
    if( is.null( badness ) ){ badness = "pc_dist" }
  } else {
    if( is.null( badness ) ){ badness = "neighbor_dist" }
  }
  if ( badness == "pc_dist" ) {
    cat(" Using distance to principal subspace as badness. \n")
    get_badness = function( nn_dists, x ){ 
      z = sweep( x, 
                 STATS = w$center, 
                 FUN = "-",
                 MARGIN = 2 ) 
      zproj = z %*% w$rotation %*% t( w$rotation )
      square = function( x ) x^2 
      ( z-zproj ) %>% apply( 1, square ) %>% apply( 2, sum ) %>% sapply( sqrt )
    }
  } else {
    cat(" Using average distance to neighbors as badness. \n")
    get_badness = function( nn_dists, x ) { rowMeans( nn_dists )}
  }

  cat("Transforming data...\n")
  coords_train_trans = my_transform( as.matrix( coords_train_orig ) ) %>% as.data.frame
  coords_test_trans  = my_transform( as.matrix( coords_test_orig ) ) %>% as.data.frame

  # # Find nearest neighbors & classify
  cat("Finding nearest neighbors...\n")
  nn_out = FNN::get.knnx( data = coords_train_trans, 
                          query = coords_test_trans,
                          k=k, algorithm=c( "cover_tree" ) )
  train_labels = FetchData( dge_train, ident.use )[[1]]
  get_label = function(idx) train_labels[idx]
  empir_prob = function(x) factor( x, levels = unique( train_labels ) ) %>% table %>% div_by_sum
  classifier_probs = apply( apply( nn_out$nn.index, 2, get_label ), 1, FUN = empir_prob ) %>% t %>% as.data.frame
  classifier_ident = apply( classifier_probs, 1, function(x) names( which.max( x ) ) )

  # # Get badness 
  cat("Calculating badness for each point...\n")
  nn_out_self = FNN::get.knn( data = coords_train_trans,
                              k=k, algorithm=c( "cover_tree" ) )
  classifier_prob_self = apply( apply( nn_out_self$nn.index, 2, get_label ), 
                                1, FUN = empir_prob ) %>% t %>% as.data.frame

  classifier_badness      = get_badness( nn_dists = nn_out$nn.dist,      x = as.matrix( coords_test_orig  ) )
  classifier_badness_self = get_badness( nn_dists = nn_out_self$nn.dist, x = as.matrix( coords_train_orig ) )

  # # Adjust badness via regression: some clusters are naturally less dense or farther from PC axes
  model_badness_by_cluster = lm( classifier_badness_self ~ . + 0, 
                                 data = cbind( classifier_badness_self, classifier_prob_self ) )  
  classifier_badness_self = classifier_badness_self - predict( object = model_badness_by_cluster ) 
  classifier_badness      = classifier_badness      - predict( object = model_badness_by_cluster, 
                                                               newdata = classifier_probs ) 
  classifier_badness      = classifier_badness      / sd(classifier_badness_self)
  classifier_badness_self = classifier_badness_self / sd(classifier_badness_self)

  # # Set threshold and label rejects
  if( reject_prop > 0 ){ 
    cat("Labeling rejects with attempted controls on false rejection rate ... \n")
    threshold = quantile( classifier_badness_self, 1 - reject_prop )
    hist( classifier_badness_self,  breaks = 80, col = scales::alpha("blue", 0.5))
    hist( classifier_badness,       breaks = 80, col = scales::alpha("red", 0.5),
          add = T, 
          main = "Held-out set badness and threshold",
          xlab = "Average distance to neighbors (train = blue, test = red)" )
    abline( v = threshold )
    text( x = threshold*1.05, y = 100, labels = "Reject", srt = 90 )
    text( x = threshold*0.95, y = 100, labels = "Classify", srt = 90 )
    classifier_ident[classifier_badness >= threshold] = "reject"
  }

  # # Save data and return
  to_add = cbind(          classifier_ident,   
                           classifier_badness,            
                           classifier_probs )
  colnames( to_add ) = c( "classifier_ident", 
                          "classifier_badness" , 
                          "classifier_probs_" %>% paste0( colnames( classifier_probs ) ) )
  rownames( to_add ) = rownames( coords_test_orig )
  dge_test %<>% AddMetaData( to_add )

  cat("Done.\n")
  return( dge_test )
}

#' Train and save a penalized logistic regression classifier.
#'
#' Uses Seurat::FetchData(training_dge, vars.all = ident.use ) as class labels.
#' Results (`glmnet` object) and training data (Seurat object) get
#' saved into a subdirectory of `results_path`. 
#' @export
train_save_classifier = function(training_dge, results_path, ident.use = "cell_type", genes.use, do.save = F ){

  # # Get labels
  training_labels = factor( vectorize_preserving_rownames ( Seurat::FetchData(training_dge, vars.all = ident.use ) ) )

  # # Get features
  genes.use = intersect( genes.use, rownames( training_dge@data ) )
  features_tf = t( training_dge@data[ genes.use, ] )

  # # Build classifier (penalized multinomial logistic regression)
  print("Training classifier...")
  # # alpha = 0 does L2 regression. alpha = 1 does LASSO. In between is elastic net.
  mlr_mod = glmnet::cv.glmnet(x = features_tf, y = training_labels, family = "multinomial", alpha = 0)
  print("... classifier trained.")

  # # Save and return
  if(do.save){
    dir.create.nice( file.path( results_path, "classifier" ) )
    print("Saving classifier...")
    saveRDS(mlr_mod,      file.path( results_path, "classifier", "glmnet_object.data" ) )
    saveRDS(training_dge, file.path( results_path, "classifier", "training_data.data" ) )
  }
  return( mlr_mod )
}

#' Extract coefficients from a glmnet multiclass logistic regression model.
#' 
get_classifier_coefs = function( mlr_mod ){
  cell_types = names( coefficients( mlr_mod, newx = features, s = "lambda.min" ) )
  genes = rownames( coefficients( mlr_mod, newx = features, s = "lambda.min" )[[1]] )
  coeffs = data.frame( matrix( NA, nrow = length(genes), ncol = length( cell_types ) ) )
  colnames( coeffs ) = cell_types
  rownames( coeffs ) = make.unique( genes )
  for( cell_type in cell_types ){
    coeffs[[cell_type]] = as.vector( coefficients( mlr_mod, newx = features, s = "lambda.min" )[[cell_type]] )
  }
  return( coeffs )
}

#' Apply a machine learning classifier (from `train_save_classifier`) to new data
#' 
#' @param dge a Seurat object 
#' @param mlr_mod a glmnet multiclass logistic regression model.
#' You can feed the output of `train_save_classifier` into the `mlr_mod` argument.
#' @export
classify_mlr = function( dge, mlr_mod ){
  genes_used = ( mlr_mod %>% coef %>% down_idx %>% attributes )$Dimnames %>% down_idx
  genes_used = setdiff( genes_used, "(Intercept)" )
  features = FetchDataZeroPad( dge, vars.all = genes_used ) %>% t
  predictions = predict( mlr_mod, newx = features, s = "lambda.min", type = "response" )[, , 1]
  return( predictions )
}

#' Visualize probabilistic classification results
#'
#' @param dge a Seurat object
#' @param results_path folder to place output in
#' @param mlr_mod a glmnet multiclass logistic regression model. Give this or `class_labels`, not both.
#' You can feed the output of `train_save_classifier` into the `mlr_mod` argument.
#' @param class_labels atomic character vector naming columns in the Seurat object that contain
#'  class probabilities. Give this or `mlr_mod`, not both. If names(class_labels) is filled in,
#'  then that's how the spokes get labeled.
#' @param fig_name filename for output.
#' @param facet_by Variable in Seurat object to facet resulting plots by; default is none
#' @param colour_by Variable in Seurat object to map color to; default is none
#' @param fig_name filename for output.
#' @param style If "points", plot each cell as a dot. 
#' If "density", then instead of plotting points, plot 2d density contours.
#' If "hexagons", do AWESOME HEXAGON BINNING YEAHHHHHHH HEXAGONS.
#' @param wheel_order Deprecated.
#' @param do.density Deprecated.
#' @export
wheel_plot = function( dge, results_path, mlr_mod = NULL, class_labels = NULL, fig_name, 
                       colour_by = NULL, facet_by = NULL, style = "points",
                       wheel_order = NULL, do.density = NULL ){
  # # Handle stupid old input
  if(!is.null(wheel_order)){
    warning( "`wheel_order` arg is deprecated. Use `class_labels` instead." )
    if( is.null( class_labels ) ){ class_labels = wheel_order} 
  }

  if( !is.null(do.density) && do.density){ 
    warning("do.density is deprecated. Use ` style = \"density\" ` instead.")
    style = "density"
  }

  # # Handle regular, non-stupid input
  mlr_mod_given = !is.null( mlr_mod )
  labels_given = !is.null( class_labels )
  if( !mlr_mod_given & !labels_given ){
    stop("Please specify either `mlr_mod` or `class_labels`")
  } else if( !mlr_mod_given & labels_given ) {
    cat( "Fetching stored predictions from Seurat object.\n" )
    predictions = FetchData( dge, vars.all = class_labels ) %>% as.matrix
  } else if( mlr_mod_given & !labels_given ) {
    cat( "Making predictions from glmnet object.\n" )
    predictions = classify_mlr( dge, mlr_mod )
    class_labels = colnames( predictions )
  } else if( mlr_mod_given & labels_given ) {
    cat( "Making predictions from glmnet object.\n" )
    warning( "Since `mlr_mod` was given, using `class_labels` 
             only to order wheel spokes, not to fetch stored predictions." )
    predictions = classify_mlr( dge, mlr_mod )
    class_labels = colnames( predictions )
  }

  # # Make wheel
  lp1 = length( class_labels ) + 1 
  if( is.null( names( class_labels ) ) ){ names( class_labels ) = class_labels }
  unwrapped_circle = seq( 0, 2*pi, length.out = lp1 )
  wheel_spokes = data.frame( z = names( class_labels ), 
                             x = cos( unwrapped_circle[-lp1] ), 
                             y = sin( unwrapped_circle[-lp1] ) )
  # # Process data
  cat("Processing data.\n")
  cell_positions = predictions %*% as.matrix( wheel_spokes[, c("x", "y")] )
  cell_positions = as.data.frame( cell_positions );   colnames( cell_positions ) = c( "x", "y" )
  if( !is.null( colour_by ) ){
      cell_positions[[colour_by]] = Seurat::FetchData(dge, vars.all = colour_by)[ rownames( predictions ),  ] 
  }
  if( !is.null( facet_by ) ){
    cell_positions[[facet_by]] = Seurat::FetchData(dge, vars.all = facet_by )[ rownames( predictions ),  ] 
  }

  # # Avoid possible namespace conflict between class_labels in model parameters and class_labels in test data
  if( "class_labels" %in% c( facet_by, colour_by ) ){
    wheel_spokes$class_labels_in_model = wheel_spokes$class_labels
    wheel_spokes$class_labels = NULL
  }

  # # Add wheel & label spokes
  wheel_plot = ggplot()   + geom_path ( data = wheel_spokes,   mapping = aes(     x,     y ) )
  wheel_plot = wheel_plot + geom_text ( data = wheel_spokes,   mapping = aes( 1.2*x, 1.2*y, label = z ) )
  # # Add data & facet
  if( style == "density"){
    cat( "Estimating density.\n" )
    wheel_plot = wheel_plot + geom_density_2d( data = cell_positions, 
                                               mapping = aes_string( "x", "y", colour = colour_by ) )
  } else if( style == "hexagons"){
    if( !is.null( colour_by ) ) {
      warning( "Won't use color with style==\"hegaxons\" because fill is mapped to bin count. " )
    }
    wheel_plot = wheel_plot + geom_hex( data = cell_positions, mapping = aes_string( "x", "y" ) ) 
    wheel_plot = wheel_plot +  scale_fill_gradient(trans = "log")
  } else {
    wheel_plot = wheel_plot + geom_point( data = cell_positions, 
                                          mapping = aes_string( "x", "y", colour = colour_by ), alpha = 0.5 ) 
  }
  wheel_plot = wheel_plot + ggtitle( fig_name )
  if( !is.null( colour_by )  && is.numeric( cell_positions[[colour_by]] ) ){
    wheel_plot = wheel_plot + scale_color_gradientn( colours = blue_gray_red )
  }
  if ( !is.null( facet_by ) ){
      wheel_plot = wheel_plot + facet_wrap( as.formula( paste( "~", facet_by ) ) )  
  }

  # # Save & return
  cat( "Done. Saving and returning plot.\n" )
  ggsave( filename = file.path( results_path, paste0( fig_name, ".pdf" ) ),
          plot = wheel_plot,
          width = 12, height = 10)
  return( wheel_plot )
}


maehrlab/thymusatlastools documentation built on May 28, 2019, 2:32 a.m.