R/obj_TEClassifierProtoNet.R

# This file is part of the R package "aifeducation".
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as published by
# the Free Software Foundation.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>

#' @title Text embedding classifier with a ProtoNet
#' @description Abstract class for neural nets with 'pytorch'.
#'
#'    This class is **deprecated**. Please use an Object of class [TEClassifierSequentialPrototype] instead.
#'
#'   This object represents in implementation of a prototypical network for few-shot learning as described by Snell,
#'   Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The
#'   network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
#'
#' @return Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
#'   classifier an object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] and a `factor` are necessary. The
#'   object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] contains the numerical text representations (text
#'   embeddings) of the raw texts generated by an object of class [TextEmbeddingModel]. The `factor` contains the
#'   classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
#'   class [EmbeddedText] or [LargeDataSetForTextEmbeddings] has to be used which was created with the same
#'   [TextEmbeddingModel] as for training.
#'
#'
#' @references Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved
#'   few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
#' @references Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning.
#'   https://doi.org/10.48550/arXiv.1703.05175
#' @references Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H.
#'   Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery
#'   and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing.
#'   https://doi.org/10.1007/978-3-030-16145-3_24
#' @family Classification
#' @export
TEClassifierProtoNet <- R6::R6Class(
  classname = "TEClassifierProtoNet",
  inherit = TEClassifiersBasedOnProtoNet,
  public = list(
    #' @description Creating a new instance of this class.
    #' @return Returns an object of class [TEClassifierProtoNet] which is ready for configuration.
    initialize = function() {
      message("TEClassifierProtoNet is deprecated. Please use TEClassifierSequentialPrototype.")
    },
    # New-----------------------------------------------------------------------
    #' @description Creating a new instance of this class.
    #' @param name `r get_param_doc_desc("name")`
    #' @param label `r get_param_doc_desc("label")`
    #' @param text_embeddings `r get_param_doc_desc("text_embeddings")`
    #' @param feature_extractor `r get_param_doc_desc("feature_extractor")`
    #' @param bias `r get_param_doc_desc("bias")`
    #' @param target_levels `r get_param_doc_desc("target_levels")`
    #' @param dense_layers `r get_param_doc_desc("dense_layers")`
    #' @param dense_size `r get_param_doc_desc("dense_size")`
    #' @param rec_layers `r get_param_doc_desc("rec_layers")`
    #' @param rec_size `r get_param_doc_desc("rec_size")`
    #' @param rec_type `r get_param_doc_desc("rec_type")`
    #' @param rec_bidirectional `r get_param_doc_desc("rec_bidirectional")`
    #' @param attention_type `r get_param_doc_desc("attention_type")`
    #' @param self_attention_heads `r get_param_doc_desc("self_attention_heads")`
    #' @param repeat_encoder `r get_param_doc_desc("repeat_encoder")`
    #' @param intermediate_size `r get_param_doc_desc("intermediate_size")`
    #' @param add_pos_embedding `r get_param_doc_desc("add_pos_embedding")`
    #' @param act_fct `r get_param_doc_desc("act_fct")`
    #' @param parametrizations `r get_param_doc_desc("parametrizations")`
    #' @param encoder_dropout `r get_param_doc_desc("encoder_dropout")`
    #' @param dense_dropout `r get_param_doc_desc("dense_dropout")`
    #' @param rec_dropout `r get_param_doc_desc("rec_dropout")`
    #' @param embedding_dim `r get_param_doc_desc("embedding_dim")`
    #' @note This model requires `pad_value=0`. If this condition is not met the
    #' padding value is switched automatically.
    configure = function(name = NULL,
                         label = NULL,
                         text_embeddings = NULL,
                         feature_extractor = NULL,
                         target_levels = NULL,
                         dense_size = 4L,
                         dense_layers = 0L,
                         rec_size = 4L,
                         rec_layers = 2L,
                         rec_type = "GRU",
                         rec_bidirectional = FALSE,
                         embedding_dim = 2L,
                         self_attention_heads = 0L,
                         intermediate_size = NULL,
                         attention_type = "Fourier",
                         add_pos_embedding = TRUE,
                         act_fct = "ELU",
                         parametrizations = "None",
                         rec_dropout = 0.1,
                         repeat_encoder = 1L,
                         dense_dropout = 0.4,
                         encoder_dropout = 0.1) {
      private$do_configuration(args = get_called_args(n = 1L), one_hot_encoding = FALSE)
    },
    #---------------------------------------------------------------------------
    #' @description Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of
    #'   texts created by an object of class [TextEmbeddingModel]. These embeddings embed documents according to their
    #'   similarity to specific classes.
    #' @param embeddings_q Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] containing the text
    #'   embeddings for all cases which should be embedded into the classification space.
    #' @param batch_size `int` batch size.
    #' @return Returns a `list` containing the following elements
    #'
    #' * `embeddings_q`: embeddings for the cases (query sample).
    #' * `embeddings_prototypes`: embeddings of the prototypes which were learned during training. They represents the
    #' center for the different classes.
    #'
    embed = function(embeddings_q = NULL, batch_size = 32L) {
      check_class(embeddings_q, object_name = "embeddings_q", c("EmbeddedText", "LargeDataSetForTextEmbeddings"), FALSE)
      check_type(batch_size, object_name = "batch_size", "int", FALSE)

      # Check input for compatible text embedding models and feature extractors
      if (inherits(embeddings_q, "EmbeddedText")) {
        self$check_embedding_model(text_embeddings = embeddings_q)
        requires_compression <- self$requires_compression(embeddings_q)
      } else if (inherits(embeddings_q, "array")) {
        requires_compression <- self$requires_compression(embeddings_q)
      } else {
        requires_compression <- FALSE
      }

      # Load Custom Model Scripts
      private$load_reload_python_scripts()

      # Check number of cases in the data
      single_prediction <- private$check_single_prediction(embeddings_q)

      # Get current row names/name of the cases
      current_row_names <- private$get_rownames_from_embeddings(embeddings_q)

      # Apply feature extractor if it is part of the model
      if (requires_compression) {
        # Returns a data set
        embeddings_q <- self$feature_extractor$extract_features(
          data_embeddings = embeddings_q,
          batch_size = as.integer(batch_size)
        )
      }


      # If at least two cases are part of the data set---------------------------
      if (!single_prediction) {
        # Returns a data set object
        prediction_data_q_embeddings <- private$prepare_embeddings_as_dataset(embeddings_q)

        if (private$ml_framework == "pytorch") {
          prediction_data_q_embeddings$set_format("torch")
          embeddings_and_distances <- py$TeProtoNetBatchEmbedDistance(
            model = private$model,
            dataset_q = prediction_data_q_embeddings,
            batch_size = as.integer(batch_size)
          )
          embeddings_tensors_q <- tensor_to_numpy(embeddings_and_distances[[1L]])
          distances_tensors_q <- tensor_to_numpy(embeddings_and_distances[[2L]])
        }
      } else {
        prediction_data_q_embeddings <- private$prepare_embeddings_as_np_array(embeddings_q)

        # Apply feature extractor if it is part of the model
        if (requires_compression) {
          # Returns a data set
          prediction_data_q <- np$array(self$feature_extractor$extract_features(
            data_embeddings = prediction_data_q_embeddings,
            batch_size = as.integer(batch_size)
          )$embeddings)
        }

        if (private$ml_framework == "pytorch") {
          if (torch$cuda$is_available()) {
            device <- "cuda"
            dtype <- torch$double
            private$model$to(device, dtype = dtype)
            private$model$eval()
            input <- torch$from_numpy(prediction_data_q_embeddings)
            embeddings_tensors_q <- private$model$embed(input$to(device, dtype = dtype))
            embeddings_tensors_q <- tensor_to_numpy(embeddings_tensors_q)
            distances_tensors_q <- private$model$get_distances(input$to(device, dtype = dtype))
            distances_tensors_q <- tensor_to_numpy(distances_tensors_q)
          } else {
            device <- "cpu"
            dtype <- torch$float
            private$model$to(device, dtype = dtype)
            private$model$eval()
            input <- torch$from_numpy(prediction_data_q_embeddings)
            embeddings_tensors_q <- private$model$embed(input$to(device, dtype = dtype))
            embeddings_tensors_q <- tensor_to_numpy(embeddings_tensors_q)
            distances_tensors_q <- private$model$get_distances(input$to(device, dtype = dtype))
            distances_tensors_q <- tensor_to_numpy(distances_tensors_q)
          }
        }
      }

      if (private$ml_framework == "pytorch") {
        embeddings_prototypes <- tensor_to_numpy(
          private$model$get_trained_prototypes()
        )
      }

      # Post processing
      rownames(embeddings_tensors_q) <- current_row_names
      rownames(distances_tensors_q) <- current_row_names
      rownames(embeddings_prototypes) <- private$model_config$target_levels

      return(list(
        embeddings_q = embeddings_tensors_q,
        distances_q = distances_tensors_q,
        embeddings_prototypes = embeddings_prototypes
      ))
    },
    #---------------------------------------------------------------------------
    #' @description Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
    #' @param embeddings_q Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] containing the text
    #'   embeddings for all cases which should be embedded into the classification space.
    #' @param classes_q Named `factor` containg the true classes for every case. Please note that the names must match
    #'   the names/ids in `embeddings_q`.
    #' @param inc_unlabeled `bool` If `TRUE` plot includes unlabeled cases as data points.
    #' @param size_points `int` Size of the points excluding the points for prototypes.
    #' @param size_points_prototypes `int` Size of points representing prototypes.
    #' @param alpha `float` Value indicating how transparent the points should be (important
    #'   if many points overlap). Does not apply to points representing prototypes.
    #' @param batch_size `int` batch size.
    #' @return Returns a plot of class `ggplot`visualizing embeddings.
    plot_embeddings = function(embeddings_q,
                               classes_q = NULL,
                               batch_size = 12L,
                               alpha = 0.5,
                               size_points = 3L,
                               size_points_prototypes = 8L,
                               inc_unlabeled = TRUE) {
      # Argument checking-------------------------------------------------------


      embeddings <- self$embed(
        embeddings_q = embeddings_q,
        batch_size = batch_size
      )

      prototypes <- as.data.frame(embeddings$embeddings_prototypes)
      prototypes$class <- rownames(embeddings$embeddings_prototypes)
      prototypes$type <- rep("prototype", nrow(embeddings$embeddings_prototypes))
      colnames(prototypes) <- c("x", "y", "class", "type")


      if (!is.null(classes_q)) {
        true_values_names <- intersect(
          x = names(na.omit(classes_q)),
          y = private$get_rownames_from_embeddings(embeddings_q)
        )
        true_values <- as.data.frame(embeddings$embeddings_q[true_values_names, , drop = FALSE])
        true_values$class <- classes_q[true_values_names]
        true_values$type <- rep("labeled", length(true_values_names))
        colnames(true_values) <- c("x", "y", "class", "type")
      } else {
        true_values_names <- NULL
        true_values <- NULL
      }


      if (inc_unlabeled) {
        estimated_values_names <- setdiff(
          x = private$get_rownames_from_embeddings(embeddings_q),
          y = true_values_names
        )

        if (length(estimated_values_names) > 0L) {
          estimated_values <- as.data.frame(embeddings$embeddings_q[estimated_values_names, , drop = FALSE])
          estimated_values$class <- private$calc_classes_on_distance(
            distance_matrix = embeddings$distances_q[estimated_values_names, , drop = FALSE],
            prototypes = embeddings$embeddings_prototypes
          )
          estimated_values$type <- rep("unlabeled", length(estimated_values_names))
          colnames(estimated_values) <- c("x", "y", "class", "type")
        }
      } else {
        estimated_values_names <- NULL
      }


      plot_data <- prototypes
      if (length(true_values) > 0L) {
        plot_data <- rbind(plot_data, true_values)
      }
      if (length(estimated_values_names) > 0L) {
        plot_data <- rbind(plot_data, estimated_values)
      }

      tmp_plot <- ggplot2::ggplot(data = plot_data) +
        ggplot2::geom_point(
          mapping = ggplot2::aes(
            x = x,
            y = y,
            color = class,
            shape = type,
            size = type,
            alpha = type
          ) # ,
          # position = ggplot2::position_jitter(h = 0.1, w = 0.1)
        ) +
        ggplot2::scale_size_manual(values = c(
          prototype = size_points_prototypes,
          labeled = size_points,
          unlabeled = size_points
        )) +
        ggplot2::scale_alpha_manual(
          values = c(
            prototype = 1L,
            labeled = alpha,
            unlabeled = alpha
          )
        ) +
        ggplot2::theme_classic()
      return(tmp_plot)
    }
  ),
  private = list(
    # Private--------------------------------------------------------------------------
    create_reset_model = function() {
      private$check_config_for_TRUE()

      private$load_reload_python_scripts()

      private$model <- py$TextEmbeddingClassifierProtoNet_PT(
        features = as.integer(private$model_config$features),
        times = as.integer(private$model_config$times),
        dense_size = as.integer(private$model_config$dense_size),
        dense_layers = as.integer(private$model_config$dense_layers),
        rec_size = as.integer(private$model_config$rec_size),
        rec_layers = as.integer(private$model_config$rec_layers),
        rec_type = private$model_config$rec_type,
        rec_bidirectional = private$model_config$rec_bidirectional,
        intermediate_size = as.integer(private$model_config$intermediate_size),
        attention_type = private$model_config$attention_type,
        repeat_encoder = as.integer(private$model_config$repeat_encoder),
        dense_dropout = private$model_config$dense_dropout,
        rec_dropout = private$model_config$rec_dropout,
        encoder_dropout = private$model_config$encoder_dropout,
        add_pos_embedding = private$model_config$add_pos_embedding,
        pad_value = as.integer(private$text_embedding_model$pad_value),
        self_attention_heads = as.integer(private$model_config$self_attention_heads),
        embedding_dim = as.integer(private$model_config$embedding_dim),
        target_levels = reticulate::np_array(seq(from = 0L, to = (length(private$model_config$target_levels) - 1L))),
        act_fct = private$model_config$act_fct,
        parametrizations = private$model_config$parametrizations
      )
    },
    #--------------------------------------------------------------------------
    basic_train = function(train_data = NULL,
                           val_data = NULL,
                           test_data = NULL,
                           reset_model = FALSE,
                           use_callback = TRUE,
                           log_dir = NULL,
                           log_write_interval = 10L,
                           log_top_value = NULL,
                           log_top_total = NULL,
                           log_top_message = NULL) {
      # Clear session to provide enough resources for computations
      if (torch$cuda$is_available()) {
        torch$cuda$empty_cache()
      }

      # Reset model if requested
      if (reset_model) {
        private$create_reset_model()
      }

      # Set loss function
      loss_cls_fct_name <- "ProtoNetworkMargin"

      # Set target column
      if (!private$model_config$require_one_hot) {
        target_column <- "labels"
      } else {
        target_column <- "one_hot_encoding"
      }

      dataset_train <- train_data$select_columns(c("input", target_column))
      if (private$model_config$require_one_hot) {
        dataset_train <- dataset_train$rename_column(target_column, "labels")
      }

      pytorch_train_data <- dataset_train$with_format("torch")

      pytorch_val_data <- val_data$select_columns(c("input", target_column))
      if (private$model_config$require_one_hot) {
        pytorch_val_data <- pytorch_val_data$rename_column(target_column, "labels")
      }
      pytorch_val_data <- pytorch_val_data$with_format("torch")

      if (!is.null(test_data)) {
        pytorch_test_data <- test_data$select_columns(c("input", target_column))
        if (private$model_config$require_one_hot) {
          pytorch_test_data <- pytorch_test_data$rename_column(target_column, "labels")
        }
        pytorch_test_data <- pytorch_test_data$with_format("torch")
      } else {
        pytorch_test_data <- NULL
      }

      tmp_history <- py$TeClassifierProtoNetTrain_PT_with_Datasets(
        model = private$model,
        loss_fct_name = self$last_training$config$loss_pt_fct_name,
        optimizer_method = self$last_training$config$optimizer,
        lr_rate = self$last_training$config$lr_rate,
        lr_warm_up_ratio = self$last_training$config$lr_warm_up_ratio,
        Ns = as.integer(self$last_training$config$Ns),
        Nq = as.integer(self$last_training$config$Nq),
        loss_alpha = self$last_training$config$loss_alpha,
        loss_margin = self$last_training$config$loss_margin,
        trace = as.integer(self$last_training$config$ml_trace),
        use_callback = use_callback,
        train_data = pytorch_train_data,
        val_data = pytorch_val_data,
        test_data = pytorch_test_data,
        epochs = as.integer(self$last_training$config$epochs),
        sampling_separate = self$last_training$config$sampling_separate,
        sampling_shuffle = self$last_training$config$sampling_shuffle,
        filepath = file.path(private$dir_checkpoint, "best_weights.pt"),
        n_classes = as.integer(length(private$model_config$target_levels)),
        log_dir = log_dir,
        log_write_interval = log_write_interval,
        log_top_value = log_top_value,
        log_top_total = log_top_total,
        log_top_message = log_top_message
      )

      # provide rownames and replace -100
      tmp_history <- private$prepare_history_data(tmp_history)

      return(tmp_history)
    },
    #--------------------------------------------------------------------------
    load_reload_python_scripts = function() {
      super$load_reload_python_scripts()
      load_py_scripts("pytorch_old_scripts.py")
    },
    #--------------------------------------------------------------------------
    check_param_combinations_configuration = function() {
      if (private$model_config$dense_layers > 0L) {
        if (private$model_config$dense_size < 1L) {
          stop("Dense layers added. Size for dense layers must be at least 1.")
        }
      }

      if (private$model_config$rec_layers > 0L) {
        if (private$model_config$rec_size < 1L) {
          stop("Recurrent  layers added. Size for recurrent layers must be at least 1.")
        }
      }

      if (private$model_config$repeat_encoder > 0L &
        private$model_config$attention_type == "MultiHead" &
        private$model_config$self_attention_heads <= 0L) {
        stop("Encoder layer is set to 'multihead'. This requires self_attention_heads>=1.")
      }

      if (private$model_config$rec_layers != 0L & private$model_config$self_attention_heads > 0L) {
        if (private$model_config$features %% 2L != 0L) {
          stop("The number of features of the TextEmbeddingmodel is
               not a multiple of 2.")
        }
      }

      if (private$model_config$rec_layers == 1L && private$model_config$rec_dropout > 0.0) {
        print_message(
          msg = "Dropout for recurrent layers requires at least two layers. Setting rec_dropout to 0.0.",
          trace = TRUE
        )
        private$model_config$rec_dropout <- 0.0
      }
    },
    #--------------------------------------------------------------------------
    adjust_configuration = function() {
      if (is.null(private$model_config$intermediate_size)) {
        if (private$model_config$attention_type == "Fourier" & private$model_config$rec_layers > 0L) {
          private$model_config$intermediate_size <- 2L * private$model_config$rec_size
        } else if (private$model_config$attention_type == "Fourier" & private$model_config$rec_layers == 0L) {
          private$model_config$intermediate_size <- 2L * private$model_config$features
        } else if (
          private$model_config$attention_type == "MultiHead" &
            private$model_config$rec_layers > 0L &
            private$model_config$self_attention_heads > 0L
        ) {
          private$model_config$intermediate_size <- 2L * private$model_config$features
        } else if (
          private$model_config$attention_type == "MultiHead" &
            private$model_config$rec_layers == 0L &
            private$model_config$self_attention_heads > 0L
        ) {
          private$model_config$intermediate_size <- 2L * private$model_config$features
        } else {
          private$model_config$intermediate_size <- NULL
        }
      }

      if (private$model_config$rec_layers <= 1L) {
        private$model_config$rec_dropout <- 0.0
      }
      if (private$model_config$rec_layers <= 0L) {
        private$model_config$rec_size <- 0L
      }

      if (private$model_config$dense_layers <= 1L) {
        private$model_config$dense_dropout <- 0.0
      }
      if (private$model_config$dense_layers <= 0L) {
        private$model_config$dense_size <- 0L
      }
    },
    #-------------------------------------------------------------------------
    calc_classes_on_distance = function(distance_matrix, prototypes) {
      index_vector <- vector(length = nrow(distance_matrix))

      for (i in seq_along(index_vector)) {
        index_vector[i] <- which.min(distance_matrix[i, ])
      }

      classes <- factor(index_vector,
        levels = seq_len(nrow(prototypes)),
        labels = rownames(prototypes)
      )
      return(classes)
    }
  )
)

# Add Classifier to central index
TEClassifiers_class_names <- append(x = TEClassifiers_class_names, values = "TEClassifierProtoNet")

Try the aifeducation package in your browser

Any scripts or data that you put into this service are public.

aifeducation documentation built on Nov. 19, 2025, 5:08 p.m.