R/train_cpc.R

Defines functions train_model_cpc

Documented in train_model_cpc

#' @title Train CPC inspired model
#'   
#' @description
#' Train a CPC (Oord et al.) inspired neural network on genomic data.
#' 
#' @inheritParams generator_fasta_lm
#' @inheritParams generator_fasta_label_folder
#' @inheritParams generator_fasta_label_header_csv
#' @inheritParams train_model
#' @param train_type Either `"cpc"`, `"Self-GenomeNet"`. 
#' @param encoder A keras encoder for the cpc function. 
#' @param context A keras context model for the cpc function.
#' @param path Path to training data. If \code{train_type} is \code{label_folder}, should be a vector or list
#' where each entry corresponds to a class (list elements can be directories and/or individual files). If \code{train_type} is not \code{label_folder}, 
#' can be a single directory or file or a list of directories and/or files.
#' @param path_val Path to validation data. See `path` argument for details.
#' @param path_checkpoint Path to checkpoints folder or `NULL`. If `NULL`, checkpoints don't get stored.
#' @param path_tensorboard Path to tensorboard directory or `NULL`. If `NULL`, training not tracked on tensorboard.
#' @param train_val_ratio For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration
#' processes \code{batch_size} \eqn{*} \code{steps_per_epoch} \eqn{*} \code{train_val_ratio} samples. If you use dataset instead of generator and \code{dataset_val} is `NULL`, splits \code{dataset}
#' into train/validation data.
#' @param run_name Name of the run. Name will be used to identify output from callbacks.
#' @param batch_size Number of samples used for one network update.
#' @param epochs Number of iterations.
#' @param steps_per_epoch Number of training batches per epoch.
#' @param shuffle_file_order Boolean, whether to go through files sequentially or shuffle beforehand.
#' @param initial_epoch Epoch at which to start training. Note that network
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds.
#' @param seed Sets seed for reproducible results.
#' @param file_limit Integer or `NULL`. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in \code{path}.
#' @param patchlen The length of a patch when splitting the input sequence.
#' @param nopatches The number of patches when splitting the input sequence. 
#' @param step Frequency of sampling steps.
#' @param stride The overlap between two patches when splitting the input sequence.
#' @param pretrained_model A pretrained keras model, for which training will be continued
#' @param learningrate A Tensor, floating point value. If a schedule is defines, this value gives the initial learning rate. Defaults to 0.001.
#' @param learningrate_schedule A schedule for a non-constant learning rate over the training. Either "cosine_annealing", "step_decay", or "exp_decay".
#' @param k Value of k for sparse top k categorical accuracy. Defaults to 5.
#' @param stepsmin In CPC, a patch is predicted given another patch. stepsmin defines how many patches between these two should be ignored during prediction.
#' @param stepsmax The maximum distance between the predicted patch and the given patch.
#' @param emb_scale Scales the impact of a patches context.
#' @examplesIf reticulate::py_module_available("tensorflow")
#' 
#' #create dummy data
#' path_train_1 <- tempfile()
#' path_train_2 <- tempfile()
#' path_val_1 <- tempfile()
#' path_val_2 <- tempfile()
#' 
#' for (current_path in c(path_train_1, path_train_2,
#'                        path_val_1, path_val_2)) {
#'   dir.create(current_path)
#'   deepG::create_dummy_data(file_path = current_path,
#'                            num_files = 3,
#'                            seq_length = 10,
#'                            num_seq = 5,
#'                            vocabulary = c("a", "c", "g", "t"))
#' }
#' 
#' # create model
#' encoder <- function(maxlen = NULL,
#'                     patchlen = NULL,
#'                     nopatches = NULL,
#'                     eval = FALSE) {
#'   if (is.null(nopatches)) {
#'     nopatches <- nopatchescalc(patchlen, maxlen, patchlen * 0.4)
#'   }
#'   inp <- keras::layer_input(shape = c(maxlen, 4))
#'   stridelen <- as.integer(0.4 * patchlen)
#'   createpatches <- inp %>%
#'     keras::layer_reshape(list(maxlen, 4L, 1L), name = "prep_reshape1", dtype = "float32") %>%
#'     tensorflow::tf$image$extract_patches(
#'       sizes = list(1L, patchlen, 4L, 1L),
#'       strides = list(1L, stridelen, 4L, 1L),
#'       rates = list(1L, 1L, 1L, 1L),
#'       padding = "VALID",
#'       name = "prep_patches"
#'     ) %>%
#'     keras::layer_reshape(list(nopatches, patchlen, 4L),
#'                          name = "prep_reshape2") %>%
#'     tensorflow::tf$reshape(list(-1L, patchlen, 4L),
#'                            name = "prep_reshape3")
#' 
#'   danQ <- createpatches %>%
#'     keras::layer_conv_1d(
#'       input_shape = c(maxlen, 4L),
#'       filters = 320L,
#'       kernel_size = 26L,
#'       activation = "relu"
#'     ) %>%
#'     keras::layer_max_pooling_1d(pool_size = 13L, strides = 13L) %>%
#'     keras::layer_dropout(0.2) %>%
#'     keras::layer_lstm(units = 320, return_sequences = TRUE) %>%
#'     keras::layer_dropout(0.5) %>%
#'     keras::layer_flatten() %>%
#'     keras::layer_dense(925, activation = "relu")
#'   patchesback <- danQ %>%
#'     tensorflow::tf$reshape(list(-1L, tensorflow::tf$cast(nopatches, tensorflow::tf$int16), 925L))
#'   keras::keras_model(inp, patchesback)
#' }
#' 
#' context <- function(latents) {
#'   cres <- latents
#'   cres_dim = cres$shape
#'   predictions <-
#'     cres %>%
#'     keras::layer_lstm(
#'       return_sequences = TRUE,
#'       units = 256,  # WAS: 2048,
#'       name = paste("context_LSTM_1",
#'                    sep = ""),
#'       activation = "relu"
#'     )
#'   return(predictions)
#' }
#' 
#' # train model
#' temp_dir <- tempdir()
#' hist <- train_model_cpc(train_type = "CPC",
#'                         ### cpc functions ###
#'                         encoder = encoder,
#'                         context = context,
#'                         #### Generator settings ####
#'                         path_checkpoint = temp_dir,
#'                         path = c(path_train_1, path_train_2),
#'                         path_val = c(path_val_1, path_val_2),
#'                         run_name = "TEST",
#'                         batch_size = 8,
#'                         epochs = 3,
#'                         steps_per_epoch = 6,
#'                         patchlen = 100,
#'                         nopatches = 8)
#'                 
#'  
#' @returns A list of training metrics.  
#' @export
train_model_cpc <-
  function(train_type = "CPC",
           ### cpc functions ###
           encoder = NULL,
           context = NULL,
           #### Generator settings ####
           path,
           path_val = NULL,
           path_checkpoint = NULL,
           path_tensorboard = NULL,
           train_val_ratio = 0.2,
           run_name,
           
           batch_size = 32,
           epochs = 100,
           steps_per_epoch = 2000,
           shuffle_file_order = FALSE,
           initial_epoch = 1,
           seed = 1234,
           
           path_file_log = TRUE,
           train_val_split_csv = NULL,
           file_limit = NULL,
           proportion_per_seq = NULL,
           max_samples = NULL,
           maxlen = NULL,
           
           patchlen = NULL,
           nopatches = NULL,
           step = NULL,
           file_filter = NULL,
           stride = 0.4,
           pretrained_model = NULL,
           learningrate = 0.001,
           learningrate_schedule = NULL,
           k = 5,
           stepsmin = 2,
           stepsmax = 3,
           emb_scale = 0.1) {
    
    # Stride is default 0.4 x patchlen FOR NOW
    stride <- 0.4
    
    patchlen <- as.integer(patchlen)
    
    ########################################################################################################
    ############################### Warning messages if wrong initialization ###############################
    ########################################################################################################
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Model specification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    ## Three options:
    ## 1. Define Maxlen and Patchlen
    ## 2. Define Number of patches and Patchlen
    ## ---> in both cases the respectively missing value will be calculated
    ## 3. Pretrained model is giving specs
    ## error if none of those is fulfilled
    
    if (is.null(pretrained_model)) {
      ## If no pretrained model, patchlen has to be defined
      if (is.null(patchlen)) {
        stop("Please define patchlen")
      }
      ## Either maxlen or number of patches is needed
      if (is.null(maxlen) & is.null(nopatches)) {
        stop("Please define either maxlen or nopatches")
        ## the respectively missing value will be calculated
      } else if (is.null(maxlen) & !is.null(nopatches)) {
        maxlen <- (nopatches - 1) * (stride * patchlen) + patchlen
      } else if (!is.null(maxlen) & is.null(nopatches)) {
        nopatches <-
          as.integer((maxlen - patchlen) / (stride * patchlen) + 1)
      }
      ## if step is not defined, we do not use overlapping sequences
      if (is.null(step)) {
        step = maxlen
      }
    } else if (!is.null(pretrained_model)) {
      specs <-
        readRDS(paste(
          sub("/[^/]+$", "", pretrained_model),
          "modelspecs.rds",
          sep = "/"
        ))
      patchlen          <- specs$patchlen
      maxlen            <- specs$maxlen
      nopatches         <- specs$nopatches
      stride            <- specs$stride
      step              <- specs$step
      k                 <- specs$k
      emb_scale         <- specs$emb_scale
    }
    
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Learning rate schedule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    ## If learning_rate schedule is wanted, all necessary parameters must be given
    LRstop(learningrate_schedule)
    ########################################################################################################
    #################################### Preparation: Data, paths metrics ##################################
    ########################################################################################################
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Path definition ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    runname <-
      paste0(run_name , format(Sys.time(), "_%y%m%d_%H%M%S"))
    
    ## Create folder for model
    if (!is.null(path_checkpoint)) {
      dir.create(paste(path_checkpoint, runname, sep = "/"))
      dir <- paste(path_checkpoint, runname, sep = "/")
      ## Create folder for filelog
      path_file_log <-
        paste(path_checkpoint, runname, "filelog.csv", sep = "/")
    } else {
      path_file_log <- NULL
    }
    
    GenConfig <-
      GenParams(maxlen, batch_size, step, proportion_per_seq, max_samples)
    GenTConfig <-
      GenTParams(path, shuffle_file_order, path_file_log, seed)
    GenVConfig <- GenVParams(path_val, shuffle_file_order)
    
    # train train_val_ratio via csv file
    if (!is.null(train_val_split_csv)) {
      if (is.null(path_val)) {
        path_val <- path
      } else {
        if (!all(unlist(path_val) %in% unlist(path))) {
          warning("Train/validation split done via file in train_val_split_csv. Only using files from path argument.")
        }
        path_val <- path
      }
      
      train_val_file <- utils::read.csv2(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE)
      if (dim(train_val_file)[2] == 1) {
        train_val_file <- utils::read.csv(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE)
      }
      train_val_file <- dplyr::distinct(train_val_file)
      
      if (!all(c("file", "type") %in% names(train_val_file))) {
        stop("Column names of train_val_split_csv file must be 'file' and 'type'")
      }
      
      if (length(train_val_file$file) != length(unique(train_val_file$file))) {
        stop("In train_val_split_csv all entires in 'file' column must be unique")
      }
      
      file_filter <- list()
      file_filter[[1]] <- train_val_file %>% dplyr::filter(type == "train")
      file_filter[[1]] <- as.character(file_filter[[1]]$file)
      file_filter[[2]] <- train_val_file %>% dplyr::filter(type == "val" | type == "validation")
      file_filter[[2]] <- as.character(file_filter[[2]]$file)
    }
    
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ File count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    if (is.null(file_filter) && is.null(train_val_split_csv)) {
      if (is.null(file_limit)) {
        if (is.list(path)) {
          num_files <- 0
          for (i in seq_along(path)) {
            num_files <- num_files + length(list.files(path[[i]]))
          }
        } else {
          num_files <- length(list.files(path))
        }
      } else {
        num_files <- file_limit * length(path)
      }
    } else {
      num_files <- length(file_filter[1])
    }
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Creation of generators ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    message(format(Sys.time(), "%F %R"), ": Preparing the data\n")
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training Generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    fastrain <-
      do.call(generator_fasta_lm,
              c(GenConfig, GenTConfig, file_filter = file_filter[1]))
    
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validation Generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    fasval <-
      do.call(
        generator_fasta_lm,
        c(
          GenConfig,
          GenVConfig,
          seed = seed,
          file_filter = file_filter[2]
        )
      )
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Creation of metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    message(format(Sys.time(), "%F %R"), ": Preparing the metrics\n")
    train_loss <- tensorflow::tf$keras$metrics$Mean(name = 'train_loss')
    val_loss <- tensorflow::tf$keras$metrics$Mean(name = 'val_loss')
    train_acc <- tensorflow::tf$keras$metrics$Mean(name = 'train_acc')
    val_acc <- tensorflow::tf$keras$metrics$Mean(name = 'val_acc')
    
    ########################################################################################################
    ###################################### History object preparation ######################################
    ########################################################################################################
    
    history <- list(
      params = list(
        batch_size = batch_size,
        epochs = 0,
        steps = steps_per_epoch,
        samples = steps_per_epoch * batch_size,
        verbose = 1,
        do_validation = TRUE,
        metrics = c("loss", "accuracy", "val_loss", "val_accuracy")
      ),
      metrics = list(
        loss = c(),
        accuracy = c(),
        val_loss = c(),
        val_accuracy = c()
      )
    )
    
    eploss <- list()
    epacc <- list()
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reformat to S3 object ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    class(history) <- "keras_training_history"
    
    ########################################################################################################
    ############################################ Model creation ############################################
    ########################################################################################################
    if (is.null(pretrained_model)) {
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unsupervised Build from scratch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
      message(format(Sys.time(), "%F %R"), ": Creating the model\n")
      ## Build encoder
      enc <-
        encoder(maxlen = maxlen,
                patchlen = patchlen,
                nopatches = nopatches)
      
      ## Build model
      model <-
        keras::keras_model(
          enc$input,
          cpcloss(
            enc$output,
            context,
            batch_size = batch_size,
            steps_to_ignore = stepsmin,
            steps_to_predict = stepsmax,
            train_type = train_type,
            k = k,
            emb_scale = emb_scale
          )
        )
      
      ## Build optimizer
      optimizer <- # keras::optimizer_adam(
        tensorflow::tf$keras$optimizers$legacy$Adam(
          learning_rate = learningrate,
          beta_1 = 0.8,
          epsilon = 10 ^ -8,
          decay = 0.999,
          clipnorm = 0.01
        )
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~ Unsupervised Read if pretrained model given ~~~~~~~~~~~~~~~~~~~~~~~~~####
      
    } else {
      message(format(Sys.time(), "%F %R"), ": Loading the trained model.\n")
      ## Read model
      model <- keras::load_model_hdf5(pretrained_model, compile = FALSE)
      optimizer <- ReadOpt(pretrained_model)
      optimizer$learning_rate$assign(learningrate)
    }
    
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving necessary model objects ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
    ## optimizer configuration
    
    if (!is.null(path_checkpoint)) {
      saveRDS(optimizer$get_config(),
              paste(dir, "optconfig.rds", sep = "/"))
      ## model parameters
      saveRDS(
        list(
          maxlen = maxlen,
          patchlen = patchlen,
          stride = stride,
          nopatches = nopatches,
          step = step,
          batch_size = batch_size,
          epochs = epochs,
          steps_per_epoch = steps_per_epoch,
          train_val_ratio = train_val_ratio,
          max_samples = max_samples,
          k = k,
          emb_scale = emb_scale,
          learningrate = learningrate
        ),
        paste(dir, "modelspecs.rds", sep = "/")
      )
    }
    ########################################################################################################
    ######################################## Tensorboard connection ########################################
    ########################################################################################################
    
    if (!is.null(path_tensorboard)) {
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Initialize Tensorboard writers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
      logdir <- path_tensorboard
      writertrain <-
        tensorflow::tf$summary$create_file_writer(file.path(logdir, runname, "/train"))
      writerval <-
        tensorflow::tf$summary$create_file_writer(file.path(logdir, runname, "/validation"))
      
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Write parameters to Tensorboard ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
      tftext <-
        lapply(as.list(match.call())[-1][-c(1, 2)], function(x)
          ifelse(all(nchar(deparse(
            eval(x)
          )) < 20) && !is.null(eval(x)), eval(x), deparse(x)))
      
      with(writertrain$as_default(), {
        tensorflow::tf$summary$text("Specification",
                                    paste(
                                      names(tftext),
                                      tftext,
                                      sep = " = ",
                                      collapse = "  \n"
                                    ),
                                    step = 0L)
      })
    }
    
    ########################################################################################################
    ######################################## Training loop function ########################################
    ########################################################################################################
    
    train_val_loop <-
      function(batches = steps_per_epoch, epoch, train_val_ratio) {
        ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Start of loop ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
        for (i in c("train", "val")) {
          if (i == "val") {
            ## Calculate steps for validation
            batches <- ceiling(batches * train_val_ratio)
          }
          
          for (b in seq(batches)) {
            if (i == "train") {
              ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
              ## If Learning rate schedule specified, calculate learning_rate for current epoch
              if (!is.null(learningrate_schedule)) {
                optimizer$learning_rate$assign(getEpochLR(learningrate_schedule, epoch))
              }
              ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Optimization step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
              
              #with(tensorflow::tf$GradientTape() %as% tape, {
              with(reticulate::`%as%`(tensorflow::tf$GradientTape(), tape), {
                
                out <-
                  modelstep(fastrain(),
                            model,
                            train_type,
                            TRUE)
                l <- out[1]
                acc <- out[2]
              })
              
              gradients <-
                tape$gradient(l, model$trainable_variables)
              optimizer$apply_gradients(purrr::transpose(list(
                gradients, model$trainable_variables
              )))
              train_loss(l)
              train_acc(acc)
              
            } else {
              ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validation step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
              out <-
                modelstep(fasval(),
                          model,
                          train_type,
                          FALSE)
              
              l <- out[1]
              acc <- out[2]
              val_loss(l)
              val_acc(acc)
              
            }
            
            ## Print status of epoch
            if (b %in% seq(0, batches, by = batches / 10)) {
              message("-")
            }
          }
          
          ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Epoch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
          if (i == "train") {
            ## Training step
            # Write epoch result metrics value to tensorboard
            if (!is.null(path_tensorboard)) {
              TB_loss_acc(writertrain, train_loss, train_acc, epoch)
              with(writertrain$as_default(), {
                tensorflow::tf$summary$scalar('epoch_lr',
                                              optimizer$learning_rate,
                                              step = tensorflow::tf$cast(epoch, "int64"))
                tensorflow::tf$summary$scalar(
                  'training files seen',
                  nrow(
                    readr::read_csv(
                      path_file_log,
                      col_names = FALSE,
                      col_types = readr::cols()
                    )
                  ) / num_files,
                  step = tensorflow::tf$cast(epoch, "int64")
                )
              })
            }
            # Print epoch result metric values to console
            tensorflow::tf$print(" Train Loss",
                                 train_loss$result(),
                                 ", Train Acc",
                                 train_acc$result())
            
            # Save epoch result metric values to history object
            history$params$epochs <- epoch
            history$metrics$loss[epoch] <-
              as.double(train_loss$result())
            history$metrics$accuracy[epoch]  <-
              as.double(train_acc$result())
            
            # Reset states
            train_loss$reset_states()
            train_acc$reset_states()
            
          } else {
            ## Validation step
            # Write epoch result metrics value to tensorboard
            if (!is.null(path_tensorboard)) {
              TB_loss_acc(writerval, val_loss, val_acc, epoch)
            }
            
            # Print epoch result metric values to console
            tensorflow::tf$print(" Validation Loss",
                                 val_loss$result(),
                                 ", Validation Acc",
                                 val_acc$result())
            
            # save results for best model saving condition
            if (b == max(seq(batches))) {
              eploss[[epoch]] <- as.double(val_loss$result())
              epacc[[epoch]] <-
                as.double(val_acc$result())
            }
            
            # Save epoch result metric values to history object
            history$metrics$val_loss[epoch] <-
              as.double(val_loss$result())
            history$metrics$val_accuracy[epoch]  <-
              as.double(val_acc$result())
            
            # Reset states
            val_loss$reset_states()
            val_acc$reset_states()
          }
        }
        return(list(history,eploss,epacc))
      }
    
    ########################################################################################################
    ############################################# Training run #############################################
    ########################################################################################################
    
    
    message(format(Sys.time(), "%F %R"), ": Starting Training\n")
    
    ## Training loop
    for (i in seq(initial_epoch, (epochs + initial_epoch - 1))) {
      message(format(Sys.time(), "%F %R"), ": EPOCH ", i, " \n")
      
      ## Epoch loop
      out <- train_val_loop(epoch = i, train_val_ratio = train_val_ratio)
      history <- out[[1]]
      eploss <- out[[2]]
      epacc <- out[[3]]
      ## Save checkpoints
      # best model (smallest loss)
      if (eploss[[i]] == min(unlist(eploss))) {
        savechecks("best", runname, model, optimizer, history, path_checkpoint)
      }
      # backup model every 10 epochs
      if (i %% 2 == 0) {
        savechecks("backup", runname, model, optimizer, history, path_checkpoint)
      }
    }
    
    ########################################################################################################
    ############################################# Final saves ##############################################
    ########################################################################################################
    
    savechecks(cp = "FINAL", runname, model, optimizer, history, path_checkpoint)
    if (!is.null(path_tensorboard)) {
      writegraph <-
        tensorflow::tf$keras$callbacks$TensorBoard(file.path(logdir, runname))
      writegraph$set_model(model)
    }
  }
GenomeNet/deepG documentation built on Dec. 24, 2024, 12:11 p.m.