R/TEClassifierRegular.R

# This file is part of the R package "aifeducation".
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as published by
# the Free Software Foundation.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>

#' @title Text embedding classifier with a neural net
#' @description Abstract class for neural nets with 'keras'/'tensorflow' and ' pytorch'.
#'
#' @return Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
#'   classifier an object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] on the one hand and a [factor] on
#'   the other hand are necessary.
#'
#'   The object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings]  contains the numerical text representations
#'   (text embeddings) of the raw texts generated by an object of class [TextEmbeddingModel]. For supporting large data
#'   sets it is recommended to use [LargeDataSetForTextEmbeddings] instead of [EmbeddedText].
#'
#'   The `factor` contains the classes/categories for every text. Missing values (unlabeled cases) are supported and can
#'   be used for pseudo labeling.
#'
#'   For predictions an object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] has to be used which was
#'   created with the same [TextEmbeddingModel] as for training.
#'
#' @family Classification
#' @export
TEClassifierRegular <- R6::R6Class(
  classname = "TEClassifierRegular",
  inherit = AIFEBaseModel,
  public = list(
    #' @field feature_extractor ('list()')\cr
    #'   List for storing information and objects about the feature_extractor.
    feature_extractor = list(),

    #' @field reliability ('list()')\cr
    #'
    #'   List for storing central reliability measures of the last training.
    #'
    #'   * `reliability$test_metric`: Array containing the reliability measures for the test data for
    #'   every fold and step (in case of pseudo-labeling).
    #'   * `reliability$test_metric_mean`: Array containing the reliability measures for the test data.
    #'   The values represent the mean values for every fold.
    #'   * `reliability$raw_iota_objects`: List containing all iota_object generated with the package `iotarelr`
    #'   for every fold at the end of the last training for the test data.
    #'
    #'
    #'   * `reliability$raw_iota_objects$iota_objects_end`: List of objects with class `iotarelr_iota2` containing the
    #'   estimated iota reliability of the second generation for the final model for every fold for the test data.
    #'   *  `reliability$raw_iota_objects$iota_objects_end_free`: List of objects with class `iotarelr_iota2` containing
    #'   the estimated iota reliability of the second generation for the final model for every fold for the test data.
    #'   Please note that the model is estimated without forcing the Assignment Error Matrix to be in line with the
    #'   assumption of weak superiority.
    #'   * `reliability$iota_object_end`: Object of class `iotarelr_iota2` as a mean of the individual objects
    #'   for every fold for the test data.
    #'   * `reliability$iota_object_end_free`: Object of class `iotarelr_iota2` as a mean of the individual objects
    #'   for every fold. Please note that the model is estimated without forcing the Assignment Error Matrix to be in
    #'   line with the assumption of weak superiority.
    #'
    #'
    #'   * `reliability$standard_measures_end`: Object of class `list` containing the final measures for precision,
    #'   recall, and f1 for every fold.
    #'   * `reliability$standard_measures_mean`: `matrix` containing the mean measures for precision, recall, and f1.
    #'
    reliability = list(
      test_metric = NULL,
      test_metric_mean = NULL,
      raw_iota_objects = list(
        iota_objects_end = NULL,
        iota_objects_end_free = NULL
      ),
      iota_object_end = NULL,
      iota_object_end_free = NULL,
      standard_measures_end = NULL,
      standard_measures_mean = NULL
    ),

    # New-----------------------------------------------------------------------
    #' @description Creating a new instance of this class.
    #' @param ml_framework `string` Framework to use for training and inference. `ml_framework="tensorflow"` for
    #'   'tensorflow' and `ml_framework="pytorch"` for 'pytorch'
    #' @param name `string` Name of the new classifier. Please refer to common name conventions. Free text can be used
    #'   with parameter `label`.
    #' @param label `string` Label for the new classifier. Here you can use free text.
    #' @param text_embeddings An object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings].
    #' @param feature_extractor Object of class [TEFeatureExtractor] which should be used in order to reduce the number
    #'   of dimensions of the text embeddings. If no feature extractor should be applied set `NULL`.
    #' @param target_levels `vector` containing the levels (categories or classes) within the target data. Please not
    #'   that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels
    #'   indicating a higher category/class. For nominal data the order does not matter.
    #' @param dense_layers `int` Number of dense layers.
    #' @param dense_size `int` Number of neurons for each dense layer.
    #' @param rec_layers `int` Number of recurrent layers.
    #' @param rec_size `int` Number of neurons for each recurrent layer.
    #' @param rec_type `string` Type of the recurrent layers. `rec_type="gru"` for Gated Recurrent Unit and
    #'   `rec_type="lstm"` for Long Short-Term Memory.
    #' @param rec_bidirectional `bool` If `TRUE` a bidirectional version of the recurrent layers is used.
    #' @param attention_type `string` Choose the relevant attention type. Possible values are `fourier` and `multihead`. Please note
    #' that you may see different values for a case for different input orders if you choose `fourier` on linux.
    #' @param self_attention_heads `int` determining the number of attention heads for a self-attention layer. Only
    #'   relevant if `attention_type="multihead"`
    #' @param repeat_encoder `int` determining how many times the encoder should be added to the network.
    #' @param intermediate_size `int` determining the size of the projection layer within a each transformer encoder.
    #' @param add_pos_embedding `bool` `TRUE` if positional embedding should be used.
    #' @param encoder_dropout `int` ranging between 0 and lower 1, determining the dropout for the dense projection
    #'   within the encoder layers.
    #' @param dense_dropout `int` ranging between 0 and lower 1, determining the dropout between dense layers.
    #' @param rec_dropout `int` ranging between 0 and lower 1, determining the dropout between bidirectional recurrent
    #'   layers.
    #' @param recurrent_dropout `int` ranging between 0 and lower 1, determining the recurrent dropout for each
    #'   recurrent layer. Only relevant for keras models.
    #' @param optimizer `string` `"adam"` or `"rmsprop"` .
    #' @return Returns an object of class [TEClassifierRegular] which is ready for training.
    configure = function(ml_framework = "pytorch",
                         name = NULL,
                         label = NULL,
                         text_embeddings = NULL,
                         feature_extractor = NULL,
                         target_levels = NULL,
                         dense_size = 4,
                         dense_layers=0,
                         rec_size = 4,
                         rec_layers=2,
                         rec_type = "gru",
                         rec_bidirectional = FALSE,
                         self_attention_heads = 0,
                         intermediate_size = NULL,
                         attention_type = "fourier",
                         add_pos_embedding = TRUE,
                         rec_dropout = 0.1,
                         repeat_encoder = 1,
                         dense_dropout = 0.4,
                         recurrent_dropout = 0.4,
                         encoder_dropout = 0.1,
                         optimizer = "adam") {
      #check if already configured
      private$check_config_for_FALSE()

      # Checking of parameters--------------------------------------------------
      if ((ml_framework %in% c("tensorflow", "pytorch")) == FALSE) {
        stop("ml_framework must be 'tensorflow' or 'pytorch'.")
      }
      check_type(name, type = "string", FALSE)
      check_type(label, type = "string", FALSE)

      check_type(target_levels, c("vector"), FALSE)
      check_type(dense_size, type = "int", FALSE)
      check_type(dense_layers, type = "int", FALSE)
      if(dense_layers>0){
        if(dense_size<1){
          stop("Dense layers added. Size for dense layers must be at least 1.")
        }
      }

      check_type(rec_size, type = "int", FALSE)
      check_type(rec_layers, type = "int", FALSE)
      if(rec_layers>0){
        if(rec_size<1){
          stop("Recurrent  layers added. Size for recurrent layers must be at least 1.")
        }
      }

      check_type(self_attention_heads, type = "int", FALSE)
      if (optimizer %in% c("adam", "rmsprop") == FALSE) {
        stop("Optimzier must be 'adam' oder 'rmsprop'.")
      }
      if (attention_type %in% c("fourier", "multihead") == FALSE) {
        stop("Optimzier must be 'fourier' oder 'multihead'.")
      }
      if (repeat_encoder > 0 & attention_type == "multihead" & self_attention_heads <= 0) {
        stop("Encoder layer is set to 'multihead'. This requires self_attention_heads>=1.")
      }
      private$check_embeddings_object_type(text_embeddings, strict = TRUE)

      #------------------------------------------------------------------------
      # Set ML framework
      private$ml_framework <- ml_framework

      # Setting Label and Name
      private$set_model_info(
        model_name_root = name,
        model_id = generate_id(16),
        label = label,
        model_date = date()
      )

      # Set TextEmbeddingModel
      private$set_text_embedding_model(
        model_info = text_embeddings$get_model_info(),
        feature_extractor_info = text_embeddings$get_feature_extractor_info(),
        times = text_embeddings$get_times(),
        features = text_embeddings$get_features()
      )

      # Saving Configuration
      config <- list(
        use_fe = FALSE,
        features = private$text_embedding_model[["features"]],
        times = private$text_embedding_model[["times"]],
        dense_size = dense_size,
        dense_layers=dense_layers,
        rec_size = rec_size,
        rec_layers=rec_layers,
        rec_type = rec_type,
        rec_bidirectional = rec_bidirectional,
        intermediate_size = intermediate_size,
        attention_type = attention_type,
        repeat_encoder = repeat_encoder,
        dense_dropout = dense_dropout,
        rec_dropout = rec_dropout,
        recurrent_dropout = recurrent_dropout,
        encoder_dropout = encoder_dropout,
        add_pos_embedding = add_pos_embedding,
        optimizer = optimizer,
        act_fct = "gelu",
        rec_act_fct = "tanh",
        self_attention_heads = self_attention_heads
      )

      # Basic Information of Input and Target Data
      variable_name_order <- dimnames(text_embeddings$embeddings)[[3]]
      config["target_levels"] <- list(target_levels)
      config["n_categories"] <- list(length(target_levels))
      config["input_variables"] <- list(variable_name_order)
      if (length(target_levels) > 2) {
        # Multi Class
        config["act_fct_last"] <- "softmax"
        config["err_fct"] <- "categorical_crossentropy"
        config["metric"] <- "categorical_accuracy"
        config["balanced_metric"] <- "balanced_accuracy"
      } else {
        # Binary Classification
        config["act_fct_last"] <- "sigmoid"
        config["err_fct"] <- "binary_crossentropy"
        config["metric"] <- "binary_accuracy"
        config["balanced_metric"] <- "balanced_accuracy"
      }

      if (ml_framework == "tensorflow") {
        if (length(target_levels) > 2) {
          config["require_one_hot"] <- list(TRUE)
        } else {
          config["require_one_hot"] <- list(FALSE)
        }
      } else {
        config["require_one_hot"] <- list(TRUE)
      }

      if (rec_layers > 0 | repeat_encoder > 0) {
        config["require_matrix_map"] <- list(FALSE)
      } else {
        config["require_matrix_map"] <- list(TRUE)
      }

      self$model_config <- config

      # Adjust configuration
      private$adjust_configuration()

      # Set FeatureExtractor and adapt config
      self$check_feature_extractor_object_type(feature_extractor)
      private$set_feature_extractor(feature_extractor)

      # Set package versions
      private$set_package_versions()

      # Finalize configuration
      private$set_configuration_to_TRUE()

      # Create_Model
      private$create_reset_model()
    },

    #-------------------------------------------------------------------------
    #' @description Method for training a neural net.
    #'
    #' Training includes a routine for early stopping. In the case that loss<0.0001
    #' and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values
    #' of the last trained epoch for the remaining epochs.
    #'
    #' After training the model with the best values for Average Iota, Accuracy, and Loss
    #' on the validation data set is used as the final model.
    #'
    #' @param data_embeddings Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings].
    #' @param data_targets `factor` containing the labels for cases stored in `data_embeddings`. Factor must be named
    #'   and has to use the same names used in `data_embeddings`.
    #' @param data_folds `int` determining the number of cross-fold samples.
    #' @param data_val_size `double` between 0 and 1, indicating the proportion of cases of each class which should be
    #'   used for the validation sample during the estimation of the model. The remaining cases are part of the training
    #'   data.
    #' @param balance_class_weights `bool` If `TRUE` class weights are generated based on the frequencies of the
    #'   training data with the method Inverse Class Frequency'. If `FALSE` each class has the weight 1.
    #' @param balance_sequence_length `bool` If `TRUE` sample weights are generated for the length of sequences based on
    #'   the frequencies of the training data with the method Inverse Class Frequency'. If `FALSE` each sequences length
    #'   has the weight 1.
    #' @param use_sc `bool` `TRUE` if the estimation should integrate synthetic cases. `FALSE` if not.
    #' @param sc_method `vector` containing the method for generating synthetic cases. Possible are `sc_method="adas"`,
    #'   `sc_method="smote"`, and `sc_method="dbsmote"`.
    #' @param sc_min_k `int` determining the minimal number of k which is used for creating synthetic units.
    #' @param sc_max_k `int` determining the maximal number of k which is used for creating synthetic units.
    #' @param use_pl `bool` `TRUE` if the estimation should integrate pseudo-labeling. `FALSE` if not.
    #' @param pl_max_steps `int` determining the maximum number of steps during pseudo-labeling.
    #' @param pl_anchor `double` between 0 and 1 indicating the reference point for sorting the new cases of every
    #'   label. See notes for more details.
    #' @param pl_max `double` between 0 and 1, setting the maximal level of confidence for considering a case for
    #'   pseudo-labeling.
    #' @param pl_min `double` between 0 and 1, setting the minimal level of confidence for considering a case for
    #'   pseudo-labeling.
    #' @param sustain_track `bool` If `TRUE` energy consumption is tracked during training via the python library
    #'   'codecarbon'.
    #' @param sustain_iso_code `string` ISO code (Alpha-3-Code) for the country. This variable must be set if
    #'   sustainability should be tracked. A list can be found on Wikipedia:
    #'   <https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes>.
    #' @param sustain_region Region within a country. Only available for USA and Canada See the documentation of
    #'   codecarbon for more information. <https://mlco2.github.io/codecarbon/parameters.html>
    #' @param sustain_interval `int` Interval in seconds for measuring power usage.
    #' @param epochs `int` Number of training epochs.
    #' @param batch_size `int` Size of the batches for training.
    #' @param dir_checkpoint `string` Path to the directory where the checkpoint during training should be saved. If the
    #'   directory does not exist, it is created.
    #' @param log_dir `string` Path to the directory where the log files should be saved. If no logging is desired set
    #'   this argument to `NULL`.
    #' @param log_write_interval `int` Time in seconds determining the interval in which the logger should try to update
    #'   the log files. Only relevant if `log_dir` is not `NULL`.
    #' @param trace `bool` `TRUE`, if information about the estimation phase should be printed to the console.
    #' @param ml_trace `int` `ml_trace=0` does not print any information about the training process from pytorch on the
    #'   console.
    #' @param n_cores `int` Number of cores which should be used during the calculation of synthetic cases. Only relevant if
    #'  `use_sc=TRUE`.
    #' @return Function does not return a value. It changes the object into a trained classifier.
    #' @details
    #'
    #' * `sc_max_k`: All values from sc_min_k up to sc_max_k are successively used. If
    #' the number of sc_max_k is too high, the value is reduced to a number that allows the calculating of synthetic
    #' units.
    #' * `pl_anchor`: With the help of this value, the new cases are sorted. For
    #' this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
    #'
    train = function(data_embeddings,
                     data_targets,
                     data_folds = 5,
                     data_val_size = 0.25,
                     balance_class_weights = TRUE,
                     balance_sequence_length = TRUE,
                     use_sc = TRUE,
                     sc_method = "dbsmote",
                     sc_min_k = 1,
                     sc_max_k = 10,
                     use_pl = TRUE,
                     pl_max_steps = 3,
                     pl_max = 1.00,
                     pl_anchor = 1.00,
                     pl_min = 0.00,
                     sustain_track = TRUE,
                     sustain_iso_code = NULL,
                     sustain_region = NULL,
                     sustain_interval = 15,
                     epochs = 40,
                     batch_size = 32,
                     dir_checkpoint,
                     trace = TRUE,
                     ml_trace = 1,
                     log_dir = NULL,
                     log_write_interval = 10,
                     n_cores=auto_n_cores()) {
      # Checking Arguments------------------------------------------------------
      check_type(data_folds, type = "int", FALSE)
      check_type(data_val_size, type = "double", FALSE)
      check_type(balance_class_weights, type = "bool", FALSE)
      check_type(balance_sequence_length, type = "bool", FALSE)
      check_type(use_sc, type = "bool", FALSE)
      check_type(sc_method, type = "string", FALSE)
      check_type(sc_min_k, type = "int", FALSE)
      check_type(sc_max_k, type = "int", FALSE)
      check_type(use_pl, type = "bool", FALSE)
      check_type(pl_max_steps, type = "int", FALSE)
      check_type(pl_max, type = "double", FALSE)
      check_type(pl_anchor, type = "double", FALSE)
      check_type(pl_min, type = "double", FALSE)
      check_type(sustain_track, type = "bool", FALSE)
      check_type(sustain_iso_code, type = "string", TRUE)
      check_type(sustain_region, type = "string", TRUE)
      check_type(sustain_interval, type = "int", FALSE)
      check_type(epochs, type = "int", FALSE)
      check_type(batch_size, type = "int", FALSE)
      check_type(dir_checkpoint, type = "string", FALSE)
      check_type(trace, type = "bool", FALSE)
      check_type(n_cores, type = "int", FALSE)

      check_class(data_embeddings, c("EmbeddedText", "LargeDataSetForTextEmbeddings"), FALSE)
      self$check_embedding_model(data_embeddings, require_compressed = FALSE)

      check_class(data_targets, c("factor"), FALSE)
      data_targets <- private$check_and_adjust_target_levels(data_targets)

      if (pl_anchor < pl_min) {
        stop("pl_anchor must be at least pl_min.")
      }
      if (pl_anchor > pl_max) {
        stop("pl_anchor must be lower or equal to pl_max.")
      }

      if (data_folds < 2) {
        stop("data_folds must be at least 2.")
      }

      # Saving training configuration-------------------------------------------
      self$last_training$config$balance_class_weights <- balance_class_weights
      self$last_training$config$balance_sequence_length <- balance_sequence_length
      self$last_training$config$data_val_size <- data_val_size
      self$last_training$config$use_sc <- use_sc
      self$last_training$config$sc_method <- sc_method
      self$last_training$config$sc_min_k <- sc_min_k
      self$last_training$config$sc_max_k <- sc_max_k
      self$last_training$config$use_pl <- use_pl
      self$last_training$config$pl_max_steps <- pl_max_steps
      self$last_training$config$pl_max <- pl_max
      self$last_training$config$pl_anchor <- pl_anchor
      self$last_training$config$pl_min <- pl_min
      self$last_training$config$sustain_track <- sustain_track
      self$last_training$config$sustain_iso_code <- sustain_iso_code
      self$last_training$config$sustain_region <- sustain_region
      self$last_training$config$sustain_interval <- sustain_interval
      self$last_training$config$epochs <- epochs
      self$last_training$config$batch_size <- batch_size
      self$last_training$config$dir_checkpoint <- dir_checkpoint
      self$last_training$config$trace <- trace
      self$last_training$config$ml_trace <- ml_trace

      self$last_training$config$n_cores<-n_cores

      private$log_config$log_dir <- log_dir
      private$log_config$log_state_file <- paste0(private$log_config$log_dir, "/aifeducation_state.log")
      private$log_config$log_write_interval <- log_write_interval

      # Start-------------------------------------------------------------------
      if (self$last_training$config$trace == TRUE) {
        message(paste(
          date(),
          "Start"
        ))
      }

      # Set up data
      if ("EmbeddedText" %in% class(data_embeddings)) {
        data <- data_embeddings$convert_to_LargeDataSetForTextEmbeddings()
      } else {
        data <- data_embeddings
      }

      # Create DataManager------------------------------------------------------
      if (self$model_config$use_fe == TRUE) {
        compressed_embeddings <- self$feature_extractor$extract_features_large(
          data_embeddings = data,
          as.integer(self$last_training$config$batch_size),
          trace = self$last_training$config$trace
        )
        data_manager <- DataManagerClassifier$new(
          data_embeddings = compressed_embeddings,
          data_targets = data_targets,
          folds = data_folds,
          val_size = self$last_training$config$data_val_size,
          class_levels = self$model_config$target_levels,
          one_hot_encoding = self$model_config$require_one_hot,
          add_matrix_map = (self$model_config$require_matrix_map == TRUE | self$last_training$config$use_sc == TRUE),
          sc_method = sc_method,
          sc_min_k = sc_min_k,
          sc_max_k = sc_max_k,
          trace = trace,
          n_cores=self$last_training$config$n_cores
        )
      } else {
        data_manager <- DataManagerClassifier$new(
          data_embeddings = data,
          data_targets = data_targets,
          folds = data_folds,
          val_size = self$last_training$config$data_val_size,
          class_levels = self$model_config$target_levels,
          one_hot_encoding = self$model_config$require_one_hot,
          add_matrix_map = (self$model_config$require_matrix_map == TRUE | self$last_training$config$use_sc == TRUE),
          sc_method = sc_method,
          sc_min_k = sc_min_k,
          sc_max_k = sc_max_k,
          trace = trace,
          n_cores=self$last_training$config$n_cores
        )
      }

      # Save Data Statistics
      self$last_training$data <- data_manager$get_statistics()

      # Save the number of folds
      self$last_training$config$n_folds <- data_manager$get_n_folds()

      # Init Training------------------------------------------------------------
      private$init_train()

      # config datasets
      datasets$disable_progress_bars()
      # datasets$disable_caching()

      # Start Sustainability Tracking-------------------------------------------
      if (sustain_track == TRUE) {
        if (is.null(sustain_iso_code) == TRUE) {
          stop("Sustainability tracking is activated but iso code for the
               country is missing. Add iso code or deactivate tracking.")
        }


        tmp_code_carbon<-reticulate::import("codecarbon")
        codecarbon_version=as.character(tmp_code_carbon["__version__"])
        if(utils::compareVersion(codecarbon_version,"2.8.0")>=0){
          path_look_file=codecarbon$lock$LOCKFILE
          if(file.exists(path_look_file)){
            unlink(path_look_file)
          }
        }

        sustainability_tracker <- codecarbon$OfflineEmissionsTracker(
          country_iso_code = sustain_iso_code,
          region = sustain_region,
          tracking_mode = "machine",
          log_level = "warning",
          measure_power_secs = sustain_interval,
          save_to_file = FALSE,
          save_to_api = FALSE
        )
        sustainability_tracker$start()
      }

      # Start Training----------------------------------------------------------
      # Load Custom Model Scripts
      private$load_reload_python_scripts()

      # Start Loop inclusive final training
      for (iter in 1:(self$last_training$config$n_folds + 1)) {
        base::gc(verbose = FALSE, full = TRUE)

        if (self$last_training$config$use_pl == FALSE) {
          private$train_standard(
            iteration = iter,
            data_manager = data_manager,
            inc_synthetic = self$last_training$config$use_sc
          )
        } else if (self$last_training$config$use_pl == TRUE) {
          private$train_with_pseudo_labels(
            init_train = TRUE,
            iteration = iter,
            data_manager = data_manager,
            inc_synthetic = self$last_training$config$use_sc
          )
        }

        # Calculate measures on categorical level
        private$calculate_measures_on_categorical_level(
          data_manager = data_manager,
          iteration = iter
        )
        gc()
      }

      # Finalize Training
      private$finalize_train()

      # Stop sustainability tracking if requested
      if (sustain_track == TRUE) {
        sustainability_tracker$stop()
        private$sustainability <- summarize_tracked_sustainability(sustainability_tracker)
      } else {
        private$sustainability <- list(
          sustainability_tracked = FALSE,
          date = NA,
          sustainability_data = list(
            duration_sec = NA,
            co2eq_kg = NA,
            cpu_energy_kwh = NA,
            gpu_energy_kwh = NA,
            ram_energy_kwh = NA,
            total_energy_kwh = NA
          )
        )
      }

      if (self$last_training$config$trace == TRUE) {
        message(paste(
          date(),
          "Training Complete"
        ))
      }
    },
    #-------------------------------------------------------------------------
    #' @description Method for predicting new data with a trained neural net.
    #' @param newdata Object of class [TextEmbeddingModel] or [LargeDataSetForTextEmbeddings] for which predictions
    #'   should be made. In addition, this method allows to use objects of class `array` and
    #'   `datasets.arrow_dataset.Dataset`. However, these should be used only by developers.
    #' @param ml_trace `int` `ml_trace=0` does not print any information on the process from the machine learning
    #'   framework.
    #' @param batch_size `int` Size of batches.
    #' @return Returns a `data.frame` containing the predictions and the probabilities of the different labels for each
    #'   case.
    predict = function(newdata,
                       batch_size = 32,
                       ml_trace = 1) {
      # Check arguments
      check_type(batch_size, "int", FALSE)
      check_type(ml_trace, "int", FALSE)

      # Check if the embeddings must be compressed before passing to the model
      requires_compression <- self$requires_compression(newdata)

      # Check input for compatible text embedding models and feature extractors
      if ("EmbeddedText" %in% class(newdata) |
        "LargeDataSetForTextEmbeddings" %in% class(newdata)) {
        self$check_embedding_model(text_embeddings = newdata, require_compressed = FALSE)
      } else {
        private$check_embeddings_object_type(newdata, strict = FALSE)
        if (requires_compression == TRUE) {
          stop("Objects of class datasets.arrow_dataset.Dataset must be provided in
               compressed form.")
        }
      }


      # Apply feature extractor if it is part of the model
      if (requires_compression == TRUE) {
        if ("EmbeddedText" %in% class(newdata)) {
          newdata <- newdata$convert_to_LargeDataSetForTextEmbeddings()
        } else {
          newdata <- newdata
        }

        # Returns a data set
        newdata <- self$feature_extractor$extract_features_large(
          data_embeddings = newdata,
          batch_size = as.integer(batch_size)
        )
      }

      # Load Custom Model Scripts
      private$load_reload_python_scripts()

      # Check number of cases in the data
      single_prediction <- private$check_single_prediction(newdata)

      # Get current row names/name of the cases
      current_row_names <- private$get_rownames_from_embeddings(newdata)

      # If at least two cases are part of the data set---------------------------
      if (single_prediction == FALSE) {
        # Returns a data set object
        prediction_data <- private$prepare_embeddings_as_dataset(newdata)

        # Tensorflow----------------------------------------------------------
        if (private$ml_framework == "tensorflow") {
          # Prepare data set for tensorflow
          prediction_data <- prediction_data$rename_column("input", "input_embeddings")
          if ("labels" %in% prediction_data$column_names) {
            prediction_data <- prediction_data$remove_columns("labels")
          }
          prediction_data$set_format("tf")
          tf_dataset_predict <- prediction_data$to_tf_dataset(
            columns = c("input_embeddings"),
            batch_size = as.integer(batch_size),
            shuffle = FALSE
          )


          if (length(self$model_config$target_levels) > 2) {
            # Multi Class
            predictions_prob <- self$model$predict(
              x = tf_dataset_predict,
              verbose = as.integer(ml_trace)
            )
            predictions <- max.col(predictions_prob) - 1
          } else {
            predictions_prob <- self$model$predict(
              x = tf_dataset_predict,
              verbose = as.integer(ml_trace)
            )

            # Add Column for the second characteristic
            predictions <- vector(length = length(predictions_prob))
            predictions_binary_prob <- matrix(
              ncol = 2,
              nrow = length(predictions_prob)
            )

            for (i in 1:length(predictions_prob)) {
              if (predictions_prob[i] >= 0.5) {
                predictions_binary_prob[i, 1] <- 1 - predictions_prob[i]
                predictions_binary_prob[i, 2] <- predictions_prob[i]
                predictions[i] <- 1
              } else {
                predictions_binary_prob[i, 1] <- 1 - predictions_prob[i]
                predictions_binary_prob[i, 2] <- predictions_prob[i]
                predictions[i] <- 0
              }
            }
            predictions_prob <- predictions_binary_prob
          }
          # Pytorch----------------------------------------------------------
        } else if (private$ml_framework == "pytorch") {
          prediction_data$set_format("torch")
          predictions_prob <- py$TeClassifierBatchPredict(
            model = self$model,
            dataset = prediction_data,
            batch_size = as.integer(batch_size)
          )
          predictions_prob <- private$detach_tensors(predictions_prob)
          predictions <- max.col(predictions_prob) - 1
        }

        # In the case the data has one single row-------------------------------
      } else {
        prediction_data <- private$prepare_embeddings_as_np_array(newdata)

        # Tensorflow------------------------------------------------------------
        if (private$ml_framework == "tensorflow") {
          if (length(self$model_config$target_levels) > 2) {
            # Multy Class
            predictions_prob <- self$model$predict(
              x = prediction_data,
              batch_size = as.integer(batch_size),
              verbose = as.integer(ml_trace)
            )
            predictions <- max.col(predictions_prob) - 1
          } else {
            predictions_prob <- self$model$predict(
              x = prediction_data,
              batch_size = as.integer(batch_size),
              verbose = as.integer(ml_trace)
            )

            # Add Column for the second characteristic
            predictions <- vector(length = length(predictions_prob))
            predictions_binary_prob <- matrix(
              ncol = 2,
              nrow = length(predictions_prob)
            )

            for (i in 1:length(predictions_prob)) {
              if (predictions_prob[i] >= 0.5) {
                predictions_binary_prob[i, 1] <- 1 - predictions_prob[i]
                predictions_binary_prob[i, 2] <- predictions_prob[i]
                predictions[i] <- 1
              } else {
                predictions_binary_prob[i, 1] <- 1 - predictions_prob[i]
                predictions_binary_prob[i, 2] <- predictions_prob[i]
                predictions[i] <- 0
              }
            }
            predictions_prob <- predictions_binary_prob
          }
        } else if (private$ml_framework == "pytorch") {
          if (torch$cuda$is_available()) {
            device <- "cuda"
            dtype <- torch$double
            self$model$to(device, dtype = dtype)
            self$model$eval()
            input <- torch$from_numpy(prediction_data)
            predictions_prob <- self$model(input$to(device, dtype = dtype),
              predication_mode = TRUE
            )
            predictions_prob <- private$detach_tensors(predictions_prob)
          } else {
            device <- "cpu"
            dtype <- torch$float
            self$model$to(device, dtype = dtype)
            self$model$eval()
            input <- torch$from_numpy(prediction_data)
            predictions_prob <- self$model(input$to(device, dtype = dtype),
              predication_mode = TRUE
            )
            predictions_prob <- private$detach_tensors(predictions_prob)
          }
          predictions <- max.col(predictions_prob) - 1
        }
      }


      # Transforming predictions to target levels------------------------------
      predictions <- as.character(as.vector(predictions))
      for (i in 0:(length(self$model_config$target_levels) - 1)) {
        predictions <- replace(
          x = predictions,
          predictions == as.character(i),
          values = self$model_config$target_levels[i + 1]
        )
      }

      # Transforming to a factor
      predictions <- factor(predictions, levels = self$model_config$target_levels)

      colnames(predictions_prob) <- self$model_config$target_levels
      predictions_prob <- as.data.frame(predictions_prob)
      predictions_prob$expected_category <- predictions
      rownames(predictions_prob) <- current_row_names
      return(predictions_prob)
    },
    # Check Embedding Model compatibility of the text embedding
    #' @description Method for checking if the provided text embeddings are created with the same [TextEmbeddingModel]
    #'   as the classifier.
    #' @param text_embeddings Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings].
    #' @param require_compressed `TRUE` if a compressed version of the embeddings are necessary. Compressed embeddings
    #'   are created by an object of class [TEFeatureExtractor].
    #' @return `TRUE` if the underlying [TextEmbeddingModel] is the same. `FALSE` if the models differ.
    check_embedding_model = function(text_embeddings, require_compressed = FALSE) {
      # Check Embeddings Object Type
      private$check_embeddings_object_type(text_embeddings, strict = TRUE)

      # Check original text embedding model.
      embedding_model_config <- text_embeddings$get_model_info()
      check <- c("model_name")

      if (!is.null_or_na(embedding_model_config[[check]]) &
        !is.null_or_na(private$text_embedding_model$model[[check]])) {
        if (embedding_model_config[[check]] != private$text_embedding_model$model[[check]]) {
          stop("The TextEmbeddingModel that generated the data_embeddings is not
               the same as the TextEmbeddingModel when generating the classifier.")
        }
      }

      # Check if a compressed version is necessary and if true if the feature extractor is
      # compatible
      feature_extractor_info <- text_embeddings$get_feature_extractor_info()
      if (require_compressed == TRUE) {
        if (!is.null(feature_extractor_info$model_name) & self$model_config$use_fe == FALSE) {
          stop("Compressed embeddings provided but the classifier does not support
             compressed embeddings.")
        } else if (!is.null(feature_extractor_info$model_name) & self$model_config$use_fe == TRUE) {
          if (private$text_embedding_model$feature_extractor$model_name != feature_extractor_info$model_name) {
            stop("The feature extractor of the compressed embeddings is not the same as
               the feature extractor during the creation of the classifier.")
          }
        }
      } else {
        if (!is.null(feature_extractor_info$model_name)) {
          stop("Compressed embeddings are provided but uncompressed are needed.")
        }
      }
    },
    #--------------------------------------------------------------------------
    #' @description Method for checking an object of class [TEFeatureExtractor].
    #' @param feature_extractor Object of class [TEFeatureExtractor]
    #' @return This method does nothing returns. It raises an error if
    #'
    #' * the object is `NULL`
    #' * the object does not rely on the same machine learning framework as the classifier
    #' * the object is not trained.
    #'
    check_feature_extractor_object_type = function(feature_extractor) {
      if (!is.null(feature_extractor)) {
        if ("TEFeatureExtractor" %in% class(feature_extractor) == FALSE) {
          stop("Object passed to feature_extractor must be an object of class
               TEFeatureExtractor or NULL.")
        } else {
          if (feature_extractor$get_ml_framework() != self$get_ml_framework()) {
            stop("The machine learning framework of the feature extractior and
                 classifier do not match. Please provide a feature extractor
                 with the same machine learning framework as the classifier.")
          } else {
            if (feature_extractor$is_trained() == FALSE) {
              stop("The supplied feature extractor is not trained. Please
                provide trained feature extractor and try again.")
            }
          }
        }
      }
    },
    #--------------------------------------------------------------------------
    #' @description Method for checking if provided text embeddings must be compressed via a [TEFeatureExtractor] before
    #'   processing.
    #' @param text_embeddings Object of class [EmbeddedText], [LargeDataSetForTextEmbeddings], `array` or
    #'   `datasets.arrow_dataset.Dataset`.
    #' @return Return `TRUE` if a compression is necessary and `FALSE` if not.
    requires_compression = function(text_embeddings) {
      # Check arguments
      check_class(text_embeddings, c(
        "EmbeddedText", "LargeDataSetForTextEmbeddings",
        "array", "datasets.arrow_dataset.Dataset"
      ), FALSE)

      if ("EmbeddedText" %in% class(text_embeddings) |
        "LargeDataSetForTextEmbeddings" %in% class(text_embeddings)) {
        if (self$model_config$use_fe == TRUE & text_embeddings$is_compressed() == FALSE) {
          return(TRUE)
        } else {
          return(FALSE)
        }
      } else if ("array" %in% class(text_embeddings)) {
        if (dim(text_embeddings)[3] > self$model_config$features) {
          return(TRUE)
        } else {
          return(FALSE)
        }
      } else if ("datasets.arrow_dataset.Dataset" %in% class(text_embeddings)) {
        text_embeddings$set_format("np")
        tensors <- text_embeddings["input"][1, , , drop = FALSE]
        if (dim(tensors)[3] > self$model_config$features) {
          return(TRUE)
        } else {
          return(FALSE)
        }
      }
    },
    #-------------------------------------------------------------------------
    #' @description Method for saving a model.
    #' @param dir_path `string` Path of the directory where the model should be saved.
    #' @param folder_name `string` Name of the folder that should be created within the directory.
    #' @return Function does not return a value. It saves the model to disk.
    save = function(dir_path, folder_name) {
      # Save the classifier
      super$save(
        dir_path = dir_path,
        folder_name = folder_name
      )

      # Save the feature extractor if necessary
      if (self$model_config$use_fe == TRUE) {
        save_to_disk(
          object = self$feature_extractor,
          dir_path = paste0(dir_path, "/", folder_name),
          folder_name = "feature_extractor"
        )
      }
    },
    #--------------------------------------------------------------------------
    #' @description loads an object from disk and updates the object to the current version of the package.
    #' @param dir_path Path where the object set is stored.
    #' @return Method does not return anything. It loads an object from disk.
    load_from_disk = function(dir_path) {
      # Call the core method which loads data common for all models
      private$load_config_and_docs(dir_path = dir_path)

      # Add classifier specific data
      # Load R file
      config_file <- load_R_config_state(dir_path)

      # Set Reliability measures
      self$reliability <- list(
        test_metric = config_file$public$reliability$test_metric,
        test_metric_mean = config_file$public$reliability$test_metric_mean,
        raw_iota_objects = list(
          iota_objects_end = config_file$public$reliability$raw_iota_objects$iota_objects_end,
          iota_objects_end_free = config_file$public$reliability$raw_iota_objects$iota_objects_end_free
        ),
        iota_object_end = config_file$public$reliability$iota_object_end,
        iota_object_end_free = config_file$public$reliability$iota_object_end_free,
        standard_measures_end = config_file$public$reliability$standard_measures_end,
        standard_measures_mean = config_file$public$reliability$standard_measures_mean
      )

      # Set FeatureExtractor
      if (self$model_config$use_fe == TRUE) {
        feature_extractor <- TEFeatureExtractor$new()
        feature_extractor$load_from_disk(paste0(dir_path, "/feature_extractor"))
        self$feature_extractor <- feature_extractor
      }

      # Create and load AI model
      private$create_reset_model()
      self$load(dir_path = dir_path)
    }
  ),
  private = list(
    #--------------------------------------------------------------------------
    load_reload_python_scripts = function() {
      if (private$ml_framework == "tensorflow") {
        reticulate::py_run_file(system.file("python/keras_te_classifier.py",
          package = "aifeducation"
        ))
        reticulate::py_run_file(system.file("python/keras_callbacks.py",
          package = "aifeducation"
        ))
      } else if (private$ml_framework == "pytorch") {
        reticulate::py_run_file(system.file("python/pytorch_te_classifier.py",
          package = "aifeducation"
        ))
        reticulate::py_run_file(system.file("python/pytorch_autoencoder.py",
          package = "aifeducation"
        ))
        reticulate::py_run_file(system.file("python/py_log.py",
          package = "aifeducation"
        ))
      }
    },
    #--------------------------------------------------------------------------
    create_reset_model = function() {
      private$check_config_for_TRUE()

      if (private$ml_framework == "tensorflow") {
        # Load custom layers
        private$load_reload_python_scripts()

        # Defining basic keras model
        layer_list <- NULL

        # Adding Input Layer
        # if(n_rec>0 | self$model_config$repeat_encoder>0){
        model_input <- keras$layers$Input(
          shape = list(as.integer(self$model_config$times), as.integer(self$model_config$features)),
          name = "input_embeddings"
        )
        # } #else {
        # model_input<-keras$layers$Input(shape=as.integer(self$model_config$times*self$model_config$features),
        #                                name="input_embeddings")
        # }
        layer_list[1] <- list(model_input)

        # Adding a Mask-Layer
        if (self$model_config$rec_layers > 0 | self$model_config$repeat_encoder > 0) {
          masking_layer <- keras$layers$Masking(
            mask_value = 0.0,
            name = "masking_layer",
            input_shape = c(self$model_config$times, self$model_config$features),
            trainable = FALSE
          )(layer_list[[length(layer_list)]])
          layer_list[length(layer_list) + 1] <- list(masking_layer)

          if (self$model_config$add_pos_embedding == TRUE) {
            positional_embedding <- py$AddPositionalEmbedding(
              sequence_length = as.integer(self$model_config$times),
              name = "add_positional_embedding"
            )(layer_list[[length(layer_list)]])
            layer_list[length(layer_list) + 1] <- list(positional_embedding)
          }

          norm_layer <- keras$layers$LayerNormalization(
            name = "normalizaion_layer"
          )(layer_list[[length(layer_list)]])
          layer_list[length(layer_list) + 1] <- list(norm_layer)
        } # else {
        # norm_layer<-keras$layers$BatchNormalization(
        #  name = "normalizaion_layer")(layer_list[[length(layer_list)]])
        # layer_list[length(layer_list)+1]<-list(norm_layer)
        # }

        if (self$model_config$repeat_encoder > 0) {
          for (r in 1:self$model_config$repeat_encoder) {
            if (self$model_config$attention_type == "multihead") {
              layer_list[length(layer_list) + 1] <- list(
                py$TransformerEncoder(
                  embed_dim = as.integer(self$model_config$features),
                  dense_dim = as.integer(self$model_config$intermediate_size),
                  num_heads = as.integer(self$model_config$self_attention_heads),
                  dropout_rate = self$model_config$encoder_dropout,
                  name = paste0("encoder_", r)
                )(layer_list[[length(layer_list)]])
              )
            } else if (self$model_config$attention_type == "fourier") {
              layer_list[length(layer_list) + 1] <- list(
                py$FourierEncoder(
                  dense_dim = as.integer(self$model_config$intermediate_size),
                  dropout_rate = self$model_config$encoder_dropout,
                  name = paste0("encoder_", r)
                )(layer_list[[length(layer_list)]])
              )
            }
          }
        }

        # Adding rec layer
        if (self$model_config$rec_layers > 0) {
          if (self$model_config$rec_bidirectional == TRUE) {
            for (i in 1:self$model_config$rec_layers) {
              if (self$model_config$rec_type == "gru") {
                layer_list[length(layer_list) + 1] <- list(
                  keras$layers$Bidirectional(
                    layer = keras$layers$GRU(
                      units = as.integer(self$model_config$rec_size),
                      input_shape = list(self$model_config$times, self$model_config$features),
                      return_sequences = TRUE,
                      dropout = 0,
                      recurrent_dropout = self$model_config$recurrent_dropout,
                      activation = "tanh",
                      name = paste0("gru_", i)
                    ),
                    name = paste0("bidirectional_", i)
                  )(layer_list[[length(layer_list)]])
                )
                if (i != self$model_config$rec_layers) {
                  layer_list[length(layer_list) + 1] <- list(
                    keras$layers$Dropout(
                      rate = self$model_config$rec_dropout,
                      name = paste0("gru_dropout_", i)
                    )(layer_list[[length(layer_list)]])
                  )
                }
              } else if (self$model_config$rec_type == "lstm") {
                layer_list[length(layer_list) + 1] <- list(
                  keras$layers$Bidirectional(
                    layer = keras$layers$LSTM(
                      units = as.integer(self$model_config$rec_size),
                      input_shape = list(self$model_config$times, self$model_config$features),
                      return_sequences = TRUE,
                      dropout = 0,
                      recurrent_dropout = self$model_config$recurrent_dropout,
                      activation = "tanh",
                      name = paste0("lstm", i)
                    ),
                    name = paste0("bidirectional_", i)
                  )(layer_list[[length(layer_list)]])
                )
                if (i != self$model_config$rec_layers) {
                  layer_list[length(layer_list) + 1] <- list(
                    keras$layers$Dropout(
                      rate = self$model_config$rec_dropout,
                      name = paste0("lstm_dropout_", i)
                    )(layer_list[[length(layer_list)]])
                  )
                }
              }
            }
          } else {
            for (i in 1:self$model_config$rec_layers) {
              if (self$model_config$rec_type == "gru") {
                layer_list[length(layer_list) + 1] <- list(
                  layer = keras$layers$GRU(
                    units = as.integer(self$model_config$rec_size),
                    input_shape = list(self$model_config$times, self$model_config$features),
                    return_sequences = TRUE,
                    dropout = 0,
                    recurrent_dropout = self$model_config$recurrent_dropout,
                    activation = "tanh",
                    name = paste0("uni_directional_gru_", i)
                  )(layer_list[[length(layer_list)]])
                )
                if (i != self$model_config$rec_layers) {
                  layer_list[length(layer_list) + 1] <- list(
                    keras$layers$Dropout(
                      rate = self$model_config$rec_dropout,
                      name = paste0("gru_dropout_", i)
                    )(layer_list[[length(layer_list)]])
                  )
                }
              } else if (self$model_config$rec_type == "lstm") {
                layer_list[length(layer_list) + 1] <- list(
                  layer = keras$layers$LSTM(
                    units = as.integer(self$model_config$rec_size),
                    input_shape = list(self$model_config$times, self$model_config$features),
                    return_sequences = TRUE,
                    dropout = 0,
                    recurrent_dropout = self$model_config$recurrent_dropout,
                    activation = "tanh",
                    name = paste0("unidirectional_lstm", i)
                  )(layer_list[[length(layer_list)]])
                )
                if (i != self$model_config$rec_layers) {
                  layer_list[length(layer_list) + 1] <- list(
                    keras$layers$Dropout(
                      rate = self$model_config$rec_dropout,
                      name = paste0("lstm_dropout_", i)
                    )(layer_list[[length(layer_list)]])
                  )
                }
              }
            }
          }
        }


        layer_list[length(layer_list) + 1] <- list(
          keras$layers$GlobalAveragePooling1D(
            name = "global_average_pooling"
          )(layer_list[[length(layer_list)]])
        )


        # Adding standard layer
        if (self$model_config$dense_layers > 0) {
          for (i in 1:self$model_config$dense_layers) {
            layer_list[length(layer_list) + 1] <- list(
              keras$layers$Dense(
                units = as.integer(self$model_config$dense_size),
                activation = "gelu",
                name = paste0("dense_", i)
              )(layer_list[[length(layer_list)]])
            )

            if (i != self$model_config$dense_layers) {
              # Add Dropout_Layer
              layer_list[length(layer_list) + 1] <- list(
                keras$layers$Dropout(
                  rate = self$model_config$dense_dropout,
                  name = paste0("dense_dropout_", i)
                )(layer_list[[length(layer_list)]])
              )
            }
          }
        }

        # Adding final Layer
        if (length(self$model_config$target_levels) > 2) {
          # Multi Class
          layer_list[length(layer_list) + 1] <- list(
            keras$layers$Dense(
              units = as.integer(length(self$model_config$target_levels)),
              activation = self$model_config$act_fct_last,
              name = "output_categories"
            )(layer_list[[length(layer_list)]])
          )
        } else {
          # Binary Class
          layer_list[length(layer_list) + 1] <- list(
            keras$layers$Dense(
              units = as.integer(1),
              activation = self$model_config$act_fct_last,
              name = "output_categories"
            )(layer_list[[length(layer_list)]])
          )
        }

        # Creating Model
        model <- keras$Model(
          inputs = model_input,
          outputs = layer_list[length(layer_list)],
          name = self$model_config$name
        )

        self$model <- model
      } else {
        #--------------------------------------------------------------------------
        # Load Custom Pytorch Objects and Functions
        private$load_reload_python_scripts()

        self$model <- py$TextEmbeddingClassifier_PT(
          features = as.integer(self$model_config$features),
          times = as.integer(self$model_config$times),
          dense_size=as.integer(self$model_config$dense_size),
          dense_layers=as.integer(self$model_config$dense_layers),
          rec_size=as.integer(self$model_config$rec_size),
          rec_layers=as.integer(self$model_config$rec_layers),
          rec_type = self$model_config$rec_type,
          rec_bidirectional = self$model_config$rec_bidirectional,
          intermediate_size = as.integer(self$model_config$intermediate_size),
          attention_type = self$model_config$attention_type,
          repeat_encoder = as.integer(self$model_config$repeat_encoder),
          dense_dropout = self$model_config$dense_dropout,
          rec_dropout = self$model_config$rec_dropout,
          encoder_dropout = self$model_config$encoder_dropout,
          add_pos_embedding = self$model_config$add_pos_embedding,
          self_attention_heads = as.integer(self$model_config$self_attention_heads),
          target_levels = self$model_config$target_levels
        )
      }
    },
    #--------------------------------------------------------------------------
    init_train = function() {
      # Setting a new ID for the classifier
      private$model_info$model_name <- paste0(
        private$model_info$model_name_root,
        "_id_",
        generate_id(16)
      )

      # Initializing Objects for Saving Performance
      metric_names <- get_coder_metrics(
        true_values = NULL,
        predicted_values = NULL,
        return_names_only = TRUE
      )

      self$reliability$test_metric <- matrix(
        nrow = self$last_training$config$n_folds,
        ncol = length(metric_names),
        dimnames = list(
          iterations = NULL,
          metrics = metric_names
        )
      )

      self$reliability$test_metric_mean <- NULL

      self$reliability$iota_objects_end <- NULL
      self$reliability$iota_objects_end_free <- NULL

      self$reliability$iota_object_end <- NULL
      self$reliability$iota_object_end_free <- NULL

      standard_measures_mean_table <- matrix(
        nrow = length(self$model_config$target_levels),
        ncol = 3,
        data = 0
      )
      colnames(standard_measures_mean_table) <- c("precision", "recall", "f1")
      rownames(standard_measures_mean_table) <- self$model_config$target_levels

      self$reliability$standard_measures_mean <- standard_measures_mean_table

      # Save start time of training
      self$last_training$start_time <- Sys.time()
    },
    #--------------------------------------------------------------------------
    calculate_test_metric = function(test_data, iteration, type) {
      test_predictions <- self$predict(
        newdata = test_data,
        ml_trace = self$last_training$config$ml_trace,
        batch_size = self$last_training$config$batch_size
      )
      test_pred_cat <- test_predictions$expected_category
      names(test_pred_cat) <- rownames(test_predictions)
      test_pred_cat <- test_pred_cat[test_data["id"]]
      test_data$set_format("np")
      true_values <- factor(
        x = test_data["labels"],
        levels = 0:(length(self$model_config$target_levels) - 1),
        labels = self$model_config$target_levels
      )
      names(true_values) <- test_data["id"]
      test_res <- get_coder_metrics(
        true_values = true_values,
        predicted_values = test_pred_cat
      )

      # Save results
      self$reliability$test_metric[iteration, ] <- test_res

    },
    #--------------------------------------------------------------------------
    calculate_measures_on_categorical_level = function(data_manager, iteration) {
      # Get test data
      data_manager$set_state(
        iteration = iteration,
        step = NULL
      )
      test_data <- data_manager$get_test_dataset()

      if (!is.null(test_data) == TRUE) {
        # Predict labels
        test_predictions <- self$predict(
          newdata = test_data,
          ml_trace = self$last_training$config$ml_trace,
          batch_size = self$last_training$config$batch_size
        )
        test_pred_cat <- test_predictions$expected_category
        names(test_pred_cat) <- rownames(test_predictions)
        test_pred_cat <- test_pred_cat[test_data["id"]]

        # Calculate standard measures
        test_data$set_format("np")
        true_values <- factor(
          x = test_data["labels"],
          levels = 0:(length(self$model_config$target_levels) - 1),
          labels = self$model_config$target_levels
        )
        names(true_values) <- test_data["id"]
        self$reliability$standard_measures_end[iteration] <- list(
          calc_standard_classification_measures(
            true_values = true_values,
            predicted_values = test_pred_cat
          )
        )

        # Calculate iota objects
        self$reliability$iota_objects_end[iteration] <- list(iotarelr::check_new_rater(
          true_values = factor(
            x = test_data["labels"],
            levels = 0:(length(self$model_config$target_levels) - 1),
            labels = self$model_config$target_levels
          ),
          assigned_values = test_pred_cat,
          free_aem = FALSE
        ))
        self$reliability$iota_objects_end_free[iteration] <- list(iotarelr::check_new_rater(
          true_values = factor(
            x = test_data["labels"],
            levels = 0:(length(self$model_config$target_levels) - 1),
            labels = self$model_config$target_levels
          ),
          assigned_values = test_pred_cat,
          free_aem = TRUE
        ))
      } else if (iteration <= data_manager$get_n_folds()) {
        warning("Unable to calculate test scores. There is no test data.")
      }
    },
    #--------------------------------------------------------------------------
    finalize_train = function() {
      # Save Final Information
      self$last_training$date <- date()

      # Finalize measures from content analysis
      test_metric_mean <- vector(length = ncol(self$reliability$test_metric))
      test_metric_mean[] <- 0
      names(test_metric_mean) <- colnames(self$reliability$test_metric)

      n_mean <- vector(length = ncol(self$reliability$test_metric))
      n_mean[] <- self$last_training$config$n_folds

      for (i in 1:self$last_training$config$n_folds) {
        for (j in 1:ncol(self$reliability$test_metric)) {
          if (is.na(self$reliability$test_metric[i, j]) == FALSE) {
            test_metric_mean[j] <- test_metric_mean[j] + self$reliability$test_metric[i, j]
          } else {
            n_mean[j] <- n_mean[j] - 1
          }
        }
      }

      test_metric_mean <- test_metric_mean / n_mean
      test_metric_mean[is.nan(test_metric_mean)] <- NA
      self$reliability$test_metric_mean <- test_metric_mean

      self$last_training$learning_time <- as.numeric(
        difftime(Sys.time(),
          self$last_training$start_time,
          units = "mins"
        )
      )

      # Finalize iota objects
      if (is.null(self$reliability$iota_objects_end) == FALSE) {
        self$reliability$iota_object_end <- create_iota2_mean_object(
          iota2_list = self$reliability$iota_objects_end,
          original_cat_labels = self$model_config$target_levels,
          free_aem = FALSE,
          call = "aifeducation::te_classifier_neuralnet"
        )
      } else {
        self$reliability$iota_objects_end <- NULL
      }

      if (is.null(self$reliability$iota_objects_end_free) == FALSE) {
        self$reliability$iota_object_end_free <- create_iota2_mean_object(
          iota2_list = self$reliability$iota_objects_end_free,
          original_cat_labels = self$model_config$target_levels,
          free_aem = TRUE,
          call = "aifeducation::te_classifier_neuralnet"
        )
      } else {
        self$reliability$iota_objects_end_free <- NULL
      }

      # Finalize standard measures
      standard_measures <- self$reliability$standard_measures_mean
      for (i in 1:self$last_training$config$n_folds) {
        for (tmp_cat in self$model_config$target_levels) {
          standard_measures[tmp_cat, "precision"] <- standard_measures[tmp_cat, "precision"] +
            self$reliability$standard_measures_end[[i]][tmp_cat, "precision"]
          standard_measures[tmp_cat, "recall"] <- standard_measures[tmp_cat, "recall"] +
            self$reliability$standard_measures_end[[i]][tmp_cat, "recall"]
          standard_measures[tmp_cat, "f1"] <- standard_measures[tmp_cat, "f1"] +
            self$reliability$standard_measures_end[[i]][tmp_cat, "f1"]
        }
      }
      self$reliability$standard_measures_mean <- standard_measures / self$last_training$config$n_folds
    },
    #--------------------------------------------------------------------------
    train_standard = function(iteration = NULL,
                              data_manager = NULL,
                              inc_synthetic = FALSE) {
      # Print status message to console
      if (self$last_training$config$trace == TRUE) {
        if (iteration <= self$last_training$config$n_folds) {
          message(paste(
            date(),
            "|", "Iteration", iteration, "from", self$last_training$config$n_folds
          ))
        } else {
          message(paste(
            date(),
            "|", "Final training"
          ))
        }
      }

      # Set the state of the DataManager
      data_manager$set_state(
        iteration = iteration,
        step = NULL
      )

      # Generate syntetic cases if requested
      if (inc_synthetic == TRUE) {
        data_manager$create_synthetic(
          trace = self$last_training$config$trace,
          inc_pseudo_data = FALSE
        )
      }

      # Get the different DataSets
      train_data <- data_manager$get_dataset(
        inc_labeled = TRUE,
        inc_synthetic = inc_synthetic,
        inc_pseudo_data = FALSE,
        inc_unlabeled = FALSE
      )
      val_data <- data_manager$get_val_dataset()
      if (iteration != "final") {
        test_data <- data_manager$get_test_dataset()
      } else {
        test_data <- NULL
      }

      # Print status to console
      if (self$last_training$config$trace == TRUE) {
        if (iteration <= self$last_training$config$n_folds) {
          message(paste(
            date(),
            "|", "Iteration", iteration, "from", self$last_training$config$n_folds,
            "|", "Training"
          ))
        } else {
          message(paste(
            date(),
            "|", "Final training",
            "|", "Training"
          ))
        }
      }

      # Start training
      train_history <- private$basic_train(
        train_data = train_data,
        val_data = val_data,
        test_data = test_data,
        reset_model = TRUE,
        use_callback = TRUE,
        log_dir = private$log_config$log_dir,
        log_write_interval = private$log_config$log_write_interval,
        log_top_value = iteration,
        log_top_total = self$last_training$config$n_folds + 1,
        log_top_message = "Overall"
      )

      # Save history
      self$last_training$history[iteration] <- list(train_history)

      # Calculate test metric
      if (!is.null(test_data) == TRUE) {
        private$calculate_test_metric(
          test_data = test_data,
          iteration = iteration,
          type = (as.numeric(inc_synthetic)) + 1
        )
      }
    },
    #--------------------------------------------------------------------------
    train_with_pseudo_labels = function(init_train = TRUE,
                                        iteration = NULL,
                                        data_manager = NULL,
                                        inc_synthetic = FALSE) {
      # If model is not trained than train for the first time
      # Necessary for estimating pseudo labels
      if (init_train == TRUE) {
        private$train_standard(
          iteration = iteration,
          data_manager = data_manager,
          inc_synthetic = inc_synthetic
        )
      }

      # Get validation and test data for training loop
      val_data <- data_manager$get_val_dataset()
      if (iteration != "final") {
        test_data <- data_manager$get_test_dataset()
      } else {
        test_data <- NULL
      }

      # Start training loop with pseudo labels
      data_manager$set_state(
        iteration = iteration,
        step = NULL
      )

      # Create list for saving training histories per step
      step_histories <- NULL

      for (step in 1:self$last_training$config$pl_max_steps) {
        # Print status message to console
        if (self$last_training$config$trace == TRUE) {
          if (iteration <= self$last_training$config$n_folds) {
            message(paste(
              date(),
              "|", "Iteration", iteration, "from", self$last_training$config$n_folds,
              "|", "Pseudo labeling", "step", step, "from", self$last_training$config$pl_max_steps
            ))
          } else {
            message(paste(
              date(),
              "|", "Final training",
              "|", "Pseudo labeling", "step", step, "from", self$last_training$config$pl_max_steps
            ))
          }
        }

        # Set correct state for the data_manager
        data_manager$set_state(
          iteration = iteration,
          step = step
        )

        # Generate pseudo labels
        pseudo_data <- private$estimate_pseudo_labels(
          unlabeled_data = data_manager$get_unlabeled_data(),
          val_data=val_data,
          current_step = step
        )

        # Save pseudo labels in the data_manager
        data_manager$add_replace_pseudo_data(
          inputs = pseudo_data$input,
          labels = pseudo_data$labels
        )

        # Remove old pseudo data
        rm(pseudo_data)

        # Generate synthetic data if requested
        if (inc_synthetic == TRUE) {
          data_manager$create_synthetic(
            trace = self$last_training$config$trace,
            inc_pseudo_data = TRUE
          )
        }

        # Request training data
        train_data <- data_manager$get_dataset(
          inc_labeled = TRUE,
          inc_synthetic = inc_synthetic,
          inc_pseudo_data = TRUE,
          inc_unlabeled = FALSE
        )

        # Print status to console
        if (self$last_training$config$trace == TRUE) {
          if (iteration <= self$last_training$config$n_folds) {
            message(paste(
              date(),
              "|", "Iteration", iteration, "from", self$last_training$config$n_folds,
              "|", "Training"
            ))
          } else {
            message(paste(
              date(),
              "|", "Final training",
              "|", "Training"
            ))
          }
        }

        # Start training
        train_history <- private$basic_train(
          train_data = train_data,
          val_data = val_data,
          test_data = test_data,
          reset_model = TRUE,
          use_callback = TRUE,
          log_dir = private$log_config$log_state_file,
          log_write_interval = private$log_config$log_write_interval,
          log_top_value = iteration,
          log_top_total = self$last_training$config$n_folds + 1,
          log_top_message = "Overall"
        )

        # Save history
        step_histories[step] <- list(train_history)
      }

      # Save the histories for the complete iteration
      self$last_training$history[iteration] <- list(step_histories)

      # Calculate test metric
      if (!is.null(test_data) == TRUE) {
        private$calculate_test_metric(
          test_data = test_data,
          iteration = iteration,
          type = 3
        )
      }
    },
    #--------------------------------------------------------------------------
    estimate_pseudo_labels = function(unlabeled_data,
                                      val_data,
                                      current_step) {
      # Predict pseudo labels for unlabeled data
      predicted_labels <- self$predict(
        newdata = unlabeled_data,
        ml_trace = self$last_training$config$ml_trace,
        batch_size = self$last_training$config$batch_size
      )

      # Create Matrix for saving the results
      new_categories <- matrix(
        nrow = nrow(predicted_labels),
        ncol = 2
      )
      rownames(new_categories) <- rownames(predicted_labels)
      colnames(new_categories) <- c("cat", "prob")

      #Correct probabilities for reliability on the validation data
      predicted_labels<-private$pseudo_labels_correct_prob(
        predictions=predicted_labels,
        val_data=val_data
      )

      # Gather information for every case. That is the category with the
      # highest probability and save both
      for (i in 1:nrow(predicted_labels)) {
        tmp_est_prob <- predicted_labels[i, 1:(ncol(predicted_labels) - 1)]
        new_categories[i, 1] <- which.max(tmp_est_prob) - 1
        new_categories[i, 2] <- max(tmp_est_prob)
      }
      new_categories <- as.data.frame(new_categories)

      # Transforming the probabilities to an information index
      new_categories[, 2] <- abs(
        self$last_training$config$pl_anchor -
          (as.numeric(new_categories[, 2]) - 1 / length(self$model_config$target_levels)) / (1 - 1 / length(self$model_config$target_levels))
      )
      new_categories <- as.data.frame(new_categories)

      # Reducing the new categories to the desired range
      condition <- (new_categories[, 2] >= self$last_training$config$pl_min &
        new_categories[, 2] <= self$last_training$config$pl_max)
      new_categories <- subset(new_categories, condition)

      # Calculate number of cases to include
      bpl_inc_ratio <- current_step / self$last_training$config$pl_max_steps
      n_cases_to_include <- nrow(new_categories) * bpl_inc_ratio

      # Order cases with increasing distance from maximal information
      new_categories <- new_categories[order(new_categories$prob, decreasing = FALSE), ]

      # Select the best cases
      names_final_new_categories <- rownames(new_categories)[1:n_cases_to_include]

      # Get the labels for these cases
      targets_pseudo_labeled <- new_categories[names_final_new_categories, 1]
      targets_pseudo_labeled <- as.numeric(targets_pseudo_labeled)
      names(targets_pseudo_labeled) <- names_final_new_categories

      # Transform pseudo labels to a factor
      targets_pseudo_labeled <- factor(
        x = targets_pseudo_labeled,
        levels = 0:(length(self$model_config$target_levels) - 1),
        labels = self$model_config$target_levels
      )

      # get the corresponding input
      unlabeled_data$set_format("np")
      embeddings <- unlabeled_data["input"]
      rownames(embeddings) <- unlabeled_data["id"]
      embeddings <- embeddings[names_final_new_categories, , ]

      # Return results
      pseudo_data <- list(
        input = embeddings,
        labels = targets_pseudo_labeled
      )

      return(pseudo_data)
    },
    #--------------------------------------------------------------------------
    pseudo_labels_correct_prob=function(predictions,val_data){
      #Predict on val data
      val_predictions <- self$predict(
        newdata = val_data,
        ml_trace = self$last_training$config$ml_trace,
        batch_size = self$last_training$config$batch_size
      )
      val_pred_cat <- val_predictions$expected_category
      names(val_pred_cat) <- rownames(val_predictions)
      val_pred_cat <- val_pred_cat[val_data["id"]]

      #Calculate Assignment Error Matrix
      val_data$set_format("np")
      val_iota_object<-iotarelr::check_new_rater(
        true_values = factor(
          x = val_data["labels"],
          levels = 0:(length(self$model_config$target_levels) - 1),
          labels = self$model_config$target_levels
        ),
        assigned_values = val_pred_cat,
        free_aem = TRUE
      )

      #Estimate probability of each category
      aem=val_iota_object$categorical_level$raw_estimates$assignment_error_matrix
      class_sizes=val_iota_object$information$est_true_cat_sizes
      p_cat=class_sizes%*%aem

      #Estimate probability that the category is the true category
      p_cat_true=class_sizes*diag(aem)/p_cat
      p_cat_true=replace(p_cat_true,list=is.nan(p_cat_true),values=0)

      #Correct probabilities
      number_columns=ncol(predictions)
      col=ncol(predictions)-1
      for(i in 1:nrow(predictions)){
        predictions[i,1:col]<-predictions[i,1:col]*p_cat_true/sum(predictions[i,1:col]*p_cat_true)
        predictions[i,number_columns]<-self$model_config$target_levels[which.max(as.numeric(predictions[i,1:col]))]
      }
      return(predictions)
    },
    #--------------------------------------------------------------------------
    basic_train = function(train_data = NULL,
                           val_data = NULL,
                           test_data = NULL,
                           reset_model = FALSE,
                           use_callback = TRUE,
                           log_dir = NULL,
                           log_write_interval = 10,
                           log_top_value = NULL,
                           log_top_total = NULL,
                           log_top_message = NULL) {
      # Clear session to provide enough resources for computations
      if (private$ml_framework == "tensorflow") {
        keras$backend$clear_session()
      } else if (private$ml_framework == "pytorch") {
        if (torch$cuda$is_available()) {
          torch$cuda$empty_cache()
        }
      }

      # Generating class weights
      if (self$last_training$config$balance_class_weights == TRUE) {
        abs_freq_classes <- table(train_data["labels"])
        class_weights <- as.vector(sum(abs_freq_classes) / (length(abs_freq_classes) * abs_freq_classes))
      } else {
        class_weights <- rep(x = 1, times = length(self$model_config$target_levels))
      }

      # Generating weights for sequence length
      if (self$last_training$config$balance_sequence_length == TRUE) {
        sequence_length <- train_data["length"]
        abs_freq_length <- table(sequence_length)

        sample_weight_per_sequence_length <- as.vector(sum(abs_freq_length) / (length(abs_freq_length) * abs_freq_length))
        sequence_order <- names(abs_freq_length)

        sample_weights <- vector(length = length(sequence_length))
        for (i in 1:length(sample_weights)) {
          idx <- which(sequence_length[i] == sequence_order)
          sample_weights[i] <- sample_weight_per_sequence_length[idx]
        }
      } else {
        sequence_length <- train_data["length"]
        sample_weights <- rep.int(x = 1, times = length(sequence_length))
      }

      # Reset model if requested
      if (reset_model == TRUE) {
        private$create_reset_model()
      }

      # Set Optimizer
      if (private$ml_framework == "tensorflow") {
        balanced_metric <- py$BalancedAccuracy(n_classes = as.integer(length(self$model_config$target_levels)))
        if (self$model_config$optimizer == "adam") {
          self$model$compile(
            loss = self$model_config$err_fct,
            optimizer = keras$optimizers$Adam(),
            metrics = c(self$model_config$metric, balanced_metric)
          )
        } else if (self$model_config$optimizer == "rmsprop") {
          self$model$compile(
            loss = self$model_config$err_fct,
            optimizer = keras$optimizers$RMSprop(),
            metrics = c(self$model_config$metric, balanced_metric)
          )
        }
      } else if (private$ml_framework == "pytorch") {
        loss_fct_name <- "CrossEntropyLoss"
        if (self$model_config$optimizer == "adam") {
          optimizer <- "adam"
        } else if (self$model_config$optimizer == "rmsprop") {
          optimizer <- "rmsprop"
        }
      }

      # Check directory for checkpoints
      create_dir(
        dir_path = paste0(self$last_training$config$dir_checkpoint, "/checkpoints"),
        trace = self$last_training$config$trace,
        msg = "Creating Checkpoint Directory")

      # Set target column
      if (self$model_config$require_one_hot == FALSE) {
        target_column <- "labels"
      } else {
        target_column <- "one_hot_encoding"
      }

      # Tensorflow - Callbacks and training
      if (private$ml_framework == "tensorflow") {
        if (use_callback == TRUE) {
          callback <- keras$callbacks$ModelCheckpoint(
            filepath = paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.h5"),
            monitor = paste0("val_", self$model_config$balanced_metric),
            verbose = as.integer(min(self$last_training$config$ml_trace, 1)),
            mode = "auto",
            save_best_only = TRUE,
            save_weights_only = TRUE
          )
        } else {
          callback <- reticulate::py_none()
        }

        data_set_weights <- datasets$Dataset$from_dict(
          reticulate::dict(list(
            sample_weights = sample_weights
          ))
        )
        # inputs, targets, sample_weights
        dataset_tf <- train_data$add_column("sample_weights", data_set_weights["sample_weights"])
        dataset_tf <- dataset_tf$rename_column("input", "input_embeddings")

        # Choose correct target column and rename
        dataset_tf <- dataset_tf$rename_column(target_column, "targets")

        dataset_tf$set_format("tf")
        tf_dataset_train <- dataset_tf$to_tf_dataset(
          columns = c("input_embeddings", "sample_weights"),
          batch_size = as.integer(self$last_training$config$batch_size),
          shuffle = TRUE,
          label_cols = "targets"
        )
        # Add sample weights
        tf_dataset_train <- tf_dataset_train$map(py$extract_sample_weight)

        dataset_tf_val <- val_data$rename_column("input", "input_embeddings")
        # Choose correct target column and rename
        dataset_tf_val <- dataset_tf_val$rename_column(target_column, "targets")

        tf_dataset_val <- dataset_tf_val$to_tf_dataset(
          columns = c("input_embeddings"),
          batch_size = as.integer(self$last_training$config$batch_size),
          shuffle = FALSE,
          label_cols = "targets"
        )

        history <- self$model$fit(
          verbose = as.integer(self$last_training$config$ml_trace),
          x = tf_dataset_train,
          validation_data = tf_dataset_val,
          epochs = as.integer(self$last_training$config$epochs),
          callbacks = callback,
          class_weight = reticulate::py_dict(keys = names(class_weights), values = class_weights)
        )$history

        if (self$model_config$n_categories == 2) {
          history <- list(
            loss = rbind(history$loss, history$val_loss),
            accuracy = rbind(history$binary_accuracy, history$val_binary_accuracy),
            balanced_accuracy = rbind(history$balanced_accuracy, history$val_balanced_accuracy)
          )
        } else {
          history <- list(
            loss = rbind(history$loss, history$val_loss),
            accuracy = rbind(history$categorical_accuracy, history$val_categorical_accuracy),
            balanced_accuracy = rbind(history$balanced_accuracy, history$val_balanced_accuracy)
          )
        }

        if (use_callback == TRUE) {
          self$model$load_weights(paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.h5"))
        }

        # PyTorch - Callbacks and training
      } else if (private$ml_framework == "pytorch") {
        data_set_weights <- datasets$Dataset$from_dict(
          reticulate::dict(list(
            sample_weights = sample_weights
          ))
        )

        dataset_train <- train_data$add_column("sample_weights", data_set_weights["sample_weights"])
        dataset_train <- dataset_train$select_columns(c("input", target_column, "sample_weights"))
        if (self$model_config$require_one_hot == TRUE) {
          dataset_train <- dataset_train$rename_column(target_column, "labels")
        }

        pytorch_train_data <- dataset_train$with_format("torch")

        pytorch_val_data <- val_data$select_columns(c("input", target_column))
        if (self$model_config$require_one_hot == TRUE) {
          pytorch_val_data <- pytorch_val_data$rename_column(target_column, "labels")
        }
        pytorch_val_data <- pytorch_val_data$with_format("torch")

        if (!is.null(test_data)) {
          pytorch_test_data <- test_data$select_columns(c("input", target_column))
          if (self$model_config$require_one_hot == TRUE) {
            pytorch_test_data <- pytorch_test_data$rename_column(target_column, "labels")
          }
          pytorch_test_data <- pytorch_test_data$with_format("torch")
        } else {
          pytorch_test_data <- NULL
        }

        history <- py$TeClassifierTrain_PT_with_Datasets(
          model = self$model,
          loss_fct_name = loss_fct_name,
          optimizer_method = self$model_config$optimizer,
          epochs = as.integer(self$last_training$config$epochs),
          trace = as.integer(self$last_training$config$ml_trace),
          use_callback = use_callback,
          batch_size = as.integer(self$last_training$config$batch_size),
          train_data = pytorch_train_data,
          val_data = pytorch_val_data,
          test_data = pytorch_test_data,
          filepath = paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.pt"),
          n_classes = as.integer(length(self$model_config$target_levels)),
          class_weights = torch$tensor(np$array(class_weights)),
          log_dir = log_dir,
          log_write_interval = log_write_interval,
          log_top_value = log_top_value,
          log_top_total = log_top_total,
          log_top_message = log_top_message
        )
      }

      #provide rownames and replace -100
      history<-private$prepare_history_data(history)
      return(history)
    },
    #--------------------------------------------------------------------------
    set_feature_extractor = function(feature_extractor) {
      # Check
      check_class(feature_extractor, "TEFeatureExtractor", TRUE)
      if (!is.null(feature_extractor)) {
        if (feature_extractor$get_ml_framework() != private$ml_framework) {
          stop("The machine learning framework of the feature extractior and
                 classifier do not match. Please provide a feature extractor
                 with the same machine learning framework as the classifier.")
        }

        if (feature_extractor$is_trained() == FALSE) {
          stop("The supplied feature extractor is not trained. Please
                provide train and try again.")
        }

        self$model_config$use_fe <- TRUE
        self$model_config$features <- feature_extractor$model_config$features
        self$feature_extractor <- feature_extractor$clone(deep = TRUE)
      } else {
        self$model_config$use_fe <- FALSE
        self$model_config$features <- private$text_embedding_model[["features"]]
      }
    },
    #--------------------------------------------------------------------------
    adjust_configuration = function() {
      if (self$model_config$rec_layers!=0 & self$model_config$self_attention_heads > 0) {
        if (self$model_config$features %% 2 != 0) {
          stop("The number of features of the TextEmbeddingmodel is
               not a multiple of 2.")
        }
      }

      if (is.null(self$model_config$intermediate_size) == TRUE) {
        if (self$model_config$attention_type == "fourier" & self$model_config$rec_layers > 0) {
          self$model_config$intermediate_size <- 2 * self$model_config$rec_size
        } else if (self$model_config$attention_type == "fourier" & self$model_config$rec_layers == 0) {
          self$model_config$intermediate_size <- 2 * self$model_config$features
        } else if (self$model_config$attention_type == "multihead" & self$model_config$rec_layers > 0 & self$model_config$self_attention_heads > 0) {
          self$model_config$intermediate_size <- 2 * self$model_config$features
        } else if (self$model_config$attention_type == "multihead" & self$model_config$rec_layers == 0 & self$model_config$self_attention_heads > 0) {
          self$model_config$intermediate_size <- 2 * self$model_config$features
        } else {
          self$model_config$intermediate_size <- NULL
        }
      }

      if(self$model_config$rec_layers<=1){
        self$model_config$rec_dropout=0.0
      }
      if(self$model_config$rec_layers<=0){
        self$model_config$rec_size=0
      }

      if(self$model_config$dense_layers<=1){
        self$model_config$dense_dropout=0.0
      }
      if(self$model_config$dense_layers<=0){
        self$model_config$dense_size=0
      }
    },
    #--------------------------------------------------------------------------
    check_and_adjust_target_levels = function(data_targets) {
      if (sum(levels(data_targets) %in% self$model_config$target_levels) != self$model_config$n_categories) {
        warning(
          paste(
            "data_targets contains levels that are not defined for the classifier",
            "Defined levels are", self$model_config$target_levels, ".",
            "Please check your data or create a new classifier and pass
                all levels to the classifier's configuration."
          )
        )
      }

      tmp_data <- as.character(data_targets)
      tmp_data <- factor(
        x = tmp_data,
        levels = self$model_config$target_levels,
        ordered = TRUE
      )
      names(tmp_data) <- names(data_targets)
      return(tmp_data)
    }
  )
)

Try the aifeducation package in your browser

Any scripts or data that you put into this service are public.

aifeducation documentation built on April 4, 2025, 2:01 a.m.