R/predict.R

Defines functions summarize_states load_prediction predict_model_one_pred_per_entry predict_model_by_entry_one_file predict_model_by_entry predict_model_one_seq predict_model

Documented in load_prediction predict_model summarize_states

#' Make prediction for nucleotide sequence or entries in fasta/fastq file
#'
#' @description Removes layers (optional) from pretrained model and calculates states of fasta/fastq file or nucleotide sequence.
#' Writes states to h5 or csv file (access content of h5 output with \code{\link{load_prediction}} function).
#' There are several options on how to process an input file:
#' \itemize{
#' \item If `"one_seq"`, computes prediction for sequence argument or fasta/fastq file.
#' Combines fasta entries in file to one sequence. This means predictor sequences can contain elements from more than one fasta entry.
#' \item If `"by_entry"`, will output a separate file for each fasta/fastq entry.
#' Names of output files are: `output_dir` + "Nr" + i + `filename` + `output_type`, where i is the number of the fasta entry.
#' \item If `"by_entry_one_file"`, will store prediction for all fasta entries in one h5 file.
#' \item If `"one_pred_per_entry"`, will make one prediction for each entry by either picking random sample for long sequences
#' or pad sequence for short sequences.
#' }
#' 
#' @inheritParams get_generator
#' @inheritParams train_model
#' @inheritParams get_generator 
#' @inheritParams train_model
#' @param layer_name Name of layer to get output from. If `NULL`, will use the last layer.
#' @param path_input Path to fasta file.
#' @param sequence Character string, ignores path_input if argument given.
#' @param round_digits Number of decimal places.
#' @param mode Either `"lm"` for language model or `"label"` for label classification.
#' @param include_seq Whether to include input sequence in h5 file.
#' @param output_format Either `"one_seq"`, `"by_entry"`, `"by_entry_one_file"`, `"one_pred_per_entry"`.
#' @param output_type `"h5"` or `"csv"`. If `output_format`` is `"by_entries_one_file", "one_pred_per_entry"` can only be `"h5"`.
#' @param return_states Return predictions as data frame. Only supported for output_format `"one_seq"`.
#' @param padding Either `"none"`, `"maxlen"`, `"standard"` or `"self"`.
#' \itemize{
#' \item If `"none"`, apply no padding and skip sequences that are too short.
#' \item If `"maxlen"`, pad with maxlen number of zeros vectors.
#' \item If `"standard"`, pad with zero vectors only if sequence is shorter than maxlen. Pads to minimum size required for one prediction.
#' \item If `"self"`, concatenate sequence with itself until sequence is long enough for one prediction.
#' Example: if sequence is "ACGT" and maxlen is 10, make prediction for "ACGTACGTAC". 
#' Only applied if sequence is shorter than maxlen.
#' }
#' @param verbose Boolean.
#' @param filename Filename to store states in. No file output if argument is `NULL`.
#' If `output_format = "by_entry"`, adds "_nr_" + "i" after name, where i is entry number.
#' @param output_dir Directory for file output.
#' @param use_quality Whether to use quality scores.
#' @param quality_string String for encoding with quality scores (as used in fastq format).
#' @param lm_format Either `"target_right"`, `"target_middle_lstm"`, `"target_middle_cnn"` or `"wavenet"`.
#' @param ... Further arguments for sequence encoding with \code{\link{seq_encoding_label}}.
#' @examplesIf reticulate::py_module_available("tensorflow")
#' # make prediction for single sequence and write to h5 file
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE)
#' vocabulary <- c("a", "c", "g", "t")
#' sequence <- paste(sample(vocabulary, 200, replace = TRUE), collapse = "")
#' output_file <- tempfile(fileext = ".h5")
#' predict_model(output_format = "one_seq", model = model, step = 10,
#'              sequence = sequence, filename = output_file, mode = "label")
#' 
#' # make prediction for fasta file with multiple entries, write output to separate h5 files
#' fasta_path <- tempfile(fileext = ".fasta")
#' create_dummy_data(file_path = fasta_path, num_files = 1,
#'                  num_seq = 5, seq_length = 100,
#'                  write_to_file_path = TRUE)
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE)
#' output_dir <- tempfile()
#' dir.create(output_dir)
#' predict_model(output_format = "by_entry", model = model, step = 10, verbose = FALSE,
#'                output_dir = output_dir, mode = "label", path_input = fasta_path)
#' list.files(output_dir)
#' 
#' @returns If `return_states = TRUE` returns a list of model predictions and position of corresponding sequences.
#' If additionally `include_seq = TRUE`, list contains sequence strings.
#' If `return_states = FALSE` returns nothing, just writes output to file(s).  
#' @export
predict_model <- function(model, output_format = "one_seq", layer_name = NULL, sequence = NULL, path_input = NULL,
                          round_digits = NULL, filename = "states.h5", step = 1, vocabulary = c("a", "c", "g", "t"),
                          batch_size = 256, verbose = TRUE, return_states = FALSE, 
                          output_type = "h5", padding = "none", use_quality = FALSE, quality_string = NULL,
                          mode = "label", lm_format = "target_right", output_dir = NULL,
                          format = "fasta", include_seq = FALSE, reverse_complement_encoding = FALSE,
                          ambiguous_nuc = "zero", ...) {
  
  stopifnot(padding %in% c("standard", "self", "none", "maxlen"))
  stopifnot(output_format %in% c("one_seq", "by_entry", "by_entry_one_file", "one_pred_per_entry"))
  if (output_format %in% c("by_entry_one_file", "one_pred_per_entry") & output_type == "csv") {
    message("by_entry_one_file or one_pred_per_entry only implemented for h5 output.
            Setting output_type to h5")
    output_type <- "h5"
  }
  
  if (output_format == "one_seq") {
    output_list <- predict_model_one_seq(layer_name = layer_name, sequence = sequence, path_input = path_input,
                                         round_digits = round_digits, filename = filename, step = step, vocabulary = vocabulary,
                                         batch_size = batch_size, verbose = verbose, return_states = return_states, 
                                         padding = padding, quality_string = quality_string, use_quality = use_quality,
                                         output_type = output_type, model = model, mode = mode, lm_format = lm_format,
                                         format = format, include_seq = include_seq, ambiguous_nuc = ambiguous_nuc,
                                         reverse_complement_encoding = reverse_complement_encoding, ...)
    return(output_list)
  }
  
  if (output_format == "by_entry") {
    predict_model_by_entry(layer_name = layer_name, path_input = path_input,
                           round_digits = round_digits, filename = filename, step = step, 
                           # vocabulary = vocabulary, quality_string = quality_string, ambiguous_nuc = ambiguous_nuc,
                           batch_size = batch_size, verbose = verbose, use_quality = use_quality,
                           output_type = output_type, model = model, mode = mode, lm_format = lm_format,
                           output_dir = output_dir, format = format, include_seq = include_seq,
                           padding = padding, reverse_complement_encoding = reverse_complement_encoding, ...)
  }
  
  if (output_format == "by_entry_one_file") {
    predict_model_by_entry_one_file(layer_name = layer_name, path_input = path_input,
                                    round_digits = round_digits, filename = filename, step = step, 
                                    # vocabulary = vocabulary, quality_string = quality_string, ambiguous_nuc = ambiguous_nuc, 
                                    batch_size = batch_size, verbose = verbose, use_quality = use_quality,
                                    model = model, mode = mode, lm_format = lm_format, format = format,
                                    padding = padding, include_seq = include_seq, 
                                    reverse_complement_encoding = reverse_complement_encoding, ...)
  }
  
  if (output_format == "one_pred_per_entry") {
    if (mode == "lm") {
      stop("one_pred_per_entry only implemented for label classification")
    }
    predict_model_one_pred_per_entry(layer_name = layer_name, path_input = path_input,
                                     round_digits = round_digits, filename = filename, 
                                     batch_size = batch_size, verbose = verbose, model = model, format = format,
                                     # ambiguous_nuc = ambiguous_nuc, use_quality = use_quality, vocabulary = vocabulary,
                                     reverse_complement_encoding = reverse_complement_encoding, ...)
  }
  
}


#' Write output of specific model layer to h5 or csv file.
#'
#' Removes layers (optional) from pretrained model and calculates states of fasta file, writes states to h5/csv file.
#' Function combines fasta entries in file to one sequence. This means predictor sequences can contain elements from more than one fasta entry.
#' h5 file also contains sequence and positions of targets corresponding to states.
#'
#' @inheritParams generator_fasta_lm
#' @param layer_name Name of layer to get output from. If `NULL`, will use the last layer.
#' @param path_input Path to fasta file.
#' @param sequence Character string, ignores path_input if argument given.
#' @param round_digits Number of decimal places.
#' @param batch_size Number of samples to evaluate at once. Does not change output, only relevant for speed and memory.
#' @param step Frequency of sampling steps.
#' @param filename Filename to store states in. No file output if argument is `NULL`.
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector.
#' @param return_states Logical scalar, return states matrix.
#' @param ambiguous_nuc `"zero"` or `"equal"`. 
#' @param verbose Whether to print model before and after removing layers.
#' @param output_type Either `"h5"` or `"csv"`.
#' @param model A keras model. 
#' @param mode Either `"lm"` for language model or `"label"` for label classification.
#' @param format Either `"fasta"` or `"fastq"`.
#' @param include_seq Whether to include input sequence in h5 file.
#' @param ... Further arguments for sequence encoding with \code{\link{seq_encoding_label}}.
#' @noRd
predict_model_one_seq <- function(model, layer_name = NULL, sequence = NULL, path_input = NULL, round_digits = 2,
                                  filename = "states.h5", step = 1, vocabulary = c("a", "c", "g", "t"), batch_size = 256, verbose = TRUE,
                                  return_states = FALSE, target_len = 1, use_quality = FALSE, quality_string = NULL,
                                  output_type = "h5", mode = "lm", lm_format = "target_right",
                                  ambiguous_nuc = "zero", padding = "none", format = "fasta", output_dir = NULL,
                                  include_seq = TRUE, reverse_complement_encoding = FALSE, ...) {
  
  vocabulary <- stringr::str_to_lower(vocabulary)
  stopifnot(mode %in% c("lm", "label"))
  stopifnot(output_type %in% c("h5", "csv"))
  if (!is.null(quality_string)) use_quality <- TRUE
  # if (!is.null(quality_string) & !is.null(sequence)) {
  #   stopifnot(length(sequence) == length(quality_to_probability(quality_string)))
  # }
  file_output <- !is.null(filename)
  if (!file_output) {
    if (!return_states) stop("If filename is NULL, return_states must be TRUE; otherwise function produces no output.")
    filename <- tempfile(fileext = paste0(".", output_type))
  }
  stopifnot(batch_size > 0)
  stopifnot(!file.exists(filename))
  if (reverse_complement_encoding) {
    test_len <- length(vocabulary) != 4
    if (test_len || all(sort(stringr::str_to_lower(vocabulary)) != c("a", "c", "g", "t"))) {
      stop("reverse_complement_encoding only implemented for A,C,G,T vocabulary")
    }
  }
  
  # token for ambiguous nucleotides
  for (i in letters) {
    if (!(i %in% stringr::str_to_lower(vocabulary))) {
      amb_nuc_token <- i
      break
    }
  }
  
  tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token))
  
  if (is.null(layer_name)) {
    layer_name <- model$output_names
    if (verbose) message(paste("layer_name not specified. Using layer", layer_name))
  }
  
  if (!is.null(sequence) && (!missing(sequence) & sequence != "")) {
    nt_seq <- sequence %>% stringr::str_to_lower()
  } else {
    if (format == "fasta") {
      fasta.file <- microseq::readFasta(path_input)
    }
    if (format == "fastq") {
      fasta.file <- microseq::readFastq(path_input)
    }
    if (nrow(fasta.file) > 1 & verbose) {
      text_1 <- paste("Your file has", nrow(fasta.file), "entries. 'one_seq'  output_format will concatenate them to a single sequence.\n")
      text_2 <- "Use 'by_entry' or 'by_entry_one_file' output_format to evaluate them separately."
      message(paste0(text_1, text_2))
    }
    nt_seq <- paste(fasta.file$Sequence, collapse = "") %>% stringr::str_to_lower()
  }
  
  # tokenize ambiguous nt
  pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]")
  nt_seq <- stringr::str_replace_all(string = nt_seq, pattern = pattern, amb_nuc_token)
  
  if (use_quality & is.null(quality_string)) {
    quality_string <- paste(fasta.file$Quality, collapse = "")
  }
  
  # extract maxlen
  target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE)
  if (!target_middle) {
    if (reverse_complement_encoding) {
      maxlen <- model$input[[1]]$shape[[2]]
    } else {
      maxlen <- model$input$shape[[2]]
    }
  } else {
    maxlen_1 <- model$input[[1]]$shape[[2]]
    maxlen_2 <- model$input[[2]]$shape[[2]]
    maxlen <- maxlen_1 + maxlen_2
  }
  
  total_seq_len <- ifelse(mode == "lm", maxlen + target_len, maxlen)
  
  # pad sequence
  unpadded_seq_len <- nchar(nt_seq)
  pad_len <- 0
  if (padding == "maxlen") {
    pad_len <- maxlen
  }
  if (padding == "standard" & (unpadded_seq_len < total_seq_len)) {
    pad_len <- total_seq_len - unpadded_seq_len
  }
  if (padding == "self"  & (unpadded_seq_len < total_seq_len)) {
    nt_seq <- strrep(nt_seq, ceiling(total_seq_len / unpadded_seq_len))
    nt_seq <- substr(nt_seq, 1, total_seq_len)
    if (use_quality) {
      quality_string <- strrep(quality_string, ceiling(total_seq_len / unpadded_seq_len))
      quality_string <- substr(quality_string, 1, total_seq_len) 
    }
  } else {
    nt_seq <- paste0(strrep("0", pad_len), nt_seq)
    if (use_quality) quality_string <- paste0(strrep("0", pad_len), quality_string)
  }
  
  if (nchar(nt_seq) < total_seq_len) {
    stop(paste0("Input sequence is shorter than required length (", total_seq_len, "). Use padding argument to pad sequence to bigger size."))
  }
  
  if (use_quality) {
    quality_vector <- quality_string %>% quality_to_probability()
  } else {
    quality_vector <- NULL
  }
  
  # start of samples
  start_indices <- seq(1, nchar(nt_seq) - total_seq_len + 1, by = step)
  num_samples <- length(start_indices)
  
  check_layer_name(model, layer_name)
  model <- tensorflow::tf$keras$Model(model$input, model$get_layer(layer_name)$output)
  if (verbose) {
    cat("Computing output for model at layer", layer_name,  "\n")
    print(model)
  }
  
  # extract number of neurons in last layer
  if (length(model$output$shape$dims) == 3) {
    if (!("lstm" %in% stringr::str_to_lower(model$output_names))) {
      stop("Output dimension of layer is > 1, format not supported yet")
    }
    layer.size <- model$output$shape[[3]]
  } else {
    layer.size <- model$output$shape[[2]]
  }
  
  # tokenize sequence
  nt_seq <- stringr::str_to_lower(nt_seq)
  tokSeq <- keras::texts_to_sequences(tokenizer, nt_seq)[[1]] - 1
  
  # seq end position
  pos_arg <- start_indices + total_seq_len - pad_len - 1
  
  if (include_seq) {
    output_seq <- substr(nt_seq, pad_len + 1, nchar(nt_seq))
  }
  
  # create h5 file to store states
  if (output_type == "h5") {
    h5_file <- hdf5r::H5File$new(filename, mode = "w")
    h5_file[["multi_entries"]] <- FALSE
    h5_file[["sample_end_position"]] <- pos_arg
    if (include_seq) h5_file[["sequence"]] <- output_seq
  }
  
  number_batches <- ceiling(length(start_indices)/batch_size)
  pred_list <- vector("list", number_batches)
  col_names <- c(as.character(1:layer.size), "sample_end_position")
  
  # subset input for target middle
  if (mode == "lm" && lm_format %in% c("target_middle_lstm", "target_middle_cnn")) {
    index_x_1 <- 1:ceiling((total_seq_len - target_len)/2)
    index_x_2 <- (max(index_x_1) + target_len + 1) : total_seq_len
  } 
  
  for (i in 1:number_batches) {
    
    index_start <- ((i - 1) * batch_size) + 1
    index_end <- min(c(num_samples + 1, index_start + batch_size)) - 1
    index <- index_start : index_end 
    
    x <- seq_encoding_label(sequence = tokSeq, 
                            maxlen = total_seq_len,
                            vocabulary = vocabulary,
                            start_ind = start_indices[index],
                            ambiguous_nuc = ambiguous_nuc,
                            tokenizer = NULL,
                            adjust_start_ind = FALSE,
                            quality_vector = quality_vector,
                            ...
    )
    
    if (mode == "lm" && lm_format == "target_middle_lstm") {
      x1 <- x[ , index_x_1, ]
      x2 <- x[ , index_x_2, ]
      
      if (length(index_x_1) == 1 | dim(x)[1] == 1) {
        x1 <- array(x1, dim = c(1, dim(x1)))
      }
      if (length(index_x_2) == 1 | dim(x)[1] == 1) {
        x2 <- array(x2, dim = c(1, dim(x2)))
      }
      
      x2 <- x2[ , dim(x2)[2]:1, ] # reverse order
      
      if (length(dim(x2)) == 2) {
        x2 <- array(x2, dim = c(1, dim(x2)))
      }
      
      x <- list(x1, x2)
    } 
    
    if (mode == "lm" && lm_format == "target_middle_cnn") {
      x <- x[ , c(index_x_1, index_x_2), ]
    } 
    
    if (reverse_complement_encoding) x <- list(x, reverse_complement_tensor(x))
    
    y <- stats::predict(model, x, verbose = 0)
    if (!is.null(round_digits)) y <- round(y, round_digits)
    pred_list[[i]] <- y
    
  }
  
  states <- do.call(rbind, pred_list)
  
  if (file_output) {
    if (output_type == "h5") {
      h5_file[["states"]] <- states
      h5_file$close_all()
    } else {
      col_names <- paste0("N", 1:ncol(states))
      colnames(states) <- col_names 
      utils::write.csv(x = states, file = filename, row.names = FALSE)
    }
  }
  
  if (return_states) {
    output_list <- list()
    output_list$states <- states
    output_list$sample_end_position <- pos_arg
    if (include_seq) output_list$sequence <- output_seq
    return(output_list)
  } else {
    return(NULL)
  }
}

#' Write states to h5 file
#'
#' @description Removes layers (optional) from pretrained model and calculates states of fasta file, writes a separate
#' h5 file for every fasta entry in fasta file. h5 files also contain the nucleotide sequence and positions of targets corresponding to states.
#' Names of output files are: file_path + "Nr" + i + filename + output_type, where i is the number of the fasta entry.
#'
#' @param filename Filename to store states, function adds "_nr_" + "i" after name, where i is entry number.
#' @param output_dir Path to folder, where to write output.
#' @noRd
predict_model_by_entry <- function(model, layer_name = NULL, path_input, round_digits = 2,
                                   filename = "states.h5", output_dir = NULL, step = 1, vocabulary = c("a", "c", "g", "t"),
                                   batch_size = 256, output_type = "h5", mode = "lm",
                                   lm_format = "target_right", format = "fasta", use_quality = FALSE,
                                   reverse_complement_encoding = FALSE, padding = "none", 
                                   verbose = FALSE, include_seq = FALSE, ambiguous_nuc = "zero", ...) {
  
  stopifnot(mode %in% c("lm", "label"))
  stopifnot(!is.null(filename))
  stopifnot(!is.null(output_dir))
  
  if (endsWith(filename, paste0(".", output_type))) {
    filename <- stringr::str_remove(filename, paste0(".", output_type, "$"))
    filename <- basename(filename)
  }
  
  # extract maxlen
  target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE)
  if (!target_middle) {
    if (reverse_complement_encoding) {
      model$input[[1]]$shape[[2]]
    } else {
      maxlen <- model$input$shape[[2]]
    }
  } else {
    maxlen_1 <- model$input[[1]]$shape[[2]]
    maxlen_2 <- model$input[[2]]$shape[[2]]
    maxlen <- maxlen_1 + maxlen_2
  }
  
  # load fasta file
  if (format == "fasta") {
    fasta.file <- microseq::readFasta(path_input)
  }
  if (format == "fastq") {
    fasta.file <- microseq::readFastq(path_input)
  }
  df <- fasta.file[ , c("Sequence", "Header")]
  names(df) <- c("seq", "header")
  rownames(df) <- NULL
  num_skipped_seq <- 0
  
  for (i in 1:nrow(df)) {
    
    # skip entry if too short
    if ((nchar(df[i, "seq"]) < maxlen) & padding == "none") {
      num_skipped_seq <- num_skipped_seq + 1
      next
    } 
    
    if (use_quality) {
      quality_string <- fasta.file$Quality[i]
    } else {
      quality_string <- NULL
    }
    
    current_file <- paste0(output_dir, "/", filename, "_nr_", as.character(i), ".", output_type)
    
    predict_model_one_seq(layer_name = layer_name, sequence = df[i, "seq"],
                          round_digits = round_digits, path_input = path_input,
                          filename = current_file, quality_string = quality_string,
                          step = step, vocabulary = vocabulary, batch_size = batch_size,
                          verbose = ifelse(i > 1, FALSE, verbose), 
                          output_type = output_type, mode = mode,
                          lm_format = lm_format, model = model, include_seq = include_seq,
                          padding = padding, ambiguous_nuc = "zero", 
                          reverse_complement_encoding = reverse_complement_encoding, ...)
  }
  
  if (verbose & num_skipped_seq > 0) {
    message(paste0("Skipped ", num_skipped_seq,
                   ifelse(num_skipped_seq == 1, " entry", " entries"),
                   ". Use different padding option to evaluate all."))
  }  
  
}

#' Write states to h5 file
#'
#' @description Removes layers (optional) from pretrained model and calculates states of fasta file,
#' writes separate states matrix in one .h5 file for every fasta entry.
#' h5 file also contains the nucleotide sequences and positions of targets corresponding to states.
#' @noRd
predict_model_by_entry_one_file <- function(model, path_input, round_digits = 2, filename = "states.h5",
                                            step = 1,  vocabulary = c("a", "c", "g", "t"), batch_size = 256, layer_name = NULL,
                                            verbose = TRUE, mode = "lm", use_quality = FALSE,
                                            lm_format = "target_right", padding = "none",
                                            format = "fasta", include_seq = TRUE, reverse_complement_encoding = FALSE, ...) {
  
  vocabulary <- stringr::str_to_lower(vocabulary)
  stopifnot(mode %in% c("lm", "label"))
  
  target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE)
  # extract maxlen
  if (!target_middle) {
    if (reverse_complement_encoding) {
      maxlen <- model$input[[1]]$shape[[2]]
    } else {
      maxlen <- model$input$shape[[2]]
    }
  } else {
    maxlen_1 <- model$input[[1]]$shape[[2]]
    maxlen_2 <- model$input[[2]]$shape[[2]]
    maxlen <- maxlen_1 + maxlen_2
  }
  
  # extract number of neurons in last layer
  if (length(model$output$shape$dims) == 3) {
    if (!("lstm" %in% stringr::str_to_lower(model$output_names))) {
      stop("Output dimension of layer is > 1, format not supported yet")
    }
    layer.size <- model$output$shape[[3]]
  } else {
    layer.size <- model$output$shape[[2]]
  }
  
  # load fasta file
  if (format == "fasta") {
    fasta.file <- microseq::readFasta(path_input)
  }
  if (format == "fastq") {
    fasta.file <- microseq::readFastq(path_input)
  }
  df <- fasta.file[ , c("Sequence", "Header")]
  names(df) <- c("seq", "header")
  rownames(df) <- NULL
  
  if (verbose) {
    # check if names are unique
    if (length(df$header) != length(unique(df$header))) {
      message("Header names are not unique, adding '_header_x' to names (x being the header number)")
      df$header <- paste0(df$header, paste0("_header_", 1:length(df$header)))
    }
  }
  
  # create h5 file to store states
  
  h5_file <- hdf5r::H5File$new(filename, mode = "w")
  h5_file[["multi_entries"]] <- TRUE
  states.grp <- h5_file$create_group("states")
  sample_end_position.grp <- h5_file$create_group("sample_end_position")
  if (include_seq) seq.grp <- h5_file$create_group("sequence")
  
  num_skipped_seq <- 0
  
  for (i in 1:nrow(df)) {
    
    #seq_name <- df$header[i]
    seq_name <- paste0("entry_", i)
    temp_file <- tempfile(fileext = ".h5")
    
    # skip entry if too short
    if ((nchar(df[i, "seq"]) < maxlen) & padding == "none") {
      num_skipped_seq <- num_skipped_seq + 1
      next
    } 
    
    if (use_quality) {
      quality_string <- fasta.file$Quality[i]
    } else {
      quality_string <- NULL
    }
    
    output_list <- predict_model_one_seq(layer_name = layer_name, sequence = df$seq[i], path_input = path_input,
                                         round_digits = round_digits, filename = temp_file, step = step, vocabulary = vocabulary,
                                         batch_size = batch_size, return_states = TRUE, quality_string = quality_string,
                                         output_type = "h5", model = model, mode = mode, lm_format = lm_format,
                                         ambiguous_nuc = "zero", verbose = ifelse(i > 1, FALSE, verbose), 
                                         padding = padding, format = format, include_seq = include_seq,
                                         reverse_complement_encoding = reverse_complement_encoding, ...)
    
    states.grp[[seq_name]] <- output_list$states
    sample_end_position.grp[[seq_name]] <- output_list$sample_end_position
    
    if (include_seq) seq.grp[[seq_name]] <- output_list$sequence
  }
  
  if (verbose & num_skipped_seq > 0) {
    message(paste0("Skipped ", num_skipped_seq,
                   ifelse(num_skipped_seq == 1, " entry", " entries"),
                   ". Use different padding option to evaluate all."))
  }  
  
  h5_file$close_all()
}

#' Get states for label classification model.
#'
#' Computes output at specified model layer. Forces every fasta entry to have length maxlen by either padding sequences shorter than maxlen or taking random subsequence for
#' longer sequences.
#'
#' @inheritParams predict_model_one_seq
#' @noRd
predict_model_one_pred_per_entry <- function(model, layer_name = NULL, path_input, round_digits = 2, format = "fasta",
                                             ambiguous_nuc = "zero", filename = "states.h5", padding = padding,
                                             vocabulary = c("a", "c", "g", "t"), batch_size = 256, verbose = TRUE,
                                             return_states = FALSE, reverse_complement_encoding = FALSE, 
                                             include_seq = FALSE, use_quality = FALSE, ...) {
  
  vocabulary <- stringr::str_to_lower(vocabulary)
  file_type <- "h5"
  stopifnot(batch_size > 0)
  stopifnot(!file.exists(filename))
  # token for ambiguous nucleotides
  for (i in letters) {
    if (!(i %in% stringr::str_to_lower(vocabulary))) {
      amb_nuc_token <- i
      break
    }
  }
  tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "N"), vocabulary)
  
  if (is.null(layer_name)) {
    layer_name <- model$output_names
    if (verbose) message(paste("layer_name not specified. Using layer", layer_name))
  }

  # extract maxlen
  if (reverse_complement_encoding) {
    maxlen <- model$input[[1]]$shape[[2]]
  } else {
    maxlen <- model$input$shape[[2]]
  }
  
  if (format == "fasta") {
    fasta.file <- microseq::readFasta(path_input)
  }
  if (format == "fastq") {
    fasta.file <- microseq::readFastq(path_input)
  }
  
  num_samples <- nrow(fasta.file)
  
  nucSeq <- as.character(fasta.file$Sequence)
  seq_length <- nchar(fasta.file$Sequence)
  if (use_quality) {
    quality_string <- vector("character", length(nucSeq)) 
  } else {
    quality_string <- NULL
  }
  
  for (i in 1:length(nucSeq)) {
    # take random subsequence
    if (seq_length[i] > maxlen) {
      start <- sample(1 : (seq_length[i] - maxlen + 1) , size = 1)
      nucSeq[i] <- substr(nucSeq[i], start = start, stop = start + maxlen - 1)
      if (use_quality) {
        quality_string[i] <- substr(fasta.file$Quality[i], start = start, stop = start + maxlen - 1)
      }  
    }
    # pad sequence
    if (seq_length[i] < maxlen) {
      nucSeq[i] <- paste0(paste(rep("N", maxlen - seq_length[i]), collapse = ""), nucSeq[i])
      if (use_quality) {
        quality_string[i] <- paste0(paste(rep("0", maxlen - seq_length[i]), collapse = ""), fasta.file$Quality[i])
      }  
    }
  }
  
  model <- tensorflow::tf$keras$Model(model$input, model$get_layer(layer_name)$output)
  if (verbose) {
    cat("Computing output for model at layer", layer_name,  "\n")
    print(model)
  } 
  
  # extract number of neurons in last layer
  if (length(model$output$shape$dims) == 3) {
    if (!("lstm" %in% stringr::str_to_lower(model$output_names))) {
      stop("Output dimension of layer is > 1, format not supported yet")
    }
    layer.size <- model$output$shape[[3]]
  } else {
    layer.size <- model$output$shape[[2]]
  }
  
  if (file_type == "h5") {
    # create h5 file to store states
    h5_file <- hdf5r::H5File$new(filename, mode = "w")
    if (!missing(path_input)) h5_file[["fasta_file"]] <- path_input
    
    h5_file[["header_names"]] <- fasta.file$Header
    if (include_seq) h5_file[["sequences"]] <- nucSeq
    h5_file[["states"]] <- array(0, dim = c(0, layer.size))
    h5_file[["multi_entries"]] <- FALSE
    writer <- h5_file[["states"]]
  }
  
  rm(fasta.file)
  #nucSeq <- paste(nucSeq, collapse = "") %>% stringr::str_to_lower()
  number_batches <- ceiling(num_samples/batch_size)
  if (verbose) cat("Evaluating", number_batches, ifelse(number_batches > 1, "batches", "batch"), "\n")
  row <- 1
  string_start_index <- 1 
  ten_percent_steps <- seq(number_batches/10, number_batches, length.out = 10)
  percentage_index <- 1
  
  if (number_batches > 1) {
    for (i in 1:(number_batches - 1)) {
      string_end_index <-string_start_index + batch_size - 1 
      char_seq <- nucSeq[string_start_index : string_end_index] %>% paste(collapse = "") 
      
      if (use_quality) {
        quality_string_subset <- quality_string[string_start_index : string_end_index] %>% paste(collapse = "") 
      } else {
        quality_string_subset <- NULL
      }
      
      if (i == 1) start_ind <- seq(1, nchar(char_seq), maxlen)
      one_hot_batch <- seq_encoding_label(sequence = NULL, maxlen = maxlen, vocabulary = vocabulary,
                                          start_ind = start_ind, ambiguous_nuc = ambiguous_nuc, 
                                          char_sequence = char_seq, quality_vector = quality_string_subset,
                                          tokenizer = tokenizer, adjust_start_ind = TRUE, ...) 
      if (reverse_complement_encoding) one_hot_batch <- list(one_hot_batch, reverse_complement_tensor(one_hot_batch))
      activations <- keras::predict_on_batch(model, one_hot_batch)
      writer[row : (row + batch_size - 1), ] <- activations
      row <- row + batch_size
      string_start_index <- string_end_index + 1
      
      if (verbose & (i > ten_percent_steps[percentage_index]) & percentage_index < 10) {
        cat("Progress: ", percentage_index * 10 ,"% \n")
        percentage_index <- percentage_index + 1
      }
      
    }
  }
  
  # last batch might be shorter
  char_seq <- nucSeq[string_start_index : length(nucSeq)] %>% paste(collapse = "") 
  if (use_quality) {
    quality_string_subset <- quality_string[string_start_index : length(nucSeq)] %>% paste(collapse = "") 
  } else {
    quality_string_subset <- NULL
  }
  one_hot_batch <- seq_encoding_label(sequence = NULL, maxlen = maxlen, vocabulary = vocabulary,
                                      start_ind = seq(1, nchar(char_seq), maxlen), ambiguous_nuc = "zero", nuc_dist = NULL,
                                      quality_vector = quality_string_subset, use_coverage = FALSE, max_cov = NULL,
                                      cov_vector = NULL, n_gram = NULL, n_gram_stride = 1, char_sequence = char_seq,
                                      tokenizer = tokenizer, adjust_start_ind = TRUE, ...) 
  if (reverse_complement_encoding) one_hot_batch <- list(one_hot_batch, reverse_complement_tensor(one_hot_batch))
  activations <- keras::predict_on_batch(model, one_hot_batch)
  writer[row : num_samples, ] <- activations[1 : length(row:num_samples), ]
  
  if (verbose) cat("Progress: 100 % \n")
  
  if (return_states & (file_type == "h5")) states <- writer[ , ]
  if (file_type == "h5") h5_file$close_all()
  if (return_states) return(states)
}


#' Read states from h5 file
#'
#' Reads h5 file created by  \code{\link{predict_model}} function.
#'
#' @param h5_path Path to h5 file.
#' @param rows Range of rows to read. If `NULL` read everything.
#' @param get_sample_position Return position of sample corresponding to state if `TRUE`.
#' @param get_seq Return nucleotide sequence if `TRUE`.
#' @param verbose Boolean.
#' @examplesIf reticulate::py_module_available("tensorflow")
#' # make prediction for single sequence and write to h5 file
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE)
#' vocabulary <- c("a", "c", "g", "t")
#' sequence <- paste(sample(vocabulary, 200, replace = TRUE), collapse = "")
#' output_file <- tempfile(fileext = ".h5")
#' predict_model(output_format = "one_seq", model = model, step = 10,
#'               sequence = sequence, filename = output_file, mode = "label")
#' load_prediction(h5_path = output_file)
#' 
#' @returns A list of data frames, containing model predictions and sequence positions.
#' @export
load_prediction <- function(h5_path, rows = NULL, verbose = FALSE,
                            get_sample_position = FALSE, get_seq = FALSE) {
  
  if (is.null(rows)) complete <- TRUE
  h5_file <- hdf5r::H5File$new(h5_path, mode = "r")
  
  multi_entries <- ifelse(h5_file[["multi_entries"]][], TRUE, FALSE)
  if (!multi_entries) {
    number_entries <- 1
  } else {
    entry_names <- names(h5_file[["states"]])
    number_entries <- length(entry_names)
    output_list <- list()
  }
  
  train_mode <- "label"
  
  if (get_sample_position & !any(c("sample_end_position", "target_position") %in% names(h5_file))) {
    get_sample_position <- FALSE
    message("File does not contain target positions.")
  }
  
  if (!multi_entries) {
    
    read_states <- h5_file[["states"]]
    
    if (get_sample_position) {
      read_targetPos <- h5_file[["sample_end_position"]]
    }
    
    if (verbose) {
      cat("states matrix has", dim(read_states[ , ])[1], "rows and",  dim(read_states[ , ])[2], "columns \n")
    }
    if (complete) {
      states <- read_states[ , ]
      if (get_sample_position) {
        targetPos <- read_targetPos[ ]
      }
    } else {
      states <- read_states[rows, ]
      if (get_sample_position) {
        targetPos <- read_targetPos[rows]
      }
    }
    
    if (is.null(dim(states))) {
      states <- matrix(states, nrow = 1)
    }
    
    contains_seq <- FALSE
    if (get_seq) {
      if ("sequence" %in% names(h5_file)) {
        contains_seq <- TRUE
        sequence <- h5_file[["sequence"]][]
      } else {
        contains_seq <- FALSE
        message("File does not contain sequence.")
      }
    }
    
    h5_file$close_all()
    output_list <- list(states = states)
    if (get_sample_position) {
      if (train_mode == "label") {
        output_list$sample_end_position <- targetPos
      } else {
        output_list$target_position <- targetPos
      }
    }
    
    if (get_seq && contains_seq) {
      output_list$sequence <- sequence
    }
    
    return(output_list)
    
    # multi entries
  } else {
    
    if (verbose) {
      cat("file contains", number_entries, "entries \n")
    }
    
    if (get_sample_position) {
      target_name <- "sample_end_position"
    }
    
    if (get_seq & !("sequence" %in% names(h5_file))) {
      message("File does not contain sequence.")
      get_seq <- FALSE
    }
    
    for (i in 1:number_entries) {
      
      entry_name <- entry_names[i]
      states <- h5_file[["states"]][[entry_name]][ , ]
      if (is.null(dim(states))) {
        states <- matrix(states, nrow = 1)
      }
      
      if (get_seq) {
        sequence <- h5_file[["sequence"]][[entry_name]][ ]
      }
      
      if (get_sample_position) {
        targetPos <- h5_file[[target_name]][[entry_name]][ ]
      }
      
      if (!complete) {
        states <- states[rows, ]
        
        if (is.null(dim(states))) {
          states <- matrix(states, nrow = 1)
        }
        
        if (get_sample_position) {
          targetPos <- hdf5r::h5attr(read_states, target_name)[rows]
        }
      }
      
      if (get_sample_position) {
        l <- list(states = states, sample_end_position = targetPos)
      } else {
        l <- list(states = states)
      }
      
      if (get_seq) {
        l[["sequence"]] <- sequence
      }
      output_list[[entry_name]] <- l
    }
    h5_file$close_all()
    names(output_list) <- entry_names
    return(output_list)
  }
}

#' Create summary of predictions 
#'
#' Create summary data frame for confidence predictions over 1 or several state files or a data frame.
#' Columns in file or data frame should be confidence predictions for one class,
#' i.e. each rows should sum to 1 and have nonnegative entries. 
#' Output data frame contains average confidence scores, max score and percentage of votes for each class.
#'
#' @param states_path Folder containing state files or a single file with same ending as `file_type`.
#' @param label_names Names of predicted classes.
#' @param file_type `"h5"` or `"csv"`.
#' @param df A states data frame. Ignore `states_dir` argument if not `NULL`.
#' @examples 
#' m <- c(0.9,  0.1, 0.2, 0.01,
#'        0.05, 0.7, 0.2, 0,
#'        0.05, 0.2, 0.6, 0.99) %>% matrix(ncol = 3)
#' 
#' label_names <- paste0("class_", 1:3)
#' df <- as.data.frame(m)
#' pred_summary <- summarize_states(label_names = label_names, df = df)
#' pred_summary
#' 
#' @returns A data frame of predictions summaries.
#' @export
summarize_states <- function(states_path = NULL, label_names = NULL, file_type = "h5", df = NULL) {
  
  if (!is.null(df)) {
    states_path <- NULL
  }
  
  if (is.null(states_path)) {
    state_files <- 1
  } else {
    if (endsWith(states_path, file_type)) {
      state_files <- states_path
    } else {
      state_files <- list.files(states_path, full.names = TRUE)
    }
  }
  
  if (!is.null(label_names)) {
    num_labels <- length(label_names)
  }
  
  summary_list <- list()
  
  for (state_file in state_files) {
    
    if (is.null(df)) {
      if (file_type == "h5") {
        df <- load_prediction(h5_path = state_file, get_sample_position = FALSE, verbose = FALSE)
        df <- as.data.frame(df$states)
      }
      if (file_type == "csv") {
        df <- utils::read.csv(state_file)
        if (ncol(df) != num_labels) {
          df <- utils::read.csv2(state_file)
        }
      }
    }
    
    if (state_file == state_files[1] & is.null(label_names)) {
      label_names <- paste0("X_", 1:ncol(df))
      num_labels <- length(label_names)
    }
    
    stopifnot(ncol(df) == num_labels)
    
    names(df) <- c(label_names)
    
    mean_df <- data.frame(matrix(0, nrow = 1, ncol = num_labels))
    names(mean_df) <- paste0("mean_conf_", label_names)
    max_df <- data.frame(matrix(0, nrow = 1, ncol = num_labels))
    names(max_df) <- paste0("max_conf_", label_names)
    
    for (label in label_names) {
      mean_df[[paste0("mean_conf_", label)]] <-  mean(df[[label]])
      max_df[[paste0("max_conf_", label)]] <-  max(df[[label]])
    }
    
    vote_distribution <- apply(df[label_names], 1, which.max)
    vote_perc <- table(factor(vote_distribution, levels = 1:length(label_names)))/length(vote_distribution)
    votes_df <- data.frame(matrix(vote_perc, nrow = 1, ncol = num_labels))
    names(votes_df) <- paste0("vote_perc_", label_names)
    
    mean_prediction <- label_names[which.max(unlist(mean_df))]
    max_prediction <- label_names[which.max(unlist(max_df))]
    vote_prediction <- label_names[which.max(vote_perc)]
    
    if (is.null(states_path)) {
      file_name <- NA
    } else {
      file_name <- basename(state_file)
    }
    
    summary_list[[state_file]] <- data.frame(file_name, mean_df, max_df, votes_df,
                                             mean_prediction, max_prediction, vote_prediction,
                                             num_prediction = nrow(df))
    
  }
  
  summary_df <- data.table::rbindlist(summary_list)
  return(summary_df)
}
GenomeNet/deepG documentation built on Dec. 24, 2024, 12:11 p.m.