R/generator_lm.R

Defines functions generator_fasta_lm

Documented in generator_fasta_lm

#' Language model generator for fasta/fastq files
#'
#' @description Iterates over folder containing fasta/fastq files and produces encoding of predictor sequences
#' and target variables. Will take a sequence of fixed size and use some part of sequence as input and other part as target. 
#'
#' @inheritParams train_model
#' @param path_corpus Input directory where fasta files are located or path to single file ending with fasta or fastq
#' (as specified in format argument). Can also be a list of directories and/or files.
#' @param format File format, either `"fasta"` or `"fastq"`.
#' @param batch_size Number of samples in one batch.
#' @param maxlen Length of predictor sequence.
#' @param max_iter Stop after `max_iter` number of iterations failed to produce a new batch.
#' @param shuffle_file_order Logical, whether to go through files randomly or sequentially.
#' @param step How often to take a sample.
#' @param seed Sets seed for `set.seed` function for reproducible results.
#' @param shuffle_input Whether to shuffle entries in every fasta/fastq file before extracting samples.
#' @param verbose Whether to show messages.
#' @param path_file_log Write name of files to csv file if path is specified.
#' @param reverse_complement Boolean, for every new file decide randomly to use original data or its reverse complement.
#' @param ambiguous_nuc How to handle nucleotides outside vocabulary, either `"zero"`, `"discard"`, `"empirical"` or `"equal"`.
#' \itemize{
#' \item If `"zero"`, input gets encoded as zero vector.
#' \item If `"equal"`, input is repetition of `1/length(vocabulary)`.
#' \item If `"discard"`, samples containing nucleotides outside vocabulary get discarded.
#' \item If `"empirical"`, use nucleotide distribution of current file.
#' }
#' @param proportion_per_seq Numerical value between 0 and 1. Proportion of sequence to take samples from (use random subsequence).
#' @param use_quality_score Whether to use fastq quality scores. If TRUE input is not one-hot-encoding but corresponds to probabilities.
#' For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0).
#' @param padding Whether to pad sequences too short for one sample with zeros.
#' @param added_label_path Path to file with additional input labels. Should be a csv file with one column named "file". Other columns should correspond to labels.
#' @param add_input_as_seq Boolean vector specifying for each entry in \code{added_label_path} if rows from csv should be encoded as a sequence or used directly.
#' If a row in your csv file is a sequence this should be `TRUE`. For example you may want to add another sequence, say ACCGT. Then this would correspond to 1,2,2,3,4 in
#' csv file (if vocabulary = c("A", "C", "G", "T")).  If \code{add_input_as_seq} is `TRUE`, 12234 gets one-hot encoded, so added input is a 3D tensor.  If \code{add_input_as_seq} is
#' `FALSE` this will feed network just raw data (a 2D tensor).
#' @param skip_amb_nuc Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise.
#' @param max_samples Maximum number of samples to use from one file. If not `NULL` and file has more than \code{max_samples} samples, will randomly choose a
#' subset of \code{max_samples} samples.
#' @param concat_seq Character string or `NULL`. If not `NULL` all entries from file get concatenated to one sequence with `concat_seq` string between them.
#' Example: If 1.entry AACC, 2. entry TTTG and `concat_seq = "ZZZ"` this becomes AACCZZZTTTG.
#' @param target_len Number of nucleotides to predict at once for language model.
#' @param file_filter Vector of file names to use from path_corpus.
#' @param use_coverage Integer or `NULL`. If not `NULL`, use coverage as encoding rather than one-hot encoding and normalize.
#' Coverage information must be contained in fasta header: there must be a string `"cov_n"` in the header, where `n` is some integer.
#' @param proportion_entries Proportion of fasta entries to keep. For example, if fasta file has 50 entries and `proportion_entries = 0.1`,
#' will randomly select 5 entries.
#' @param sample_by_file_size Sample new file weighted by file size (bigger files more likely).
#' @param n_gram Integer, encode target not nucleotide wise but combine n nucleotides at once. For example for `n=2, "AA" ->  (1, 0,..., 0),`
#' `"AC" ->  (0, 1, 0,..., 0), "TT" -> (0,..., 0, 1)`, where the one-hot vectors have length `length(vocabulary)^n`.
#' @param add_noise `NULL` or list of arguments. If not `NULL`, list must contain the following arguments: \code{noise_type} can be `"normal"` or `"uniform"`;
#' optional arguments `sd` or `mean` if noise_type is `"normal"` (default is `sd=1` and `mean=0`) or `min, max` if `noise_type` is `"uniform"`
#' (default is `min=0, max=1`).
#' @param return_int Whether to return integer encoding or one-hot encoding.
#' @param reshape_xy Can be a list of functions to apply to input and/or target. List elements (containing the reshape functions)
#'  must be called x for input or y for target and each have arguments called x and y. For example: 
#'  `reshape_xy = list(x = function(x, y) {return(x+1)}, y = function(x, y) {return(x+y)})` .
#' For rds generator needs to have an additional argument called sw.
#' @rawNamespace import(data.table, except = c(first, last, between))
#' @importFrom magrittr %>%
#' @examplesIf reticulate::py_module_available("tensorflow")
#' # create dummy fasta files
#' path_input_1 <- tempfile()
#' dir.create(path_input_1)
#' create_dummy_data(file_path = path_input_1,
#'                   num_files = 2,
#'                   seq_length = 8,
#'                   num_seq = 1,
#'                   vocabulary = c("a", "c", "g", "t"))
#' 
#' gen <- generator_fasta_lm(path_corpus = path_input_1, batch_size = 2,
#'                                    maxlen = 7)
#' z <- gen()
#' dim(z[[1]])
#' z[[2]]
#' 
#' @returns A generator function.  
#' @export
generator_fasta_lm <- function(path_corpus,
                               format = "fasta",
                               batch_size = 256,
                               maxlen = 250,
                               max_iter = 10000,
                               vocabulary = c("a", "c", "g", "t"),
                               verbose = FALSE,
                               shuffle_file_order = FALSE,
                               step = 1,
                               seed = 1234,
                               shuffle_input = FALSE,
                               file_limit = NULL,
                               path_file_log = NULL,
                               reverse_complement = FALSE,
                               output_format = "target_right",
                               ambiguous_nuc = "zeros",
                               use_quality_score = FALSE,
                               proportion_per_seq = NULL,
                               padding = TRUE,
                               added_label_path = NULL,
                               add_input_as_seq = NULL,
                               skip_amb_nuc = NULL,
                               max_samples = NULL,
                               concat_seq = NULL,
                               target_len = 1,
                               file_filter = NULL,
                               use_coverage = NULL,
                               proportion_entries = NULL,
                               sample_by_file_size = FALSE,
                               n_gram = NULL,
                               n_gram_stride = 1,
                               add_noise = NULL,
                               return_int = FALSE,
                               reshape_xy = NULL) {
  
  
  ##TODO: add check for n-gram and option for stride
  # if (!is.null(n_gram) & !(any(n_gram_stride == c(n_gram, 1)))) {
  #   stop("When using language model with n_gram encoding, n_gram_stride must be 1 or equal to n_gram")
  # } 
  if (!is.null(n_gram)) {
    # maxlen_n_gram <- ceiling((maxlen - n_gram + 1)/n_gram_stride)
    # target_len_n_gram <- ceiling((target_len - n_gram + 1)/n_gram_stride) 
    if (!n_gram_stride == n_gram) {
      stop("When using train_type='lm' with n_gram encoding, n_gram_stride must be equal to n_gram.")
    }  
  } # else {
  #   maxlen_n_gram <- maxlen
  #   target_len_n_gram <- target_len 
  # }
  
  if (!is.null(reshape_xy)) {
    reshape_xy_bool <- TRUE
    reshape_x_bool <- ifelse(is.null(reshape_xy$x), FALSE, TRUE)
    if (reshape_x_bool && !all(c('x', 'y') %in% formals(reshape_xy$x))) {
      stop("function reshape_xy$x needs to have arguments named x and y")
    }
    reshape_y_bool <- ifelse(is.null(reshape_xy$y), FALSE, TRUE)
    if (reshape_y_bool && !all(c('x', 'y') %in% formals(reshape_xy$y))) {
      stop("function reshape_xy$y needs to have arguments named x and y")
    }
  } else {
    reshape_xy_bool <- FALSE
  }
  
  total_seq_len <- maxlen + target_len
  gen <- generator_fasta_label_folder(path_corpus = path_corpus,
                                      format = format,
                                      batch_size = batch_size,
                                      maxlen = total_seq_len,
                                      max_iter = max_iter,
                                      vocabulary = vocabulary,
                                      shuffle_file_order = shuffle_file_order,
                                      step = step,
                                      seed = seed,
                                      shuffle_input = shuffle_input,
                                      file_limit = file_limit,
                                      path_file_log = path_file_log,
                                      reverse_complement = reverse_complement,
                                      reverse_complement_encoding = FALSE,
                                      num_targets = 1,
                                      ones_column = 1,
                                      ambiguous_nuc = ambiguous_nuc,
                                      proportion_per_seq = proportion_per_seq,
                                      read_data = FALSE,
                                      use_quality_score = use_quality_score,
                                      padding = padding,
                                      added_label_path = added_label_path,
                                      add_input_as_seq = add_input_as_seq,
                                      skip_amb_nuc = skip_amb_nuc,
                                      max_samples = max_samples,
                                      concat_seq = concat_seq,
                                      file_filter = file_filter,
                                      use_coverage = use_coverage,
                                      proportion_entries = proportion_entries,
                                      sample_by_file_size = sample_by_file_size,
                                      n_gram = n_gram,
                                      n_gram_stride = n_gram_stride,
                                      masked_lm = NULL,
                                      add_noise = add_noise,
                                      return_int = return_int)
  
  function() {
    
    if (is.null(added_label_path)) {
      xy <- gen()[[1]]
    } else {
      z <- gen()[[1]]
      added_input <- z[1:(length(z)-1)]
      xy <- z[length(z)][[1]]
    }
    
    xy_list <- slice_tensor_lm(xy = xy,
                               output_format = output_format,
                               target_len = target_len,
                               n_gram = n_gram,
                               # maxlen_n_gram = maxlen_n_gram,
                               # target_len_n_gram = target_len_n_gram, 
                               n_gram_stride = n_gram_stride,
                               total_seq_len = total_seq_len,
                               return_int = return_int)
    
    if (!is.null(added_label_path)) {
      xy_list <- (list(append(added_input, list(xy_list$x)), xy_list$y))
    }
    
    if (reshape_xy_bool) {
      xy_list <- f_reshape(x = xy_list$x, y = xy_list$y,
                           reshape_xy = reshape_xy,
                           reshape_x_bool = reshape_x_bool,
                           reshape_y_bool = reshape_y_bool,
                           reshape_sw_bool = FALSE, sw = NULL)
    } 
    
    return(xy_list)
    
  }
}
GenomeNet/deepG documentation built on Dec. 24, 2024, 12:11 p.m.