R/TEClassifierProtoNet.R

# This file is part of the R package "aifeducation".
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as published by
# the Free Software Foundation.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>

#' @title Text embedding classifier with a ProtoNet
#' @description Abstract class for neural nets with 'keras'/'tensorflow' and 'pytorch'.
#'
#'   This object represents in implementation of a prototypical network for few-shot learning as described by Snell,
#'   Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The
#'   network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
#'
#' @return Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
#'   classifier an object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] and a `factor` are necessary. The
#'   object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] contains the numerical text representations (text
#'   embeddings) of the raw texts generated by an object of class [TextEmbeddingModel]. The `factor` contains the
#'   classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
#'   class [EmbeddedText] or [LargeDataSetForTextEmbeddings] has to be used which was created with the same
#'   [TextEmbeddingModel] as for training.
#'
#' @references Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved
#'   few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
#' @references Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning.
#'   https://doi.org/10.48550/arXiv.1703.05175
#' @references Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H.
#'   Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery
#'   and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing.
#'   https://doi.org/10.1007/978-3-030-16145-3_24
#' @family Classification
#' @export
TEClassifierProtoNet <- R6::R6Class(
  classname = "TEClassifierProtoNet",
  inherit = TEClassifierRegular,
  public = list(
    # New-----------------------------------------------------------------------
    #' @description Creating a new instance of this class.
    #' @param ml_framework `string` Currently only pytorch is supported (`ml_framework="pytorch"`).
    #' @param name `string` Name of the new classifier. Please refer to common name conventions. Free text can be used
    #'   with parameter `label`.
    #' @param label `string` Label for the new classifier. Here you can use free text.
    #' @param text_embeddings An object of class [TextEmbeddingModel] or [LargeDataSetForTextEmbeddings].
    #' @param target_levels `vector` containing the levels (categories or classes) within the target data. Please not
    #'   that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels
    #'   indicating a higher category/class. For nominal data the order does not matter.
    #' @param feature_extractor Object of class [TEFeatureExtractor] which should be used in order to reduce the number
    #'   of dimensions of the text embeddings. If no feature extractor should be applied set `NULL`.
    #' @param dense_layers `int` Number of dense layers.
    #' @param dense_size `int` Number of neurons for each dense layer.
    #' @param rec_layers `int` Number of recurrent layers.
    #' @param rec_size `int` Number of neurons for each recurrent layer.
    #' @param rec_type `string` Type of the recurrent layers.`rec_type="gru"` for Gated Recurrent Unit and
    #'   `rec_type="lstm"` for Long Short-Term Memory.
    #' @param rec_bidirectional `bool` If `TRUE` a bidirectional version of the recurrent layers is used.
    #' @param embedding_dim `int` determining the number of dimensions for the text embedding.
    #' @param attention_type `string` Choose the relevant attention type. Possible values are `"fourier"` and
    #'   `"multihead"`. Please note that you may see different values for a case for different input orders if you choose `fourier` on linux.
    #' @param self_attention_heads `int` determining the number of attention heads for a self-attention layer. Only
    #'   relevant if `attention_type="multihead"`.
    #' @param repeat_encoder `int` determining how many times the encoder should be added to the network.
    #' @param intermediate_size `int` determining the size of the projection layer within a each transformer encoder.
    #' @param add_pos_embedding `bool` `TRUE` if positional embedding should be used.
    #' @param encoder_dropout `double` ranging between 0 and lower 1, determining the dropout for the dense projection
    #'   within the encoder layers.
    #' @param dense_dropout `double` ranging between 0 and lower 1, determining the dropout between dense layers.
    #' @param rec_dropout `double` ranging between 0 and lower 1, determining the dropout between bidirectional
    #'   recurrent layers.
    #' @param recurrent_dropout `double` ranging between 0 and lower 1, determining the recurrent dropout for each
    #'   recurrent layer. Only relevant for keras models.
    #' @param optimizer `string` `"adam"` or `"rmsprop"` .
    #' @return Returns an object of class [TEClassifierProtoNet] which is ready for training.
    configure = function(ml_framework = "pytorch",
                         name = NULL,
                         label = NULL,
                         text_embeddings = NULL,
                         feature_extractor = NULL,
                         target_levels = NULL,
                         dense_size = 4,
                         dense_layers=0,
                         rec_size = 4,
                         rec_layers=2,
                         rec_type = "gru",
                         rec_bidirectional = FALSE,
                         embedding_dim = 2,
                         self_attention_heads = 0,
                         intermediate_size = NULL,
                         attention_type = "fourier",
                         add_pos_embedding = TRUE,
                         rec_dropout = 0.1,
                         repeat_encoder = 1,
                         dense_dropout = 0.4,
                         recurrent_dropout = 0.4,
                         encoder_dropout = 0.1,
                         optimizer = "adam") {
      # check if already configured
      private$check_config_for_FALSE()

      # Checking of parameters--------------------------------------------------
      if ((ml_framework %in% c("pytorch")) == FALSE) {
        stop("ml_framework must be 'pytorch'.")
      }
      private$check_embeddings_object_type(text_embeddings, strict = TRUE)
      check_type(name, type = "string", FALSE)
      check_type(label, type = "string", FALSE)
      check_type(target_levels, c("vector"), FALSE)
      check_type(dense_size, type = "int", FALSE)
      check_type(dense_layers, type = "int", FALSE)
      if(dense_layers>0){
        if(dense_size<1){
          stop("Dense layers added. Size for dense layers must be at least 1.")
        }
      }

      check_type(rec_size, type = "int", FALSE)
      check_type(rec_layers, type = "int", FALSE)
      if(rec_layers>0){
        if(rec_size<1){
          stop("Recurrent  layers added. Size for recurrent layers must be at least 1.")
        }
      }
      check_type(self_attention_heads, type = "int", FALSE)
      if (optimizer %in% c("adam", "rmsprop") == FALSE) {
        stop("Optimzier must be 'adam' oder 'rmsprop'.")
      }
      check_type(embedding_dim, "int", FALSE)
      if (embedding_dim <= 0) {
        stop("Embedding_dim must be an integer greater 0.")
      }
      if (attention_type %in% c("fourier", "multihead") == FALSE) {
        stop("Optimzier must be 'fourier' oder 'multihead'.")
      }
      if (repeat_encoder > 0 & attention_type == "multihead" & self_attention_heads <= 0) {
        stop("Encoder layer is set to 'multihead'. This requires self_attention_heads>=1.")
      }

      #------------------------------------------------------------------------
      # Setting ML Framework
      private$ml_framework <- ml_framework

      # Setting Label and Name
      private$set_model_info(
        model_name_root = name,
        model_id = generate_id(16),
        label = label,
        model_date = date()
      )

      # Set TextEmbeddingModel
      private$set_text_embedding_model(
        model_info = text_embeddings$get_model_info(),
        feature_extractor_info = text_embeddings$get_feature_extractor_info(),
        times = text_embeddings$get_times(),
        features = text_embeddings$get_features()
      )

      # Saving Configuration
      config <- list(
        use_fe = FALSE,
        features = private$text_embedding_model[["features"]],
        times = private$text_embedding_model[["times"]],
        dense_size = dense_size,
        dense_layers=dense_layers,
        rec_size = rec_size,
        rec_layers=rec_layers,
        rec_type = rec_type,
        rec_bidirectional = rec_bidirectional,
        intermediate_size = intermediate_size,
        attention_type = attention_type,
        repeat_encoder = repeat_encoder,
        dense_dropout = dense_dropout,
        rec_dropout = rec_dropout,
        recurrent_dropout = recurrent_dropout,
        encoder_dropout = encoder_dropout,
        add_pos_embedding = add_pos_embedding,
        optimizer = optimizer,
        act_fct = "gelu",
        rec_act_fct = "tanh",
        embedding_dim = embedding_dim,
        self_attention_heads = self_attention_heads
      )

      # Basic Information of Input and Target Data
      variable_name_order <- dimnames(text_embeddings$embeddings)[[3]]
      config["target_levels"] <- list(target_levels)
      config["n_categories"] <- list(length(target_levels))
      config["input_variables"] <- list(variable_name_order)
      if (length(target_levels) > 2) {
        # Multi Class
        config["act_fct_last"] <- "softmax"
        config["err_fct"] <- "categorical_crossentropy"
        config["metric"] <- "categorical_accuracy"
        config["balanced_metric"] <- "balanced_accuracy"
      } else {
        # Binary Classification
        config["act_fct_last"] <- "sigmoid"
        config["err_fct"] <- "binary_crossentropy"
        config["metric"] <- "binary_accuracy"
        config["balanced_metric"] <- "balanced_accuracy"
      }

      config["require_one_hot"] <- list(FALSE)

      # if(length(rec)>0|repeat_encoder>0){
      #  config["require_matrix_map"]=list(FALSE)
      # } else {
      config["require_matrix_map"] <- list(TRUE)
      # }

      self$model_config <- config

      # Adjust configuration
      private$adjust_configuration()

      # Set FeatureExtractor and adapt config
      self$check_feature_extractor_object_type(feature_extractor)
      private$set_feature_extractor(feature_extractor)

      # Set package versions
      private$set_package_versions()

      # Finalize configuration
      private$set_configuration_to_TRUE()

      # Create_Model
      private$create_reset_model()
    },

    #-------------------------------------------------------------------------
    #' @description Method for training a neural net.
    #'
    #' Training includes a routine for early stopping. In the case that loss<0.0001
    #' and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values
    #' of the last trained epoch for the remaining epochs.
    #'
    #' After training the model with the best values for Average Iota, Accuracy, and Loss
    #' on the validation data set is used as the final model.
    #'
    #' @param data_embeddings Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings].
    #' @param data_targets `factor` containing the labels for cases stored in `data_embeddings`. Factor must be named
    #'   and has to use the same names used in `data_embeddings`.
    #' @param data_folds `int` determining the number of cross-fold samples.
    #' @param data_val_size `double` between 0 and 1, indicating the proportion of cases of each class which should be
    #'   used for the validation sample during the estimation of the model. The remaining cases are part of the training
    #'   data.
    #' @param balance_class_weights `bool` If `TRUE` class weights are generated based on the frequencies of the
    #'   training data with the method Inverse Class Frequency'. If `FALSE` each class has the weight 1.
    #' @param balance_sequence_length `bool` If `TRUE` sample weights are generated for the length of sequences based on
    #'   the frequencies of the training data with the method Inverse Class Frequency'. If `FALSE` each sequences length
    #'   has the weight 1.
    #' @param use_sc `bool` `TRUE` if the estimation should integrate synthetic cases. `FALSE` if not.
    #' @param sc_method `vector` containing the method for generating synthetic cases. Possible are `sc_method="adas"`,
    #'   `sc_method="smote"`, and `sc_method="dbsmote"`.
    #' @param sc_min_k `int` determining the minimal number of k which is used for creating synthetic units.
    #' @param sc_max_k `int` determining the maximal number of k which is used for creating synthetic units.
    #' @param use_pl `bool` `TRUE` if the estimation should integrate pseudo-labeling. `FALSE` if not.
    #' @param pl_max_steps `int` determining the maximum number of steps during pseudo-labeling.
    #' @param pl_anchor `double` between 0 and 1 indicating the reference point for sorting the new cases of every
    #'   label. See notes for more details.
    #' @param pl_max `double` between 0 and 1, setting the maximal level of confidence for considering a case for
    #'   pseudo-labeling.
    #' @param pl_min `double` between 0 and 1, setting the minimal level of confidence for considering a case for
    #'   pseudo-labeling.
    #' @param sustain_track `bool` If `TRUE` energy consumption is tracked during training via the python library
    #'   'codecarbon'.
    #' @param sustain_iso_code `string` ISO code (Alpha-3-Code) for the country. This variable must be set if
    #'   sustainability should be tracked. A list can be found on Wikipedia:
    #'   <https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes>.
    #' @param sustain_region Region within a country. Only available for USA and Canada See the documentation of
    #'   codecarbon for more information. <https://mlco2.github.io/codecarbon/parameters.html>
    #' @param sustain_interval `int` Interval in seconds for measuring power usage.
    #' @param batch_size `int` Size of the batches for training.
    #' @param epochs `int` Number of training epochs.
    #' @param Ns `int` Number of cases for every class in the sample.
    #' @param Nq `int` Number of cases for every class in the query.
    #' @param loss_alpha `double` Value between 0 and 1 indicating how strong the loss should focus on pulling cases to
    #'   its corresponding prototypes or pushing cases away from other prototypes. The higher the value the more the
    #'   loss concentrates on pulling cases to its corresponding prototypes.
    #' @param loss_margin `double` Value greater 0 indicating the minimal distance of every case from prototypes of
    #'   other classes
    #' @param sampling_separate `bool` If `TRUE` the cases for every class are divided into a data set for sample and for query.
    #'    These are never mixed. If `TRUE` sample and query cases are drawn from the same data pool. That is, a case can be
    #'    part of sample in one epoch and in another epoch it can be part of query. It is ensured that a case is never part of
    #'    sample and query at the same time. In addition, it is ensured that every cases exists only once during
    #'    a training step.
    #' @param sampling_shuffle `bool` If `TRUE` cases a randomly drawn from the data during every step. If `FALSE`
    #'    the cases are not shuffled.
    #' @param dir_checkpoint `string` Path to the directory where the checkpoint during training should be saved. If the
    #'   directory does not exist, it is created.
    #' @param log_dir `string` Path to the directory where the log files should be saved. If no logging is desired set
    #'   this argument to `NULL`.
    #' @param log_write_interval `int` Time in seconds determining the interval in which the logger should try to update
    #'   the log files. Only relevant if `log_dir` is not `NULL`.
    #' @param trace `bool` `TRUE`, if information about the estimation phase should be printed to the console.
    #' @param ml_trace `int` `ml_trace=0` does not print any information about the training process from pytorch on the
    #'   console.
    #' @param n_cores `int` Number of cores which should be used during the calculation of synthetic cases. Only relevant if
    #'  `use_sc=TRUE`.
    #' @return Function does not return a value. It changes the object into a trained classifier.
    #' @details
    #'
    #' * `sc_max_k`: All values from sc_min_k up to sc_max_k are successively used. If
    #' the number of `sc_max_k` is too high, the value is reduced to a number that allows the calculating of synthetic
    #' units.
    #' * `pl_anchor:` With the help of this value, the new cases are sorted. For
    #' this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
    #'
    train = function(data_embeddings,
                     data_targets,
                     data_folds = 5,
                     data_val_size = 0.25,
                     use_sc = TRUE,
                     sc_method = "dbsmote",
                     sc_min_k = 1,
                     sc_max_k = 10,
                     use_pl = TRUE,
                     pl_max_steps = 3,
                     pl_max = 1.00,
                     pl_anchor = 1.00,
                     pl_min = 0.00,
                     sustain_track = TRUE,
                     sustain_iso_code = NULL,
                     sustain_region = NULL,
                     sustain_interval = 15,
                     epochs = 40,
                     batch_size = 35,
                     Ns = 5,
                     Nq = 3,
                     loss_alpha = 0.5,
                     loss_margin = 0.5,
                     sampling_separate = FALSE,
                     sampling_shuffle = TRUE,
                     dir_checkpoint,
                     trace = TRUE,
                     ml_trace = 1,
                     log_dir = NULL,
                     log_write_interval = 10,
                     n_cores=auto_n_cores()) {
      # Checking Arguments------------------------------------------------------
      check_type(data_folds, type = "int", FALSE)
      check_type(data_val_size, type = "double", FALSE)
      check_type(use_sc, type = "bool", FALSE)
      check_type(sc_method, type = "string", FALSE)
      check_type(sc_min_k, type = "int", FALSE)
      check_type(sc_max_k, type = "int", FALSE)
      check_type(use_pl, type = "bool", FALSE)
      check_type(pl_max_steps, type = "int", FALSE)
      check_type(pl_max, type = "double", FALSE)
      check_type(pl_anchor, type = "double", FALSE)
      check_type(pl_min, type = "double", FALSE)
      check_type(sustain_track, type = "bool", FALSE)
      check_type(sustain_iso_code, type = "string", TRUE)
      check_type(sustain_region, type = "string", TRUE)
      check_type(sustain_interval, type = "int", FALSE)
      check_type(epochs, type = "int", FALSE)
      check_type(batch_size, type = "int", FALSE)
      check_type(dir_checkpoint, type = "string", FALSE)
      check_type(trace, type = "bool", FALSE)
      check_type(n_cores, type = "int", FALSE)

      check_type(Ns, type = "int", FALSE)
      check_type(Nq, type = "int", FALSE)
      check_type(loss_alpha, type = "double", FALSE)
      check_type(loss_margin, type = "double", FALSE)

      check_class(data_embeddings, c("EmbeddedText", "LargeDataSetForTextEmbeddings"), FALSE)
      self$check_embedding_model(data_embeddings)

      if (is.null(names(data_targets))) {
        stop("data_targets must be a named factor.")
      }
      data_targets <- private$check_and_adjust_target_levels(data_targets)

      if (pl_anchor < pl_min) {
        stop("pl_anchor must be at least pl_min.")
      }
      if (pl_anchor > pl_max) {
        stop("pl_anchor must be lower or equal to pl_max.")
      }

      if (data_folds < 2) {
        stop("data_folds must be at least 2.")
      }

      # Saving training configuration-------------------------------------------
      self$last_training$config$data_val_size <- data_val_size
      self$last_training$config$use_sc <- use_sc
      self$last_training$config$sc_method <- sc_method
      self$last_training$config$sc_min_k <- sc_min_k
      self$last_training$config$sc_max_k <- sc_max_k
      self$last_training$config$use_pl <- use_pl
      self$last_training$config$pl_max_steps <- pl_max_steps
      self$last_training$config$pl_max <- pl_max
      self$last_training$config$pl_anchor <- pl_anchor
      self$last_training$config$pl_min <- pl_min
      self$last_training$config$sustain_track <- sustain_track
      self$last_training$config$sustain_iso_code <- sustain_iso_code
      self$last_training$config$sustain_region <- sustain_region
      self$last_training$config$sustain_interval <- sustain_interval
      self$last_training$config$epochs <- epochs
      self$last_training$config$batch_size <- batch_size
      self$last_training$config$Ns <- Ns
      self$last_training$config$Nq <- Nq
      self$last_training$config$loss_alpha <- loss_alpha
      self$last_training$config$loss_margin <- loss_margin
      self$last_training$config$dir_checkpoint <- dir_checkpoint
      self$last_training$config$trace <- trace
      self$last_training$config$ml_trace <- ml_trace

      self$last_training$config$n_cores<-n_cores

      self$last_training$config$sampling_separate <- sampling_separate
      self$last_training$config$sampling_shuffle <- sampling_shuffle

      private$log_config$log_dir <- log_dir
      private$log_config$log_state_file <- paste0(private$log_config$log_dir, "/aifeducation_state.log")
      private$log_config$log_write_interval <- log_write_interval

      # Start-------------------------------------------------------------------
      if (self$last_training$config$trace == TRUE) {
        message(paste(
          date(),
          "Start"
        ))
      }

      # Set up data
      if ("EmbeddedText" %in% class(data_embeddings)) {
        data <- data_embeddings$convert_to_LargeDataSetForTextEmbeddings()
      } else {
        data <- data_embeddings
      }

      # Create DataManager------------------------------------------------------
      if (self$model_config$use_fe == TRUE) {
        compressed_embeddings <- self$feature_extractor$extract_features_large(
          data_embeddings = data,
          as.integer(self$last_training$config$batch_size),
          trace = self$last_training$config$trace
        )
        data_manager <- DataManagerClassifier$new(
          data_embeddings = compressed_embeddings,
          data_targets = data_targets,
          folds = data_folds,
          val_size = self$last_training$config$data_val_size,
          class_levels = self$model_config$target_levels,
          one_hot_encoding = self$model_config$require_one_hot,
          add_matrix_map = (self$model_config$require_matrix_map == TRUE | self$last_training$config$use_sc == TRUE),
          sc_method = sc_method,
          sc_min_k = sc_min_k,
          sc_max_k = sc_max_k,
          trace = trace,
          self$last_training$config$n_cores
        )
      } else {
        data_manager <- DataManagerClassifier$new(
          data_embeddings = data,
          data_targets = data_targets,
          folds = data_folds,
          val_size = self$last_training$config$data_val_size,
          class_levels = self$model_config$target_levels,
          one_hot_encoding = self$model_config$require_one_hot,
          add_matrix_map = (self$model_config$require_matrix_map == TRUE | self$last_training$config$use_sc == TRUE),
          sc_method = sc_method,
          sc_min_k = sc_min_k,
          sc_max_k = sc_max_k,
          trace = trace,
          self$last_training$config$n_cores
        )
      }

      # Save Data Statistics
      self$last_training$data <- data_manager$get_statistics()

      # Save the number of folds
      self$last_training$config$n_folds <- data_manager$get_n_folds()

      # Init Training------------------------------------------------------------
      private$init_train()

      # disable Progress bars
      datasets$disable_progress_bars()

      # Start Sustainability Tracking-------------------------------------------
      if (sustain_track == TRUE) {
        if (is.null(sustain_iso_code) == TRUE) {
          stop("Sustainability tracking is activated but iso code for the
               country is missing. Add iso code or deactivate tracking.")
        }

        tmp_code_carbon<-reticulate::import("codecarbon")
        codecarbon_version=as.character(tmp_code_carbon["__version__"])
        if(utils::compareVersion(codecarbon_version,"2.8.0")>=0){
          path_look_file=codecarbon$lock$LOCKFILE
          if(file.exists(path_look_file)){
            unlink(path_look_file)
          }
        }

        sustainability_tracker <- codecarbon$OfflineEmissionsTracker(
          country_iso_code = sustain_iso_code,
          region = sustain_region,
          tracking_mode = "machine",
          log_level = "warning",
          measure_power_secs = sustain_interval,
          save_to_file = FALSE,
          save_to_api = FALSE
        )
        sustainability_tracker$start()
      }

      # Start Training----------------------------------------------------------
      # Load Custom Model Scripts
      private$load_reload_python_scripts()

      # Start Loop inclusive final training
      for (iter in 1:(self$last_training$config$n_folds + 1)) {
        base::gc(verbose = FALSE, full = TRUE)

        if (self$last_training$config$use_pl == FALSE) {
          private$train_standard(
            iteration = iter,
            data_manager = data_manager,
            inc_synthetic = self$last_training$config$use_sc
          )
        } else if (self$last_training$config$use_pl == TRUE) {
          private$train_with_pseudo_labels(
            init_train = TRUE,
            iteration = iter,
            data_manager = data_manager,
            inc_synthetic = self$last_training$config$use_sc
          )
        }

        # Calculate measures on categorical level
        private$calculate_measures_on_categorical_level(
          data_manager = data_manager,
          iteration = iter
        )
      }

      # Finalize Training
      private$finalize_train()

      # Stop sustainability tracking if requested
      if (sustain_track == TRUE) {
        sustainability_tracker$stop()
        private$sustainability <- summarize_tracked_sustainability(sustainability_tracker)
      } else {
        private$sustainability <- list(
          sustainability_tracked = FALSE,
          date = NA,
          sustainability_data = list(
            duration_sec = NA,
            co2eq_kg = NA,
            cpu_energy_kwh = NA,
            gpu_energy_kwh = NA,
            ram_energy_kwh = NA,
            total_energy_kwh = NA
          )
        )
      }

      if (self$last_training$config$trace == TRUE) {
        message(paste(
          date(),
          "Training Complete"
        ))
      }
    },
    #---------------------------------------------------------------------------
    #' @description Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of
    #'   texts created by an object of class [TextEmbeddingModel]. These embeddings embed documents according to their
    #'   similarity to specific classes.
    #' @param embeddings_q Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] containing the text
    #'   embeddings for all cases which should be embedded into the classification space.
    #' @param batch_size `int` batch size.
    #' @return Returns a `list` containing the following elements
    #'
    #' * `embeddings_q`: embeddings for the cases (query sample).
    #' * `embeddings_prototypes`: embeddings of the prototypes which were learned during training. They represents the
    #' center for the different classes.
    #'
    embed = function(embeddings_q = NULL, batch_size = 32) {
      check_class(embeddings_q, c("EmbeddedText", "LargeDataSetForTextEmbeddings"), FALSE)
      check_type(batch_size, "int", FALSE)

      # Check input for compatible text embedding models and feature extractors
      if ("EmbeddedText" %in% class(embeddings_q)) {
        self$check_embedding_model(text_embeddings = embeddings_q)
        requires_compression <- self$requires_compression(embeddings_q)
      } else if ("array" %in% class(embeddings_q)) {
        requires_compression <- self$requires_compression(embeddings_q)
      } else {
        requires_compression <- FALSE
      }

      # Load Custom Model Scripts
      private$load_reload_python_scripts()

      # Check number of cases in the data
      single_prediction <- private$check_single_prediction(embeddings_q)

      # Get current row names/name of the cases
      current_row_names <- private$get_rownames_from_embeddings(embeddings_q)

      # Apply feature extractor if it is part of the model
      if (requires_compression == TRUE) {
        # Returns a data set
        embeddings_q <- self$feature_extractor$extract_features(
          data_embeddings = embeddings_q,
          batch_size = as.integer(batch_size)
        )
      }


      # If at least two cases are part of the data set---------------------------
      if (single_prediction == FALSE) {
        # Returns a data set object
        prediction_data_q_embeddings <- private$prepare_embeddings_as_dataset(embeddings_q)

        if (private$ml_framework == "pytorch") {
          prediction_data_q_embeddings$set_format("torch")
          embeddings_and_distances <- py$TeProtoNetBatchEmbedDistance(
            model = self$model,
            dataset_q = prediction_data_q_embeddings,
            batch_size = as.integer(batch_size)
          )
          embeddings_tensors_q <- private$detach_tensors(embeddings_and_distances[[1]])
          distances_tensors_q <- private$detach_tensors(embeddings_and_distances[[2]])
        }
      } else {
        prediction_data_q_embeddings <- private$prepare_embeddings_as_np_array(embeddings_q)

        # Apply feature extractor if it is part of the model
        if (requires_compression == TRUE) {
          # Returns a data set
          prediction_data_q <- np$array(self$feature_extractor$extract_features(
            data_embeddings = prediction_data_q_embeddings,
            batch_size = as.integer(batch_size)
          )$embeddings)
        }

        if (private$ml_framework == "pytorch") {
          if (torch$cuda$is_available()) {
            device <- "cuda"
            dtype <- torch$double
            self$model$to(device, dtype = dtype)
            self$model$eval()
            input <- torch$from_numpy(prediction_data_q_embeddings)
            embeddings_tensors_q <- self$model$embed(input$to(device, dtype = dtype))
            embeddings_tensors_q <- private$detach_tensors(embeddings_tensors_q)
            distances_tensors_q <- self$model$get_distances(input$to(device, dtype = dtype))
            distances_tensors_q <- private$detach_tensors(distances_tensors_q)
          } else {
            device <- "cpu"
            dtype <- torch$float
            self$model$to(device, dtype = dtype)
            self$model$eval()
            input <- torch$from_numpy(prediction_data_q_embeddings)
            embeddings_tensors_q <- self$model$embed(input$to(device, dtype = dtype))
            embeddings_tensors_q <- private$detach_tensors(embeddings_tensors_q)
            distances_tensors_q <- self$model$get_distances(input$to(device, dtype = dtype))
            distances_tensors_q <- private$detach_tensors(distances_tensors_q)
          }
        }
      }

      if (private$ml_framework == "pytorch") {
        embeddings_prototypes <- private$detach_tensors(
          self$model$get_trained_prototypes()
        )
      }

      # Post processing
      rownames(embeddings_tensors_q) <- current_row_names
      rownames(distances_tensors_q) <- current_row_names
      rownames(embeddings_prototypes) <- self$model_config$target_levels

      return(list(
        embeddings_q = embeddings_tensors_q,
        distances_q = distances_tensors_q,
        embeddings_prototypes = embeddings_prototypes
      ))
    },
    #---------------------------------------------------------------------------
    #' @description Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
    #' @param embeddings_q Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] containing the text
    #'   embeddings for all cases which should be embedded into the classification space.
    #' @param classes_q Named `factor` containg the true classes for every case. Please note that the names must match
    #'   the names/ids in `embeddings_q`.
    #'@param inc_unlabeled `bool` If `TRUE` plot includes unlabeled cases as data points.
    #'@param size_points `int` Size of the points excluding the points for prototypes.
    #'@param size_points_prototypes `int` Size of points representing prototypes.
    #'@param alpha `float` Value indicating how transparent the points should be (important
    #'   if many points overlap). Does not apply to points representing prototypes.
    #' @param batch_size `int` batch size.
    #' @return Returns a plot of class `ggplot`visualizing embeddings.
    plot_embeddings = function(embeddings_q,
                               classes_q=NULL,
                               batch_size = 12,
                               alpha = 0.5,
                               size_points = 3,
                               size_points_prototypes = 8,
                               inc_unlabeled = TRUE) {
      # Argument checking-------------------------------------------------------


      embeddings <- self$embed(
        embeddings_q = embeddings_q,
        batch_size = batch_size
      )

      prototypes <- as.data.frame(embeddings$embeddings_prototypes)
      prototypes$class <- rownames(embeddings$embeddings_prototypes)
      prototypes$type <- rep("prototype", nrow(embeddings$embeddings_prototypes))
      colnames(prototypes) <- c("x", "y", "class", "type")


      if(!is.null(classes_q)){
        true_values_names <- intersect(
          x = names(na.omit(classes_q)),
          y = private$get_rownames_from_embeddings(embeddings_q)
        )
        true_values <- as.data.frame(embeddings$embeddings_q[true_values_names, , drop = FALSE])
        true_values$class <- classes_q[true_values_names]
        true_values$type <- rep("labeled", length(true_values_names))
        colnames(true_values) <- c("x", "y", "class", "type")
      } else {
        true_values_names=NULL
        true_values=NULL
      }


      if (inc_unlabeled == TRUE) {
        estimated_values_names <- setdiff(
          x = private$get_rownames_from_embeddings(embeddings_q),
          y = true_values_names
        )

        if (length(estimated_values_names) > 0) {
          estimated_values <- as.data.frame(embeddings$embeddings_q[estimated_values_names, , drop = FALSE])
          estimated_values$class <- private$calc_classes_on_distance(
            distance_matrix = embeddings$distances_q[estimated_values_names, , drop = FALSE],
            prototypes = embeddings$embeddings_prototypes
          )
          estimated_values$type <- rep("unlabeled", length(estimated_values_names))
          colnames(estimated_values) <- c("x", "y", "class", "type")
        }
      } else {
        estimated_values_names <- NULL
      }


      plot_data=prototypes
      if (length(true_values) > 0) {
        plot_data=rbind(plot_data,true_values)
      }
      if (length(estimated_values_names) > 0) {
        plot_data <- rbind(plot_data,estimated_values)
      }

      plot <- ggplot2::ggplot(data = plot_data) +
        ggplot2::geom_point(
          mapping = ggplot2::aes(
            x = x,
            y = y,
            color = class,
            shape = type,
            size = type,
            alpha = type
          )#,
          #position = ggplot2::position_jitter(h = 0.1, w = 0.1)
        ) +
        ggplot2::scale_size_manual(values = c(
          "prototype" = size_points_prototypes,
          "labeled" = size_points,
          "unlabeled" = size_points
        )) +
        ggplot2::scale_alpha_manual(
          values = c(
            "prototype" = 1,
            "labeled" = alpha,
            "unlabeled" = alpha
          )
        ) +
        ggplot2::theme_classic()
      return(plot)
    }
  ),
  private = list(
    #-------------------------------------------------------------------------
    calc_classes_on_distance = function(distance_matrix, prototypes) {
      index_vector <- vector(length = nrow(distance_matrix))

      for (i in 1:length(index_vector)) {
        index_vector[i] <- which.min(distance_matrix[i, ])
      }

      classes <- factor(index_vector,
        levels = 1:nrow(prototypes),
        labels = rownames(prototypes)
      )
      return(classes)
    },
    #--------------------------------------------------------------------------
    load_reload_python_scripts = function() {
      if (private$ml_framework == "tensorflow") {

      } else if (private$ml_framework == "pytorch") {
        reticulate::py_run_file(system.file("python/pytorch_te_classifier.py",
          package = "aifeducation"
        ))
        reticulate::py_run_file(system.file("python/pytorch_te_protonet.py",
          package = "aifeducation"
        ))
        reticulate::py_run_file(system.file("python/py_log.py",
          package = "aifeducation"
        ))
      }
    },
    #--------------------------------------------------------------------------
    create_reset_model = function() {
      private$check_config_for_TRUE()
      if (private$ml_framework == "tensorflow") {

      } else {
        #--------------------------------------------------------------------------
        # Load Custom Pytorch Objects and Functions
        private$load_reload_python_scripts()
        self$model <- py$TextEmbeddingClassifierProtoNet_PT(
          features = as.integer(self$model_config$features),
          times = as.integer(self$model_config$times),
          dense_size = as.integer(self$model_config$dense_size),
          dense_layers=as.integer(self$model_config$dense_layers),
          rec_size = as.integer(self$model_config$rec_size),
          rec_layers=as.integer(self$model_config$rec_layers),
          rec_type = self$model_config$rec_type,
          rec_bidirectional = self$model_config$rec_bidirectional,
          intermediate_size = as.integer(self$model_config$intermediate_size),
          attention_type = self$model_config$attention_type,
          repeat_encoder = as.integer(self$model_config$repeat_encoder),
          dense_dropout = self$model_config$dense_dropout,
          rec_dropout = self$model_config$rec_dropout,
          encoder_dropout = self$model_config$encoder_dropout,
          add_pos_embedding = self$model_config$add_pos_embedding,
          self_attention_heads = as.integer(self$model_config$self_attention_heads),
          embedding_dim = as.integer(self$model_config$embedding_dim),
          target_levels = reticulate::np_array(seq(from = 0, to = (length(self$model_config$target_levels) - 1)))
        )
      }
    },
    #--------------------------------------------------------------------------
    basic_train = function(train_data = NULL,
                           val_data = NULL,
                           test_data = NULL,
                           reset_model = FALSE,
                           use_callback = TRUE,
                           log_dir = NULL,
                           log_write_interval = 10,
                           log_top_value = NULL,
                           log_top_total = NULL,
                           log_top_message = NULL) {
      # Clear session to provide enough resources for computations
      if (private$ml_framework == "tensorflow") {
        keras$backend$clear_session()
      } else if (private$ml_framework == "pytorch") {
        if (torch$cuda$is_available()) {
          torch$cuda$empty_cache()
        }
      }

      # Reset model if requested
      if (reset_model == TRUE) {
        private$create_reset_model()
      }

      # Set Optimizer
      if (private$ml_framework == "tensorflow") {
        balanced_metric <- py$BalancedAccuracy(n_classes = as.integer(length(self$model_config$target_levels)))
        if (self$model_config$optimizer == "adam") {
          self$model$compile(
            loss = self$model_config$err_fct,
            optimizer = keras$optimizers$Adam(),
            metrics = c(self$model_config$metric, balanced_metric)
          )
        } else if (self$model_config$optimizer == "rmsprop") {
          self$model$compile(
            loss = self$model_config$err_fct,
            optimizer = keras$optimizers$RMSprop(),
            metrics = c(self$model_config$metric, balanced_metric)
          )
        }
      } else if (private$ml_framework == "pytorch") {
        loss_fct_name <- "CrossEntropyLoss"
        if (self$model_config$optimizer == "adam") {
          optimizer <- "adam"
        } else if (self$model_config$optimizer == "rmsprop") {
          optimizer <- "rmsprop"
        }
      }

      # Check directory for checkpoints
      create_dir(
        dir_path = paste0(self$last_training$config$dir_checkpoint, "/checkpoints"),
        trace = self$last_training$config$trace,
        msg = "Creating Checkpoint Directory")

      # Set target column
      if (self$model_config$require_one_hot == FALSE) {
        target_column <- "labels"
      } else {
        target_column <- "one_hot_encoding"
      }

      # Tensorflow - Callbacks and training
      if (private$ml_framework == "tensorflow") {
        if (use_callback == TRUE) {
          callback <- keras$callbacks$ModelCheckpoint(
            filepath = paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.h5"),
            monitor = paste0("val_", self$model_config$balanced_metric),
            verbose = as.integer(min(self$last_training$config$ml_trace, 1)),
            mode = "auto",
            save_best_only = TRUE,
            save_weights_only = TRUE
          )
        } else {
          callback <- reticulate::py_none()
        }

        if (private$gui$shiny_app_active == TRUE) {
          private$load_reload_python_scripts()

          callback <- list(callback, py$ReportAiforeducationShiny())
        }

        data_set_weights <- datasets$Dataset$from_dict(
          reticulate::dict(list(
            sample_weights = sample_weights
          ))
        )
        # inputs, targets, sample_weights
        dataset_tf <- train_data$add_column("sample_weights", data_set_weights["sample_weights"])
        dataset_tf <- dataset_tf$rename_column("input", "input_embeddings")

        # Choose correct target column and rename
        dataset_tf <- dataset_tf$rename_column(target_column, "targets")

        dataset_tf$with_format("tf")
        tf_dataset_train <- dataset_tf$to_tf_dataset(
          columns = c("input_embeddings", "sample_weights"),
          batch_size = as.integer(self$last_training$config$batch_size),
          shuffle = TRUE,
          label_cols = "targets"
        )
        # Add sample weights
        tf_dataset_train <- tf_dataset_train$map(py$extract_sample_weight)

        dataset_tf_val <- val_data$rename_column("input", "input_embeddings")
        # Choose correct target column and rename
        dataset_tf_val <- dataset_tf_val$rename_column(target_column, "targets")

        tf_dataset_val <- dataset_tf_val$to_tf_dataset(
          columns = c("input_embeddings"),
          batch_size = as.integer(self$last_training$config$batch_size),
          shuffle = FALSE,
          label_cols = "targets"
        )

        history <- self$model$fit(
          verbose = as.integer(self$last_training$config$ml_trace),
          x = tf_dataset_train,
          validation_data = tf_dataset_val,
          epochs = as.integer(self$last_training$config$epochs),
          callbacks = callback,
          class_weight = reticulate::py_dict(keys = names(class_weights), values = class_weights)
        )$history

        if (self$model_config$n_categories == 2) {
          history <- list(
            loss = rbind(history$loss, history$val_loss),
            accuracy = rbind(history$binary_accuracy, history$val_binary_accuracy),
            balanced_accuracy = rbind(history$balanced_accuracy, history$val_balanced_accuracy)
          )
        } else {
          history <- list(
            loss = rbind(history$loss, history$val_loss),
            accuracy = rbind(history$categorical_accuracy, history$val_categorical_accuracy),
            balanced_accuracy = rbind(history$balanced_accuracy, history$val_balanced_accuracy)
          )
        }

        if (use_callback == TRUE) {
          self$model$load_weights(paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.h5"))
        }

        # PyTorch - Callbacks and training
      } else if (private$ml_framework == "pytorch") {
        dataset_train <- train_data$select_columns(c("input", target_column))
        if (self$model_config$require_one_hot == TRUE) {
          dataset_train <- dataset_train$rename_column(target_column, "labels")
        }

        pytorch_train_data <- dataset_train$with_format("torch")

        pytorch_val_data <- val_data$select_columns(c("input", target_column))
        if (self$model_config$require_one_hot == TRUE) {
          pytorch_val_data <- pytorch_val_data$rename_column(target_column, "labels")
        }
        pytorch_val_data <- pytorch_val_data$with_format("torch")

        if (!is.null(test_data)) {
          pytorch_test_data <- test_data$select_columns(c("input", target_column))
          if (self$model_config$require_one_hot == TRUE) {
            pytorch_test_data <- pytorch_test_data$rename_column(target_column, "labels")
          }
          pytorch_test_data <- pytorch_test_data$with_format("torch")
        } else {
          pytorch_test_data <- NULL
        }

        history <- py$TeClassifierProtoNetTrain_PT_with_Datasets(
          model = self$model,
          loss_fct_name = loss_fct_name,
          optimizer_method = self$model_config$optimizer,
          Ns = as.integer(self$last_training$config$Ns),
          Nq = as.integer(self$last_training$config$Nq),
          loss_alpha = self$last_training$config$loss_alpha,
          loss_margin = self$last_training$config$loss_margin,
          trace = as.integer(self$last_training$config$ml_trace),
          use_callback = use_callback,
          train_data = pytorch_train_data,
          val_data = pytorch_val_data,
          test_data = pytorch_test_data,
          epochs = as.integer(self$last_training$config$epochs),
          sampling_separate = self$last_training$config$sampling_separate,
          sampling_shuffle = self$last_training$config$sampling_shuffle,
          filepath = paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.pt"),
          n_classes = as.integer(length(self$model_config$target_levels)),
          log_dir = log_dir,
          log_write_interval = log_write_interval,
          log_top_value = log_top_value,
          log_top_total = log_top_total,
          log_top_message = log_top_message
        )
      }

      # provide rownames and replace -100
      history <- private$prepare_history_data(history)

      return(history)
    }
  )
)

Try the aifeducation package in your browser

Any scripts or data that you put into this service are public.

aifeducation documentation built on April 4, 2025, 2:01 a.m.