R/predict.R

Defines functions evaluateFasta replaceChar predictNextNucleotide

Documented in evaluateFasta predictNextNucleotide replaceChar

#' Predict next nucleotides in sequence 
#' 
#' The output is a S4 class.
#'
#' @param sequence input sequence, length should be in sync with the model.
#' If length exceeds input.shape of model then only the right side of the
#' sequence will be used.
#' @param model trained model from the function \code{trainNetwork()}
#' @param vocabulary vocabulary of input sequence
#' @param verbose TRUE/FALSE
#' @examples 
#' \dontrun{
#' example.model <- keras::load_model_hdf5("example_model.hdf5")
#' sequence <- strrep("A", 100)
#' predictNextNucleotide(sequence, example.model)}
#' @export
predictNextNucleotide <- function(sequence,
                                    model,
                                    vocabulary =  c("l", "a", "c", "g", "t"),
                                    verbose = F){
  
  stopifnot(!missing(sequence))
  stopifnot(!missing(model))
  stopifnot(nchar(sequence) >= model$input_shape[2])
  
  substringright <- function(x, n){
    substr(x, nchar(x)- n + 1, nchar(x))
  }
  # sequence can be longer then model input shape
  # if so just use the last input_shape chars
  
  sentence <- tokenizers::tokenize_characters(
    stringr::str_to_lower(substringright(sequence, as.numeric(model$input_shape[2]))),
    strip_non_alphanum = FALSE, simplify = TRUE)
  
  x <- sapply(vocabulary, function(x){
    as.numeric(x == sentence)
  })
  x <- keras::array_reshape(x, c(1, dim(x)))

  if(verbose) {
    message("Prediction ...")}
  
  preds <- keras::predict_proba(model, x)
  next_index <- which.max(preds)
  next_char <- vocabulary[next_index]
  # return a S4 class
  return(new("prediction",
             next_char = next_char,
             probability = preds[next_index],
             index = next_index,
             alternative_probability = preds,
             solution = paste0(sequence, next_char)))
}


#' Replaces specific nucleotides in a sequence
#' 
#' @param sequence input sequence, length should be in sync with the model.
#' If length exceeds input.shape of model then only the right side of the
#' sequence will be used.
#' @param model trained model from the function \code{trainNetwork()}
#' @param char character in the sequence that will be replaced
#' @param vocabulary ordered vocabulary of input sequence
#' @examples 
#' \dontrun{
#' example.model <- keras::load_model_hdf5("example_model.hdf5")
#' replaceChar(sequence = sequence, model = example.model)}
#' @export
replaceChar <- function(sequence,
                         model,
                         char = "X",
                         vocabulary =  c("l", "a", "c", "g", "t")){
  
  stopifnot(!missing(sequence))
  stopifnot(!missing(model))
  stopifnot(nchar(sequence) >= model$input_shape[2])

  while (stringr::str_detect(sequence, char)) {
    # get the position
    next_position <- stringr::str_locate_all(pattern = 'X', sequence)[[1]][[1]]
    # seed text for model is the most-right chunk of text
    # with size of model$input_shape[[2]]
    seed <- substr(sequence,
                   next_position - model$input_shape[[2]] - 1,
                   next_position - 1)
    prediction <- predictNextNucleotide(seed, model, vocabulary)
    sequence <- paste0(prediction@solution,
                       substr(sequence, next_position + 1,
                              nchar(sequence)))
  }
  return(sequence)
}

#' Evaluates a trained model on .fasta/fastq files
#' 
#' Returns accuracies per batch and overall confusion matrix. Evaluates \code{batch.size} * \code{numberOfBatches} samples.
#' 
#' @param fasta.path Input directory where fasta/fastq files are located.
#' @param model.path Path to pretrained model.
#' @param batch.size Number of samples per batch.
#' @param step How often to take a sample.
#' @param seqStart Inserts character at beginning of sequence from one file.
#' @param seqEnd Insert character at end of sequence from one file.
#' @param withinFile Insert characters between fasta entries.
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector.
#' @param label_vocabulary Labels for targets. Equal to vocabulary if not given.
#' @param numberOfBatches How many batches to evaluate.
#' @param filePath Where to store output, if missing output won't be written.
#' @param format File format, "fasta" or "fastq".
#' @param filename Name of output file.
#' @param plot Returns density plot of accuracies if TRUE.
#' @param mode Either "lm" for language model and "label_header" or "label_folder" for label classification.
#' @export
  evaluateFasta <- function(fasta.path,
                            model.path,
                            batch.size = 100,
                            step = 1,
                            vocabulary = c("a", "c", "g", "t"),
                            label_vocabulary = c("a", "c", "g", "t"),
                            numberOfBatches = 10,
                            filePath = NULL,
                            format = "fasta",
                            filename = "",
                            plot = TRUE, 
                            mode = "lm"){
  
  stopifnot(mode %in% c("lm", "label_header", "label_folder"))
  
  if (is.null(label_vocabulary)) label_vocabulary <- vocabulary
  
  model <- keras::load_model_hdf5(model.path)
  maxlen <- model$input$shape[1]
  
  if (mode == "lm"){
  gen <- fastaFileGenerator(corpus.dir = fasta.path,
                            format = format,
                            batch.size = batch.size,
                            maxlen = maxlen,
                            max_iter = 100,
                            vocabulary = vocabulary,
                            verbose = FALSE,
                            randomFiles = FALSE,
                            step = step,
                            showWarnings = FALSE,
                            shuffleFastaEntries = FALSE,
                            reverseComplements = FALSE)
  }
  if (mode == "label_header"){
    gen <- fastaLabelGenerator(corpus.dir = fasta.path,
                    format = format,
                    batch.size = batch.size,
                    maxlen = maxlen,
                    max_iter = 100,
                    vocabulary = vocabulary,
                    verbose = FALSE,
                    randomFiles = FALSE,
                    step = step,
                    showWarnings = FALSE,
                    shuffleFastaEntries = FALSE,
                    labelVocabulary = label_vocabulary,
                    reverseComplements = FALSE)
  }
  if (mode == "label_folder"){
    # Bug, order of classes = order fasta.path ?  
    initializeGenerators(directories = fasta.path,
                         format = format,
                         batch.size = batch.size,
                         maxlen = maxlen,
                         vocabulary = vocabulary,
                         verbose = FALSE,
                         randomFiles = FALSE,
                         step = step, 
                         showWarnings = FALSE,
                         shuffleFastaEntries = FALSE,
                         numberOfFiles = NULL,
                         fileLog = NULL,
                         reverseComplements = FALSE,
                         val = FALSE)
    gen <- labelByFolderGeneratorWrapper(val = FALSE, path = fasta.path)
  }
  
  acc <- vector("numeric")
  confMat <- matrix(0, nrow = length(label_vocabulary), ncol = length(label_vocabulary))
  
  for (i in 1:numberOfBatches){
    z <- gen()
    x <- z[[1]]
    y <- z[[2]]
    
    y_pred <- keras::predict_classes(model, x, verbose = 0) + 1
    y_true <- apply(y, 1, FUN = which.max)
    
    df_true_pred <- data.frame(
      true = factor(y_true, levels = 1:(length(label_vocabulary)), labels = label_vocabulary),
      pred = factor(y_pred, levels = 1:(length(label_vocabulary)), labels = label_vocabulary)
    )
    
    acc[i] <- sum(y_pred == y_true)/batch.size
    cm <- yardstick::conf_mat(df_true_pred, true, pred)
    confMat <- confMat + cm[[1]]
  }
  
  if (plot){
    df <- data.frame(accuracies = acc)
    acc_plot <- ggplot2::ggplot(df) + ggplot2::geom_density(ggplot2::aes(x = accuracies), color = "blue", alpha = 0.1, fill = "blue")
    print(acc_plot)
  }
  
  if (!is.null(filePath)){
    save(acc, file = paste0(filePath, "/", filename, "Acc.Rdata"))
    save(confMat, file = paste0(filePath, "/", filename, "ConfMat.Rdata"))
    if (plot){
      ggplot2::ggsave(acc_plot, filename = paste0(filePath, "/", filename, "AccPlot.pdf" ))
    }
  }
  list(acc, confMat)
}
hiddengenome/altum documentation built on April 22, 2020, 9:33 p.m.