R/train.R

Defines functions trainNetwork

Documented in trainNetwork

#' @title Trains a (mostly) LSTM model on genomic data. Designed for developing genome based language models (GenomeNet)
#'
#' @description
#' Depth and number of neurons per layer of the netwok can be specified. First layer can be a Convolutional Neural Network (CNN) that is designed to capture codons.
#' If a path to a folder where FASTA files are located is provided, batches will ge generated using an external generator which
#' is recommended for big training sets. Alternative, a dataset can be supplied that holds the preprocessed batches (generated by \code{preprocessSemiRedundant()})
#'  and keeps them in RAM. Supports also training on instances with multiple GPUs and scales linear with number of GPUs present.
#' @param train_type Either "lm" for language model, "label_header" or "label_folder". Language model is trained to predict next character in sequence.
#' label_header/label_folder are trained to predict a corresponding class, given a sequence as input. If "label_header", class will be read from fasta headers. 
#' If "label_folder", class will be read from folder, i.e. all fasta files in one folder must belong to the same class.     mailab 
#' @param model A keras model.   
#' @param model_path Path to a pretrained model.
#' @param path Path to folder where individual or multiple FASTA files are located for training. If \code{train_type} is \code{label_folder}, should be a vector  
#' containing a path for each class.
#' @param path.val Path to folder where individual or multiple FASTA files are located for validation.If \code{train_type} is \code{label_folder}, should be a vector  
#' containing a path for each class.
#' @param dataset Dataframe holding training samples in RAM instead of using generator. 
#' @param checkpoint_path Path to checkpoints folder. 
#' @param validation.split Defines the fraction of the batches that will be used for validation (compared to size of training data).
#' @param run.name Name of the run (without file ending). Name will be used to identify output from callbacks.
#' @param batch.size Number of samples that are used for one network update.
#' @param epochs Number of iterations.
#' @param max.queue.size Queue on fit_generator().
#' @param lr.plateau.factor Factor of decreasing learning rate when plateau is reached.
#' @param patience Number of epochs waiting for decrease in loss before reducing learning rate.
#' @param cooldown Number of epochs without changing learning rate.
#' @param steps.per.epoch Number of batches to finish one epoch.
#' @param step Frequency of sampling steps.
#' @param randomFiles TRUE/FALSE go through files sequentially or shuffle beforehand.
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector.
#' @param initial_epoch Epoch at which to start training, set to 0 if no \code{model_path} argument is given. Note that network
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds. 
#' @param tensorboard.log Path to tensorboard log directory.
#' @param save_best_only Only save model that improved on best val_loss score. 
#' @param compile Whether to compile the model after loading.
#' @param solver Optimization method, options are "adam", "adagrad", "rmsprop" or "sgd". Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE. 
#' Otherwise solver is determined when model is created.
#' @param learning.rate Learning rate for optimizer. Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE. 
#' Otherwise learning rate is determined when model is created.
#' @param max_iter Stop after max_iter number of iterations failed to produce new sample. 
#' @param seed Sets seed for set.seed function, for reproducible results when using \code{randomFiles} or \code{shuffleFastaEntries}  
#' @param shuffleFastaEntries Logical, shuffle entries in file.
#' @param output List of optional outputs, no output if none is TRUE.
#' @param format File format, "fasta" or "fastq". 
#' @param fileLog Write name of files to csv file if path is specified.
#' @param labelVocabulary Character vector of possible targets. Targets outside \code{labelVocabulary} will get discarded.
#' @param numberOfFiles Use only specified number of files, ignored if greater than number of files in corpus.dir. 
#' @param reverseComplements Logical, half of batch contains sequences and other its reverse complements. Reverse complement 
#' is given by reversed order of sequence and switching A/T and C/G. \code{batch.size} argument has to be even, otherwise 1 will be added
#' to \code{batch.size}
#' @export
trainNetwork <- function(train_type = "lm", 
                         model_path = NULL,
                         model = NULL,
                         path = NULL,
                         path.val = NULL,
                         dataset = NULL,
                         checkpoint_path, 
                         validation.split = 0.2,
                         run.name = "run",
                         batch.size = 64,
                         epochs = 10,
                         max.queue.size = 100,
                         lr.plateau.factor = 0.9,
                         patience = 5,
                         cooldown = 5,
                         steps.per.epoch = 1000,
                         step = 1,
                         randomFiles = FALSE,
                         initial_epoch = NULL,
                         vocabulary = c("a", "c", "g", "t"),
                         tensorboard.log,
                         save_best_only = TRUE,
                         compile = TRUE,
                         learning.rate = NULL,
                         solver = NULL,
                         max_iter = 1000, 
                         seed = c(1234, 4321),
                         shuffleFastaEntries = FALSE,
                         output = list(none = FALSE, 
                                       checkpoints =TRUE, 
                                       tensorboard = TRUE,
                                       log = TRUE,
                                       serialize_model = TRUE,
                                       full_model = TRUE
                         ),
                         format = "fasta",
                         fileLog = NULL,
                         labelVocabulary = NULL,
                         numberOfFiles = NULL,
                         reverseComplements = FALSE) {  
  
  stopifnot(train_type %in% c("lm", "label_header", "label_folder"))
  if (train_type == "lm"){
    labelGen <- FALSE
    labelByFolder <- FALSE
  } 
  if (train_type == "label_header"){
    labelGen <- TRUE
    labelByFolder <- FALSE
    stopifnot(!is.null(labelVocabulary))
  }
  if (train_type == "label_folder"){
    labelGen <- TRUE
    labelByFolder <- TRUE
  }
 
  if (output$none){
    output$checkpoints <- FALSE 
    output$tensorboard <- FALSE
    output$log <- FALSE
    output$serialize_model <- FALSE
    output$full_model <- FALSE
  }
  
  label.vocabulary.size <- length(labelVocabulary)
  vocabulary.size <- length(vocabulary) 
  # extract maxlen from model
  maxlen <- model$input$shape[1] 
  
  if (labelByFolder){
    if (length(path) == 1) warning("Training with just one label")
  }
  
  if (output$checkpoints){
    ## create folder for checkpoints using run.name
    ## filenames contain epoch, validation loss and validation accuracy 
    checkpoint_dir <- paste0(checkpoint_path, "/", run.name, "_checkpoints")
    dir.create(checkpoint_dir, showWarnings = FALSE)
    filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-val_loss{val_loss:.2f}-val_acc{val_acc:.3f}.hdf5")
  }
  
  # Check if run.name is unique
  if (dir.exists(file.path(tensorboard.log, run.name)) & output$tensorboard) {
    stop(paste0("Tensorboard entry '", run.name , "' is already present. Please give your run a unique name."))
  }
  
  # Load pretrained model    
  if (!is.null(model_path)){
    # epochs arguments can be misleading 
    if (!missing(initial_epoch)){
      if (initial_epoch >= epochs){
        stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.")
      }
    }
    
    # extract initial_epoch from filename if no argument is given
    if (is.null(initial_epoch)){
      epochFromFilename <- stringr::str_extract(model_path, "Ep.\\d+")
      initial_epoch <- as.integer(substring(epochFromFilename, 4, nchar(epochFromFilename)))
      if (initial_epoch >= epochs){
        stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.")
      }
    }
    
    # load model  
    model <- keras::load_model_hdf5(model_path, compile = compile)
    summary(model)
    
    # extract maxlen
    maxlen <- model$input$shape[1] 
    
    if (compile & (!is.null(learning.rate)|!is.null(solver))){
      message("Arguments for solver and learning rate will be ignored. Set compile to FALSE to use custom solver and learning rate.")
    }
    
    if (!compile){
      # choose optimization method
      if (solver == "adam")
        optimizer <-
          keras::optimizer_adam(lr = learning.rate)
      if (solver == "adagrad")
        optimizer <-
          keras::optimizer_adagrad(lr = learning.rate)
      if (solver == "rmsprop")
        optimizer <-
          keras::optimizer_rmsprop(lr = learning.rate)
      if (solver == "sgd")
        optimizer <-
          keras::optimizer_sgd(lr = learning.rate) 
      
      model %>% keras::compile(loss = "categorical_crossentropy",
                               optimizer = optimizer, metrics = c("acc"))
      
    }
  } else {
    initial_epoch <- 0
  }
  
  if (output$tensorboard){
    hp <- reticulate::import("tensorboard.plugins.hparams.api")
    
    model_hparam <- get_hyper_param(model)
    
    # list of hyperparameters
    hparams <- reticulate::dict(
      HP_VOCABULARY = paste(vocabulary, collapse = ","),
      HP_PATH = paste(path, collapse = ", "),
      HP_REVERSE_COMP = reverseComplements, 
      HP_LABEL.VOC = paste(labelVocabulary, collapse = ", "),
      HP_LAYER.SIZE =  model_hparam$HP_LAYER.SIZE, 
      HP_OPTIMIZER = model_hparam$HP_OPTIMIZER, 
      HP_MAXLEN = maxlen, 
      HP_USE.CUDNN = model_hparam$HP_USE.CUDNN, 
      HP_EPOCHS = epochs, 
      HP_MAX.QUEUE.SIZE = max.queue.size,
      HP_LR.PLATEAU.FACTOR = lr.plateau.factor,
      HP_NUM_LAYERS = model_hparam$HP_NUM_LAYERS,
      HP_BATCH.SIZE = batch.size,
      HP_LEARNING.RATE = model_hparam$HP_LEARNING.RATE, 
      HP_DROPOUT = model_hparam$HP_DROPOUT, 
      HP_USE.CODON.CNN = model_hparam$HP_USE.CODON.CNN, 
      HP_PATIENCE = patience, 
      HP_COOLDOWN = cooldown,
      HP_SPEPS.PER.EPOCHE = steps.per.epoch,
      HP_STEP = step,
      HP_RANDOM.FILES = randomFiles,
      HP_BIDIRECTIONAL = model_hparam$HP_BIDIRECTIONAL 
    )
  }
  
  # if no dataset is supplied, external fasta generator will generate batches
  if (is.null(dataset)) {
    message("Starting fasta generator...")
    if (!labelGen){
      
      # generator for training
      gen <- fastaFileGenerator(corpus.dir = path, batch.size = batch.size,
                                maxlen = maxlen, step = step, randomFiles = randomFiles,
                                vocabulary = vocabulary, max_iter = max_iter, seed = seed[1],
                                shuffleFastaEntries = shuffleFastaEntries, format = format,
                                fileLog = fileLog, reverseComplements = reverseComplements)
      
      
      # generator for validation
      gen.val <- fastaFileGenerator(corpus.dir = path.val, batch.size = batch.size,
                                    maxlen = maxlen, step = step, randomFiles = randomFiles,
                                    vocabulary = vocabulary, max_iter = max_iter, seed = seed[2],
                                    shuffleFastaEntries = shuffleFastaEntries, format = format,
                                    fileLog = NULL, reverseComplements = FALSE)
      # label generator
    } else {
      # label by folder
      if (labelByFolder){
        
        # initialize training generators
        initializeGenerators(directories = path,
                             format = format,
                             batch.size = batch.size,
                             maxlen = maxlen,
                             max_iter = max_iter,
                             vocabulary = vocabulary,
                             verbose = FALSE,
                             randomFiles = randomFiles,
                             step = step, 
                             showWarnings = FALSE,
                             seed = seed[1],
                             shuffleFastaEntries = shuffleFastaEntries,
                             numberOfFiles = numberOfFiles,
                             fileLog = fileLog,
                             reverseComplements = reverseComplements,
                             val = FALSE)
        
        # initialize validation generators
        initializeGenerators(directories = path.val,
                             format = format,
                             batch.size = batch.size,
                             maxlen = maxlen,
                             max_iter = max_iter,
                             vocabulary = vocabulary,
                             verbose = FALSE,
                             randomFiles = randomFiles,
                             step = step, 
                             showWarnings = FALSE,
                             seed = seed[2],
                             shuffleFastaEntries = shuffleFastaEntries,
                             numberOfFiles = NULL,
                             fileLog = fileLog,
                             reverseComplements = FALSE,
                             val = TRUE)
        gen <- labelByFolderGeneratorWrapper(val = FALSE, path = path)
        gen.val <- labelByFolderGeneratorWrapper(val = TRUE, path = path.val)
        
      } else {
        
        # generator for training
        gen <- fastaLabelGenerator(corpus.dir = path,
                                   format = format,
                                   batch.size = batch.size,
                                   maxlen = maxlen,
                                   max_iter = max_iter,
                                   vocabulary = vocabulary,
                                   verbose = FALSE,
                                   randomFiles = randomFiles,
                                   step = step, 
                                   showWarnings = FALSE,
                                   seed = seed[1],
                                   shuffleFastaEntries = shuffleFastaEntries,
                                   fileLog = fileLog,
                                   labelVocabulary = labelVocabulary,
                                   reverseComplements = reverseComplements
        )
        
        gen.val <- fastaLabelGenerator(corpus.dir = path.val,
                                       format = format,
                                       batch.size = batch.size,
                                       maxlen = maxlen,
                                       max_iter = max_iter,
                                       vocabulary = vocabulary,
                                       verbose = FALSE,
                                       randomFiles = randomFiles,
                                       step = step, 
                                       showWarnings = FALSE,
                                       seed = seed[2],
                                       shuffleFastaEntries = shuffleFastaEntries,
                                       fileLog = NULL,
                                       labelVocabulary = labelVocabulary,
                                       reverseComplements = FALSE
        )
      }
    }
    
    # callback list
    callbacks = list(keras::callback_reduce_lr_on_plateau(
      monitor = "loss",
      factor = lr.plateau.factor,
      patience = patience,
      cooldown = cooldown
    )
    )
    
    # add optional callbacks
    list_index <- 2
    if (output$checkpoints){
      callbacks[[list_index]] <- keras::callback_model_checkpoint(filepath = filepath_checkpoints,
                                                                  save_weights_only = FALSE,
                                                                  save_best_only = save_best_only,
                                                                  verbose = 1)
      
      list_index <- list_index + 1                 
    }                
    
    if (output$tensorboard){
      callbacks[[list_index]] <- keras::callback_tensorboard(file.path(tensorboard.log, run.name),
                                                             write_graph = TRUE, 
                                                             histogram_freq = 1,
                                                             write_images = TRUE,
                                                             write_grads = TRUE)
      # log hparams
      callbacks[[list_index + 1]] <- hp$KerasCallback(file.path(tensorboard.log, run.name), hparams, trial_id = run.name) 
      list_index <- list_index + 2
      
      # create string with function arguments 
      argumentList <- as.list(match.call(expand.dots=FALSE))
      argAsChar <- as.character(argumentList)
      argText <- vector("character")
      argsInQuotes <- c("model_path", "path", "path.val", "checkpoint_path", "run.name", "solver",
                        "tensorboard.log", "fileLog", "train_type")
      argText[1] <- "trainNetwork("
      for (i in 2:(length(argumentList) - 1)){
        arg <- argAsChar[[i]]
        if (names(argumentList)[i] %in% argsInQuotes){  
          argText[i] <- paste0(names(argumentList)[i], " = ", '\"', arg, '\"', " ,")
        } else {
          argText[i] <- paste0(names(argumentList)[i], " = ", arg, " ,")
        }
      }
      i <- length(argumentList)
      if (names(argumentList)[i] %in% argsInQuotes){  
        argText[i] <- paste0(names(argumentList)[i], " = ", '\"', argAsChar[[i]], '\"', ")")
      } else {
        argText[i] <- paste0(names(argumentList)[i], " = ", argAsChar[[i]], ")")
      }
      
      # write function arguments as text in tensorboard
      trainNetworkArguments <- keras::callback_lambda(
        on_train_begin = function(logs){
          file.writer <- tensorflow::tf$summary$create_file_writer(file.path(tensorboard.log, run.name))
          file.writer$set_as_default()
          tensorflow::tf$summary$text(name="Arguments",  data = argText, step = 0L)
          file.writer$flush()
        }
      )
      
      callbacks[[list_index]] <- trainNetworkArguments
      list_index <- list_index + 1 
      
      # confusion matrix callback
      confMat <- keras::callback_lambda(
        on_epoch_end = function(epoch, logs) {
          
          file.writer <- tensorflow::tf$summary$create_file_writer(file.path(tensorboard.log, run.name))
          file.writer$set_as_default()

          if (labelGen | labelByFolder){
            num_classes <- label.vocabulary.size
            confMatLabels <- labelVocabulary
          } else {
            num_classes <- vocabulary.size
            confMatLabels <- vocabulary
          }

          df_true_pred <- data.frame(
            true = NULL,
            pred = NULL
          )
          
          for (i in 1:ceiling(steps.per.epoch * validation.split)){
            z <- gen.val()
            y_true <- apply(z[[2]], 1, which.max) - 1
            y_pred <- keras::predict_classes(model, z[[1]])
              df_true_pred <- rbind(df_true_pred, cbind(y_true, y_pred))
          }
          
          df_true_pred$true <- factor(y_true, levels = 0:(length(confMatLabels) - 1), labels = confMatLabels)
          df_true_pred$pred <- factor(y_pred, levels = 0:(length(confMatLabels) - 1), labels = confMatLabels)
          
          cm <- yardstick::conf_mat(df_true_pred, true, pred)
          suppressMessages(
            cm_plot <- ggplot2::autoplot(cm, type = "heatmap") +
              ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1") 
          )

          plot_path <- paste0(getwd(), "/", run.name, ".png")
          suppressMessages(ggplot2::ggsave(filename = plot_path, plot = cm_plot))

          # convert saved image to array
          np <- reticulate::import("numpy", convert = FALSE)
          # python module pillow needs to be installed  
          PIL <- reticulate::import("PIL", convert = FALSE)
          im <- np$asarray(PIL$Image$open(plot_path))
          im_R <- reticulate::py_to_r(im)
          im_R <- array(im_R, dim = c(1, dim(im_R)))

          tensorflow::tf$summary$image(name = "confusion matrix", data = im_R/255, step = epoch)
          file.writer$flush()
        }
      )
      
      callbacks[[list_index]] <- confMat
      list_index + 1
      
    }
    
    if (output$log){
      callbacks[[list_index]] <-  keras::callback_csv_logger(
        paste0(run.name, "_log.csv"),
        separator = ";",
        append = TRUE)
    }            
    
    # training
    message("Start training ...")
    history <-
      model %>% keras::fit_generator(
        generator = gen,
        validation_data = gen.val,
        validation_steps = ceiling(steps.per.epoch * validation.split),
        steps_per_epoch = steps.per.epoch,
        max_queue_size = max.queue.size,
        epochs = epochs,
        initial_epoch = initial_epoch,
        callbacks = callbacks
      )
  } else {
    message("Start training ...")
    history <- model %>% keras::fit(
      dataset$X,
      dataset$Y,
      batch_size = batch.size,
      validation_split = validation.split,
      epochs = epochs)
  }
  
  # save final model
  message("Training done.\nSave model.")
  
  if (output$serialize_model){
    Rmodel <-
      keras::serialize_model(model, include_optimizer = TRUE)
    save(Rmodel, file = paste0(run.name, "_full_model.Rdata"))
  }
  
  if (output$full_model){
    keras::save_model_hdf5(
      model,
      paste0(run.name, "_full_model.hdf5"),
      overwrite = TRUE,
      include_optimizer = TRUE
    )
  }
  return(history)
}
hiddengenome/altum documentation built on April 22, 2020, 9:33 p.m.