R/DLmodel.R

#' DLmodel Class
#'
#' @docType class
#' @importFrom R6 R6Class
#'
#' @export
#' @keywords data
#'
#' @return Object of class \code{\link{R6Class}} and \code{DLmodel}.
#'
#' @format \code{\link{R6Class}} object.
#'
#' @examples
#' DLmodel$new()
#'
#' @section Methods:
#' \describe{
#'   \item{Documentation}{For full documentation of each method follow the corresponding link. }
#'   \item{\code{initialize(...)}}{Create new object. Documented in \link{DLmodel.initialize}.}
#'   \item{\code{update(...)}}{Update fields of the object. Documented in \link{DLmodel.update}.}
#'   \item{\code{summary()}}{Model summary. Documented in \link{DLmodel.summary}.}
#'   \item{\code{log(level = c("DEBUG", "INFO", "WARNING", "ERROR"), message = '...')}}{Add a message to the object log. Documented in \link{DLmodel.log}.}
#'   \item{\code{add_to_history(epoch = 0, subepoch = 0, time = Sys.time(), loss = NA, val_loss = NA)}}{Adds a new entry to the loss history of the model. Documented in \link{DLmodel.add_to_history}.}
#'   \item{\code{plot_history()}}{Plots this object's loss history. Documented in \link{DLmodel.plot_history}.}
#'   \item{\code{reset_history()}}{Deletes all loss history of this object. Documented in \link{DLmodel.reset_history}.}
#'   \item{\code{render_history(initialize = TRUE)}}{Uses RStudio Viewer pane to plot loss history. Documented in \link{DLmodel.render_history}.}
#'   \item{\code{update_render()}}{Updates the render of loss history in Viewer pane. Documented in \link{DLmodel.update_render}.}
#'   \item{\code{get_model()}}{Returns the keras model. Documented in \link{DLmodel.get_model}.}
#'   \item{\code{get_width()}}{Window width used to train the model. Documented in \link{DLmodel.get_width}.}
#'   \item{\code{set_width(width)}}{Sets window width for the model. Documented in \link{DLmodel.set_width}.}
#'   \item{\code{get_loss()}}{Returns the best loss achieved by the model, or \code{Inf} if it hasn't been trained yet. Documented in \link{DLmodel.get_loss}.}
#'   \item{\code{set_loss(loss)}}{Sets the loss of the model. Documented in \link{DLmodel.set_loss}.}
#'   \item{\code{get_encoder()}}{If the model is an autoencoder, returns the encoder part. Documented in \link{DLmodel.get_encoder}.}
#'   \item{\code{get_decoder()}}{If the model is an autoencoder, returns the decoder part. Documented in \link{DLmodel.get_decoder}.}
#'   \item{\code{get_config()}}{Returns the configuration of the model. From it, one can re-build the whole model. Documented in \link{DLmodel.get_config}.}
#'   \item{\code{get_history()}}{Returns the loss history as a data.frame. Documented in \link{DLmodel.get_history}.}
#'   \item{\code{print_log(level = c("DEBUG", "WARNING", "INFO", "ERROR"))}}{Prints this object's log. Documented in \link{DLmodel.print_log}.}
#'   \item{\code{errors()}}{Prints errors produced when using this object. Documented in \link{DLmodel.errors}.}
#'   \item{\code{warnings()}}{Prints warnings produced when using this object. Documented in \link{DLmodel.warnings}.}
#'   \item{\code{save_log(filename, level = c("DEBUG", "WARNING", "INFO", "ERROR"))}}{Saves log to a file. Documented in \link{DLmodel.save_log}.}
#'   \item{\code{check_memory()}}{Checks whether the model can be trained given current memory limits. Documented in \link{DLmodel.check_memory}.}
#'   \item{\code{graph()}}{Returns the graph of the model. Documented in \link{DLmodel.graph}.}
#'   \item{\code{plot(to_file)}}{Plots this model's graph. Documented in \link{DLmodel.plot}.}
#'   \item{\code{load(path, prefix)}}{Load a model stored in a given path, given file prefix. Documented in \link{DLmodel.load}.}
#'   \item{\code{save(path, prefix, comment)}}{Saves the model in a given path, given file prefix. Documented in \link{DLmodel.save}.}
#'   \item{\code{use_data(use = c("train", "test"), x_files, y_files = NULL, target_windows_per_file = 1024)}}{Assigns data for training r testing, to be used when fitting the model. Documented in \link{DLmodel.use_data}.}
#'   \item{\code{fit(epochs = 10, keep_best = TRUE, metrics_viewer = FALSE, ...)}}{Trains the model. Documented in \link{DLmodel.fit}.}
#'   \item{\code{reset()}}{Resets the model to the original (untrained) state. Documented in \link{DLmodel.reset}.}
#'   \item{\code{infer(V = NULL, speed = c("faster", "medium", "slower"))}}{Run inference over a volume. Documented in \link{DLmodel.infer}.}
#'   \item{\code{clone(deep = FALSE)}}{Clones the model. Documented in \link{DLmodel.clone}.}
#'  }
DLmodel <- R6::R6Class(
  
  classname = "DLmodel",
  
  public = list(
    
    initialize = function(...) {
      
      load_keras()
      suppressPackageStartupMessages(require(tidyverse))
      suppressPackageStartupMessages(require(ggvis))
      
      args <- list(...)
      
      if (!is.null(args$model)) {
        
        private$model <- args$model
        
      }
      
      if (!is.null(args$width)) {
        
        private$width <- args$width
        
      }
      
      if (!is.null(args$best_loss)) {
        
        private$best_loss <- args$best_loss
        
      }
      
      if (!is.null(args$encoder)) {
        
        private$encoder <- args$encoder
        
      }
      
      if (!is.null(args$decoder)) {
        
        private$decoder <- args$decoder
        
      }
      
      if (!is.null(args$hyperparameters)) {
        
        private$hyperparameters <- args$hyperparameters
        
      }
      
    },
    
    update = function(...) {
      
      self$initialize(...)
      
    },
    
    summary = function() {
      
      summary(private$model)
      
    },

    log = function(level = c("DEBUG", "INFO", "WARNING", "ERROR"), message = "...") {
      
      line_to_add <- paste0("(", format(Sys.time(), "%Y-%m-%d %H:%M:%S"), ") [",
                            level[1], "] ",
                            message)
      
      private$log_lines <- c(private$log_lines, line_to_add)
      
    },
    
    add_to_history = function(epoch = 0, 
                              subepoch = 0, 
                              time = Sys.time(), 
                              loss = NA, 
                              val_loss = NA) {
      
      private$history <- rbind(private$history,
                               data.frame(epoch = epoch,
                                          subepoch = subepoch,
                                          time = time,
                                          loss = loss,
                                          val_loss = val_loss))
      
    },
    
    plot_history = function() {
      
      require(ggvis)
      
      h <- private$reformat_history()
      
      span <- ifelse(nrow(h) < 25, 0.5, 0.2)
      
      if (length(h$val_loss[h$val_loss >= 0]) == 1) {
        
        h %>% 
          ggvis(~epochs, ~loss) %>% 
          layer_lines(stroke := "darkgray") %>% 
          layer_smooths(span = span, se = TRUE, fill := "blue", stroke := "darkblue") %>% 
          layer_points(x = ~epochs[val_loss >= 0],  y = ~val_loss[val_loss >= 0], fill := "red", stroke := "darkred") %>% 
          add_axis("x", title = "epochs") %>% 
          add_axis("y", title = "loss") %>% 
          private$add_title(title = "Losses History")
        
      } else {
        
        h %>% 
          ggvis(~epochs, ~loss) %>% 
          layer_lines(stroke := "darkgray") %>% 
          layer_smooths(span = span, se = TRUE, fill := "blue", stroke := "darkblue") %>% 
          layer_lines(x = ~epochs[val_loss >= 0],  y = ~val_loss[val_loss >= 0], stroke := "darkred") %>% 
          add_axis("x", title = "epochs") %>% 
          add_axis("y", title = "loss") %>% 
          private$add_title(title = "Losses History")
        
      }
      
    },
    
    reset_history = function() {
      
      private$history <- data.frame(epoch = NULL, subepoch = NULL, time = NULL, loss = NULL, val_loss = NULL)
      
    },
    
    render_history = function(initialize = TRUE) {
      
      suppressPackageStartupMessages(require(googleVis) & require(rstudioapi))
      
      h <- private$reformat_history()
      
      foo <- gvisLineChart(h, 
                           xvar = "epochs", 
                           yvar = c("loss", "val_loss"),
                           options = list(width = 480,
                                          height = 640,
                                          title = "Loss History",
                                          titleTextStyle = "{color:'red',fontSize:16}"))
      
      cat(renderGvis(foo)(), file = private$render_file)
      
      if (initialize)
        viewer(url = private$render_file)
      
    },
    
    update_render = function() {
      
      self$render_history(initialize = FALSE)
      
    },
    
    get_model = function() {
      
      return(private$model)
      
    },
    
    is_autoencoder = function() {
      
      return(!(is.null(private$encoder)) && !(is.null(private$decoder)))
      
    },

    get_width = function() {
      
      return(private$width)
      
    },
    
    set_width = function(width) {
      
      private$width <- width
      
    },
    
    get_loss = function() {
      
      return(private$best_loss)
      
    },
    
    set_loss = function(loss) {
      
      private$best_loss <- loss
      
    },
    
    get_encoder = function() {
      
      return(private$encoder)
      
    },
    
    get_decoder = function() {
      
      return(private$decoder)
      
    },
    
    get_config = function() {
      
      return(private$hyperparameters)
      
    },
    
    get_history = function() {
      
      return(private$history)
      
    },
    
    print_log = function(level = c("DEBUG", "WARNING", "INFO", "ERROR")) {
      
      lines <- private$get_log_lines(level)
      
      cat(lines, sep = "\n")
      
    },
    
    errors = function() {
      
      self$print_log(level = "ERROR")
      
    },
    
    warnings = function() {
      
      self$print_log(level = "WARNING")
      
    },
    
    save_log = function(filename, level = c("DEBUG", "WARNING", "INFO", "ERROR")) {
      
      lines <- private$get_log_lines(level)
      
      cat(lines, sep = "\n", file = filename)
      
    },
    
    check_memory = function() {
      
      batch_size <- self %>% compute_batch_size()
      
      # If batch_size == 0, there is no possibility of training with the specified memory limit.
      if (batch_size < 1) {
        
        required_memory <- prettyunits::pretty_bytes(unclass(self %>% model_size() * 4))
        
        # Not enough memory to train even 1 batch at a time
        error_message <- paste0("Not enough memory to train this model. Optimal batch size is 0 for the memory limit: ", 
                                prettyunits::pretty_bytes(private$hyperparameters$memory_limit), "\n",
                                "This model requires at least ", required_memory, " to be trained.\n",
                                "We suggest to increase this limit by adding: memory_limit = ", 
                                required_memory, " to the scheme.\n")
        
        self$log("ERROR", message = error_message)
        
        stop(error_message)
        
      } else {
        
        return(invisible(batch_size))
        
      }
      
    },
    
    raise_memory_limit = function(memory_limit = "4G") {
      
      private$hyperparameters$memory_limit <- convert_to_bytes(memory_limit)
      
    },
    
    graph = function(mode = "graphviz") {
      
      self %>% graph_from_model(mode = mode)
      
    },
    
    plot = function(to_file) {
      
      self %>% plot_model(to_file = to_file)
      
    },
    
    load = function(path, prefix) {
      
      self <- load_model(path = path, prefix = prefix)
      
    },
    
    save = function(path, prefix, comment = "") {
      
      self %>% save_model(path, prefix, comment)
      
    },
    
    use_data = function(use = c("train", "test"),
                        x_files, 
                        y_files = NULL,
                        target_windows_per_file = 1024) {
      
      gen <- self %>% create_generator(x_files = x_files, 
                                       y_files = y_files,
                                       mode = "sampling",
                                       target_windows_per_file = target_windows_per_file)
      
      if (use[1] == "train") {
        
        private$train_config <- gen
        private$train_files <- list(x = x_files, y = y_files)
        self$has_train_data <- TRUE
        
      } else {
        
        private$test_config <- gen
        private$test_files <- list(x = x_files, y = y_files)
        
      }
      
    },
    
    fit = function(epochs = 10, 
                   keep_best = TRUE,
                   metrics_viewer = FALSE,
                   ...) {
      
      args <- list(...)
      
      verbose <- FALSE
      if (!is.null(args$verbose)) verbose <- args$verbose

      tmp_path <- tempfile()
      
      if (keep_best) {
        
        # Path and prefix must be provided
        warn <- FALSE
        
        if (is.null(args$path)) {
          
          warn <- TRUE
          path <- dirname(tmp_path)
          
        } else {
          
          path <- args$path
          
        }
        
        if (!file.exists(path)) {
          
          dir.create(path, showWarnings = FALSE, recursive = TRUE)
          
        }
        
        if (is.null(args$prefix)) {
          
          prefix <- basename(tmp_path)
          
          
        } else {
          
          prefix <- args$prefix
          
        }
          
        if (warn) {
          
          self$log("WARNING", message = paste0("No path provided when training. Using ", path, "."))
          
          message <- "When training with keep_best == TRUE, at least path must be provided."
          warning(message)
          
        }
        
      } else {
        
        path <- dirname(tmp_path)
        prefix <- basename(tmp_path)
        
      }
      
      if (!self$has_train_data) {
        
        self$log("ERROR", message = "No training configuration provided.")
        stop("No training configuration provided. Employ 'use_data' function to provide a training configuration.")
        
      }
      
      self %>% fit_with_generator(train_config = private$train_config,
                                  validation_config = private$test_config,
                                  epochs = epochs,
                                  starting_epoch = private$last_epoch + 1,
                                  keep_best = keep_best,
                                  path = path,
                                  prefix = prefix,
                                  metrics_viewer = metrics_viewer,
                                  verbose = verbose)
      
      private$last_epoch <- max(private$history$epoch)
      
    },
    
    reset = function() {
      
      self %>% reset_model()
      
    },
    
    infer = function(V = NULL, 
                     speed = c("faster", "medium", "slower"), ...) {
      
      self %>% infer_on_volume(V = V, speed = speed[1], ...)
      
    },
    
    has_train_data = FALSE
    
  ),
  
  private = list(
    
    model = NULL,
    train_config = NULL,
    test_config = NULL,
    
    train_files = NULL,
    test_files = NULL,
    
    last_epoch = 0,
    
    width = NULL,
    
    best_loss = Inf,
    
    encoder = NULL,
    decoder = NULL,
    
    hyperparameters = NULL,
    
    log_lines = c(),
    
    history = data.frame(epoch = NULL, subepoch = NULL, time = NULL, loss = NULL, val_loss = NULL),
    
    render_file = tempfile(tmpdir = normalizePath("~"), 
                           pattern = "dl4ni-metrics_", 
                           fileext = ".html"),
    
    get_log_lines = function(level = c("DEBUG", "WARNING", "INFO", "ERROR")) {
      
      lines <- c()
      
      if ("DEBUG" %in% level)
        level <- c("DEBUG", "INFO", "WARNING", "ERROR")
      
      for (i in seq_along(level)) {
        
        lines <- c(lines, grep(private$log_lines, pattern = level[i]))
        
      }
      
      lines <- sort(lines)
      
      lines <- private$log_lines[lines]
      
      return(lines)
      
    },
    
    # ggvis lacks a plot title function, so add one.
    # based on clever hack by tonytonov
    # http://stackoverflow.com/a/25030002/1135316
    add_title = function(vis, ..., properties = NULL, title = "Plot Title") {
      
      # recursively merge lists by name
      # http://stackoverflow.com/a/13811666/1135316
      merge.lists <- function(a, b) {
        a.names <- names(a)
        b.names <- names(b)
        m.names <- sort(unique(c(a.names, b.names)))
        sapply(m.names, function(i) {
          if (is.list(a[[i]]) & is.list(b[[i]])) merge.lists(a[[i]], b[[i]])
          else if (i %in% b.names) b[[i]]
          else a[[i]]
        }, simplify = FALSE)
      }
      
      # default properties make title 'axis' invisible
      default.props <- axis_props(
        ticks = list(strokeWidth = 0),
        axis = list(strokeWidth = 0),
        labels = list(fontSize = 0),
        grid = list(strokeWidth = 0)
      )
      
      # merge the default properties with user-supplied props.
      axis.props <- do.call(axis_props, merge.lists(default.props, properties))
      
      # don't step on existing scales.
      vis <- scale_numeric(vis, "title", domain = c(0,1), range = 'width')
      axis <- ggvis:::create_axis('x', 'title', orient = "top",  title = title, properties = axis.props, ...)
      ggvis:::append_ggvis(vis, "axes", axis)
      
    },
    
    reformat_history = function() {
      
      suppressPackageStartupMessages(require(tidyverse))
      
      max_sub_epoch <- max(private$history %>% select(subepoch) %>% unlist())
      
      history2 <- private$history %>% 
        mutate(epochs = (epoch - 1) + (subepoch - 1) / max_sub_epoch)
      
      losses <- history2 %>% select(epochs, loss) %>% filter(loss >= 0)
      val_losses <- history2 %>% select(epochs, val_loss) %>% filter(val_loss >= 0)
      h <- full_join(losses, val_losses, by = "epochs")
      if (nrow(h) >= 2)
        if (all(h[1, ] == h[2, ]))
          h <- h[-1, ]
      
      return(h)
      
    }
    
  )
  
)
neuroimaginador/dl4ni documentation built on May 3, 2019, 5:47 p.m.