R/ml.R

Defines functions ml_tax_HFE .HFE

Documented in .HFE ml_tax_HFE

#' supporting function for HFE
#' @importFrom data.table rbindlist
#' @import tidytable
#' @return data.table
.HFE = function(brk, class_level, corr_cutoff=0.5, freqCut = 99/1, uniqueCut = 1, quiet=TRUE){
  tax_levs = c('Species', 'Genus', 'Family', 'Order', 'Class', 'Phylum', 'Domain')
  class_level = enexpr(class_level)

  # creating aggregated features
  dt_levs = list()
  if(as.character(class_level) %in% tax_levs[1:length(tax_levs)]){
    dt_levs[['Species']] = brk %>%
      mutate.(Taxonomy = paste(Phylum, Class, Order, Family, Genus, Species, sep=';'))  %>%
      select.(Taxonomy, Sample, Abundance)
  }
  if(as.character(class_level) %in% tax_levs[2:length(tax_levs)]){
    dt_levs[['Genus']] = brk %>%
      mutate.(Taxonomy = paste(Phylum, Class, Order, Family, Genus, sep=';')) %>%
      summarize.(Abundance = sum(Abundance),
                 .by = c(Taxonomy, Sample))
  }
  if(as.character(class_level) %in% tax_levs[3:length(tax_levs)]){
    dt_levs[['Family']] = brk %>%
      mutate.(Taxonomy = paste(Phylum, Class, Order, Family, sep=';')) %>%
      summarize.(Abundance = sum(Abundance),
                 .by = c(Taxonomy, Sample))
  }
  if(as.character(class_level) %in% tax_levs[4:length(tax_levs)]){
    dt_levs[['Order']] = brk %>%
      mutate.(Taxonomy = paste(Phylum, Class, Order, sep=';')) %>%
      summarize.(Abundance = sum(Abundance),
                 .by = c(Taxonomy, Sample))
  }
  if(as.character(class_level) %in% tax_levs[5:length(tax_levs)]){
    dt_levs[['Class']] = brk %>%
      mutate.(Taxonomy = paste(Phylum, Class, sep=';')) %>%
      summarize.(Abundance = sum(Abundance),
                 .by = c(Taxonomy, Sample))
  }
  if(as.character(class_level) %in% tax_levs[6:length(tax_levs)]){
    dt_levs[['Phylum']] = brk %>%
      rename.('Taxonomy' = Phylum) %>%
      summarize.(Abundance = sum(Abundance),
                 .by = c(Taxonomy, Sample))
  }

  # combining & creating wide table
  brk = data.table::rbindlist(dt_levs, use.names=TRUE)
  rm(dt_levs)
  brk_w = brk %>%
    pivot_wider.(names_from=Taxonomy, values_from=Abundance) %>%
    as.data.frame
  rownames(brk_w) = brk_w$Sample
  brk_w$Sample = NULL

  # removing near-zero factors
  if(!is.null(freqCut) & !is.null(uniqueCut)){
    to_rm = caret::nearZeroVar(brk_w, freqCut=freqCut, uniqueCut=uniqueCut, names=TRUE)
    if(length(to_rm) >= ncol(brk_w)){
      return(brk %>% filter.(FALSE))
    }
    brk = brk %>%
      filter.(! Taxonomy %in% !!to_rm)
    brk_w = brk %>%
      pivot_wider.(names_from=Taxonomy, values_from=Abundance) %>%
      as.data.frame
    rownames(brk_w) = brk_w$Sample
    brk_w$Sample = NULL
  }

  # correlation amoung features
  options(warn=-1)
  descrCor = brk_w %>%
    cor %>% as.matrix
  options(warn=0)
  if(ncol(descrCor) > 1){
    descrCor[is.na(descrCor)] = 0
    to_rm = caret::findCorrelation(descrCor, cutoff=corr_cutoff, names=TRUE)
  } else {
    to_rm = c()
  }
  if(quiet == FALSE){
    cat('Clade: ', clade, sep = '')
    cat('; Removing', length(to_rm), 'of', ncol(brk_w)-1, 'taxa\n')
  }
  if(length(to_rm) > 0){
    brk = brk %>%
      filter.(! Taxonomy %in% !!to_rm)
  }
  return(brk)
}

#' Hierarchical Feature Selection
#'
#' For each clade (defined by tax_level), aggregate species abundances at each taxonomic
#' level up to the user-defined "tax_level", (optionally filter out near-zero features),
#' then filter out taxa that correlate strongly
#' (just one taxon is selected of those that correlate).
#'
#' @param brk data.table generated by read_bracken(). Columns: Sample, Abundance, Phylum=>Species
#' @param tax_level which taxonmoic level to use?
#' @param corr_cutoff features with >cutoff will be filtered to just one
#' @param freqCut as in caret::nearZeroVar; use NULL to skip
#' @param uniqueCut as in caret::nearZeroVar; use NULL to skip
#' @return data.table of filtered features
#' @export
#' @import tidytable
#' @importFrom data.table rbindlist
ml_tax_HFE = function(brk, tax_level, corr_cutoff=0.7, threads=2, freqCut = 95/1, uniqueCut = 5, quiet=TRUE){
  require(caret)
  require(plyr)
  tax_level = rlang::enexpr(tax_level)
  if(as.character(tax_level) == 'Species'){
    brk = brk %>%
      mutate.(Taxonomy = paste(Phylum, Class, Order, Family, Genus, Species, sep=';')) %>%
      select.(Taxonomy, Abundance, Sample) %>%
      pivot_wider.(names_from=Taxonomy, values_from=Abundance)
    return(brk)
  }
  if(threads > 1) doParallel::registerDoParallel(threads)
  brk = brk %>%
    group_split.(!!tax_level) %>%
    plyr::llply(.HFE, class_level=!!tax_level,
                corr_cutoff=corr_cutoff, quiet=quiet,
                freqCut=freqCut, uniqueCut=uniqueCut,
                .parallel=threads > 1) %>%
    data.table::rbindlist(use.names=TRUE) %>%
    pivot_wider.(names_from=Taxonomy, values_from=Abundance)
  return(brk)
}
leylabmpi/LeyLabRMisc documentation built on Nov. 3, 2022, 3:45 p.m.