Nothing
# This file is part of the R package "aifeducation".
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as published by
# the Free Software Foundation.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
#' @title Text embedding classifier with a ProtoNet
#' @description Abstract class for neural nets with 'keras'/'tensorflow' and 'pytorch'.
#'
#' This object represents in implementation of a prototypical network for few-shot learning as described by Snell,
#' Swersky, and Zemel (2017). The network uses a multi way contrastive loss described by Zhang et al. (2019). The
#' network learns to scale the metric as described by Oreshkin, Rodriguez, and Lacoste (2018)
#'
#' @return Objects of this class are used for assigning texts to classes/categories. For the creation and training of a
#' classifier an object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] and a `factor` are necessary. The
#' object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] contains the numerical text representations (text
#' embeddings) of the raw texts generated by an object of class [TextEmbeddingModel]. The `factor` contains the
#' classes/categories for every text. Missing values (unlabeled cases) are supported. For predictions an object of
#' class [EmbeddedText] or [LargeDataSetForTextEmbeddings] has to be used which was created with the same
#' [TextEmbeddingModel] as for training.
#'
#' @references Oreshkin, B. N., Rodriguez, P. & Lacoste, A. (2018). TADAM: Task dependent adaptive metric for improved
#' few-shot learning. https://doi.org/10.48550/arXiv.1805.10123
#' @references Snell, J., Swersky, K. & Zemel, R. S. (2017). Prototypical Networks for Few-shot Learning.
#' https://doi.org/10.48550/arXiv.1703.05175
#' @references Zhang, X., Nie, J., Zong, L., Yu, H. & Liang, W. (2019). One Shot Learning with Margin. In Q. Yang, Z.-H.
#' Zhou, Z. Gong, M.-L. Zhang & S.-J. Huang (Eds.), Lecture Notes in Computer Science. Advances in Knowledge Discovery
#' and Data Mining (Vol. 11440, pp. 305–317). Springer International Publishing.
#' https://doi.org/10.1007/978-3-030-16145-3_24
#' @family Classification
#' @export
TEClassifierProtoNet <- R6::R6Class(
classname = "TEClassifierProtoNet",
inherit = TEClassifierRegular,
public = list(
# New-----------------------------------------------------------------------
#' @description Creating a new instance of this class.
#' @param ml_framework `string` Currently only pytorch is supported (`ml_framework="pytorch"`).
#' @param name `string` Name of the new classifier. Please refer to common name conventions. Free text can be used
#' with parameter `label`.
#' @param label `string` Label for the new classifier. Here you can use free text.
#' @param text_embeddings An object of class [TextEmbeddingModel] or [LargeDataSetForTextEmbeddings].
#' @param target_levels `vector` containing the levels (categories or classes) within the target data. Please not
#' that order matters. For ordinal data please ensure that the levels are sorted correctly with later levels
#' indicating a higher category/class. For nominal data the order does not matter.
#' @param feature_extractor Object of class [TEFeatureExtractor] which should be used in order to reduce the number
#' of dimensions of the text embeddings. If no feature extractor should be applied set `NULL`.
#' @param dense_layers `int` Number of dense layers.
#' @param dense_size `int` Number of neurons for each dense layer.
#' @param rec_layers `int` Number of recurrent layers.
#' @param rec_size `int` Number of neurons for each recurrent layer.
#' @param rec_type `string` Type of the recurrent layers.`rec_type="gru"` for Gated Recurrent Unit and
#' `rec_type="lstm"` for Long Short-Term Memory.
#' @param rec_bidirectional `bool` If `TRUE` a bidirectional version of the recurrent layers is used.
#' @param embedding_dim `int` determining the number of dimensions for the text embedding.
#' @param attention_type `string` Choose the relevant attention type. Possible values are `"fourier"` and
#' `"multihead"`. Please note that you may see different values for a case for different input orders if you choose `fourier` on linux.
#' @param self_attention_heads `int` determining the number of attention heads for a self-attention layer. Only
#' relevant if `attention_type="multihead"`.
#' @param repeat_encoder `int` determining how many times the encoder should be added to the network.
#' @param intermediate_size `int` determining the size of the projection layer within a each transformer encoder.
#' @param add_pos_embedding `bool` `TRUE` if positional embedding should be used.
#' @param encoder_dropout `double` ranging between 0 and lower 1, determining the dropout for the dense projection
#' within the encoder layers.
#' @param dense_dropout `double` ranging between 0 and lower 1, determining the dropout between dense layers.
#' @param rec_dropout `double` ranging between 0 and lower 1, determining the dropout between bidirectional
#' recurrent layers.
#' @param recurrent_dropout `double` ranging between 0 and lower 1, determining the recurrent dropout for each
#' recurrent layer. Only relevant for keras models.
#' @param optimizer `string` `"adam"` or `"rmsprop"` .
#' @return Returns an object of class [TEClassifierProtoNet] which is ready for training.
configure = function(ml_framework = "pytorch",
name = NULL,
label = NULL,
text_embeddings = NULL,
feature_extractor = NULL,
target_levels = NULL,
dense_size = 4,
dense_layers=0,
rec_size = 4,
rec_layers=2,
rec_type = "gru",
rec_bidirectional = FALSE,
embedding_dim = 2,
self_attention_heads = 0,
intermediate_size = NULL,
attention_type = "fourier",
add_pos_embedding = TRUE,
rec_dropout = 0.1,
repeat_encoder = 1,
dense_dropout = 0.4,
recurrent_dropout = 0.4,
encoder_dropout = 0.1,
optimizer = "adam") {
# check if already configured
private$check_config_for_FALSE()
# Checking of parameters--------------------------------------------------
if ((ml_framework %in% c("pytorch")) == FALSE) {
stop("ml_framework must be 'pytorch'.")
}
private$check_embeddings_object_type(text_embeddings, strict = TRUE)
check_type(name, type = "string", FALSE)
check_type(label, type = "string", FALSE)
check_type(target_levels, c("vector"), FALSE)
check_type(dense_size, type = "int", FALSE)
check_type(dense_layers, type = "int", FALSE)
if(dense_layers>0){
if(dense_size<1){
stop("Dense layers added. Size for dense layers must be at least 1.")
}
}
check_type(rec_size, type = "int", FALSE)
check_type(rec_layers, type = "int", FALSE)
if(rec_layers>0){
if(rec_size<1){
stop("Recurrent layers added. Size for recurrent layers must be at least 1.")
}
}
check_type(self_attention_heads, type = "int", FALSE)
if (optimizer %in% c("adam", "rmsprop") == FALSE) {
stop("Optimzier must be 'adam' oder 'rmsprop'.")
}
check_type(embedding_dim, "int", FALSE)
if (embedding_dim <= 0) {
stop("Embedding_dim must be an integer greater 0.")
}
if (attention_type %in% c("fourier", "multihead") == FALSE) {
stop("Optimzier must be 'fourier' oder 'multihead'.")
}
if (repeat_encoder > 0 & attention_type == "multihead" & self_attention_heads <= 0) {
stop("Encoder layer is set to 'multihead'. This requires self_attention_heads>=1.")
}
#------------------------------------------------------------------------
# Setting ML Framework
private$ml_framework <- ml_framework
# Setting Label and Name
private$set_model_info(
model_name_root = name,
model_id = generate_id(16),
label = label,
model_date = date()
)
# Set TextEmbeddingModel
private$set_text_embedding_model(
model_info = text_embeddings$get_model_info(),
feature_extractor_info = text_embeddings$get_feature_extractor_info(),
times = text_embeddings$get_times(),
features = text_embeddings$get_features()
)
# Saving Configuration
config <- list(
use_fe = FALSE,
features = private$text_embedding_model[["features"]],
times = private$text_embedding_model[["times"]],
dense_size = dense_size,
dense_layers=dense_layers,
rec_size = rec_size,
rec_layers=rec_layers,
rec_type = rec_type,
rec_bidirectional = rec_bidirectional,
intermediate_size = intermediate_size,
attention_type = attention_type,
repeat_encoder = repeat_encoder,
dense_dropout = dense_dropout,
rec_dropout = rec_dropout,
recurrent_dropout = recurrent_dropout,
encoder_dropout = encoder_dropout,
add_pos_embedding = add_pos_embedding,
optimizer = optimizer,
act_fct = "gelu",
rec_act_fct = "tanh",
embedding_dim = embedding_dim,
self_attention_heads = self_attention_heads
)
# Basic Information of Input and Target Data
variable_name_order <- dimnames(text_embeddings$embeddings)[[3]]
config["target_levels"] <- list(target_levels)
config["n_categories"] <- list(length(target_levels))
config["input_variables"] <- list(variable_name_order)
if (length(target_levels) > 2) {
# Multi Class
config["act_fct_last"] <- "softmax"
config["err_fct"] <- "categorical_crossentropy"
config["metric"] <- "categorical_accuracy"
config["balanced_metric"] <- "balanced_accuracy"
} else {
# Binary Classification
config["act_fct_last"] <- "sigmoid"
config["err_fct"] <- "binary_crossentropy"
config["metric"] <- "binary_accuracy"
config["balanced_metric"] <- "balanced_accuracy"
}
config["require_one_hot"] <- list(FALSE)
# if(length(rec)>0|repeat_encoder>0){
# config["require_matrix_map"]=list(FALSE)
# } else {
config["require_matrix_map"] <- list(TRUE)
# }
self$model_config <- config
# Adjust configuration
private$adjust_configuration()
# Set FeatureExtractor and adapt config
self$check_feature_extractor_object_type(feature_extractor)
private$set_feature_extractor(feature_extractor)
# Set package versions
private$set_package_versions()
# Finalize configuration
private$set_configuration_to_TRUE()
# Create_Model
private$create_reset_model()
},
#-------------------------------------------------------------------------
#' @description Method for training a neural net.
#'
#' Training includes a routine for early stopping. In the case that loss<0.0001
#' and Accuracy=1.00 and Average Iota=1.00 training stops. The history uses the values
#' of the last trained epoch for the remaining epochs.
#'
#' After training the model with the best values for Average Iota, Accuracy, and Loss
#' on the validation data set is used as the final model.
#'
#' @param data_embeddings Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings].
#' @param data_targets `factor` containing the labels for cases stored in `data_embeddings`. Factor must be named
#' and has to use the same names used in `data_embeddings`.
#' @param data_folds `int` determining the number of cross-fold samples.
#' @param data_val_size `double` between 0 and 1, indicating the proportion of cases of each class which should be
#' used for the validation sample during the estimation of the model. The remaining cases are part of the training
#' data.
#' @param balance_class_weights `bool` If `TRUE` class weights are generated based on the frequencies of the
#' training data with the method Inverse Class Frequency'. If `FALSE` each class has the weight 1.
#' @param balance_sequence_length `bool` If `TRUE` sample weights are generated for the length of sequences based on
#' the frequencies of the training data with the method Inverse Class Frequency'. If `FALSE` each sequences length
#' has the weight 1.
#' @param use_sc `bool` `TRUE` if the estimation should integrate synthetic cases. `FALSE` if not.
#' @param sc_method `vector` containing the method for generating synthetic cases. Possible are `sc_method="adas"`,
#' `sc_method="smote"`, and `sc_method="dbsmote"`.
#' @param sc_min_k `int` determining the minimal number of k which is used for creating synthetic units.
#' @param sc_max_k `int` determining the maximal number of k which is used for creating synthetic units.
#' @param use_pl `bool` `TRUE` if the estimation should integrate pseudo-labeling. `FALSE` if not.
#' @param pl_max_steps `int` determining the maximum number of steps during pseudo-labeling.
#' @param pl_anchor `double` between 0 and 1 indicating the reference point for sorting the new cases of every
#' label. See notes for more details.
#' @param pl_max `double` between 0 and 1, setting the maximal level of confidence for considering a case for
#' pseudo-labeling.
#' @param pl_min `double` between 0 and 1, setting the minimal level of confidence for considering a case for
#' pseudo-labeling.
#' @param sustain_track `bool` If `TRUE` energy consumption is tracked during training via the python library
#' 'codecarbon'.
#' @param sustain_iso_code `string` ISO code (Alpha-3-Code) for the country. This variable must be set if
#' sustainability should be tracked. A list can be found on Wikipedia:
#' <https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes>.
#' @param sustain_region Region within a country. Only available for USA and Canada See the documentation of
#' codecarbon for more information. <https://mlco2.github.io/codecarbon/parameters.html>
#' @param sustain_interval `int` Interval in seconds for measuring power usage.
#' @param batch_size `int` Size of the batches for training.
#' @param epochs `int` Number of training epochs.
#' @param Ns `int` Number of cases for every class in the sample.
#' @param Nq `int` Number of cases for every class in the query.
#' @param loss_alpha `double` Value between 0 and 1 indicating how strong the loss should focus on pulling cases to
#' its corresponding prototypes or pushing cases away from other prototypes. The higher the value the more the
#' loss concentrates on pulling cases to its corresponding prototypes.
#' @param loss_margin `double` Value greater 0 indicating the minimal distance of every case from prototypes of
#' other classes
#' @param sampling_separate `bool` If `TRUE` the cases for every class are divided into a data set for sample and for query.
#' These are never mixed. If `TRUE` sample and query cases are drawn from the same data pool. That is, a case can be
#' part of sample in one epoch and in another epoch it can be part of query. It is ensured that a case is never part of
#' sample and query at the same time. In addition, it is ensured that every cases exists only once during
#' a training step.
#' @param sampling_shuffle `bool` If `TRUE` cases a randomly drawn from the data during every step. If `FALSE`
#' the cases are not shuffled.
#' @param dir_checkpoint `string` Path to the directory where the checkpoint during training should be saved. If the
#' directory does not exist, it is created.
#' @param log_dir `string` Path to the directory where the log files should be saved. If no logging is desired set
#' this argument to `NULL`.
#' @param log_write_interval `int` Time in seconds determining the interval in which the logger should try to update
#' the log files. Only relevant if `log_dir` is not `NULL`.
#' @param trace `bool` `TRUE`, if information about the estimation phase should be printed to the console.
#' @param ml_trace `int` `ml_trace=0` does not print any information about the training process from pytorch on the
#' console.
#' @param n_cores `int` Number of cores which should be used during the calculation of synthetic cases. Only relevant if
#' `use_sc=TRUE`.
#' @return Function does not return a value. It changes the object into a trained classifier.
#' @details
#'
#' * `sc_max_k`: All values from sc_min_k up to sc_max_k are successively used. If
#' the number of `sc_max_k` is too high, the value is reduced to a number that allows the calculating of synthetic
#' units.
#' * `pl_anchor:` With the help of this value, the new cases are sorted. For
#' this aim, the distance from the anchor is calculated and all cases are arranged into an ascending order.
#'
train = function(data_embeddings,
data_targets,
data_folds = 5,
data_val_size = 0.25,
use_sc = TRUE,
sc_method = "dbsmote",
sc_min_k = 1,
sc_max_k = 10,
use_pl = TRUE,
pl_max_steps = 3,
pl_max = 1.00,
pl_anchor = 1.00,
pl_min = 0.00,
sustain_track = TRUE,
sustain_iso_code = NULL,
sustain_region = NULL,
sustain_interval = 15,
epochs = 40,
batch_size = 35,
Ns = 5,
Nq = 3,
loss_alpha = 0.5,
loss_margin = 0.5,
sampling_separate = FALSE,
sampling_shuffle = TRUE,
dir_checkpoint,
trace = TRUE,
ml_trace = 1,
log_dir = NULL,
log_write_interval = 10,
n_cores=auto_n_cores()) {
# Checking Arguments------------------------------------------------------
check_type(data_folds, type = "int", FALSE)
check_type(data_val_size, type = "double", FALSE)
check_type(use_sc, type = "bool", FALSE)
check_type(sc_method, type = "string", FALSE)
check_type(sc_min_k, type = "int", FALSE)
check_type(sc_max_k, type = "int", FALSE)
check_type(use_pl, type = "bool", FALSE)
check_type(pl_max_steps, type = "int", FALSE)
check_type(pl_max, type = "double", FALSE)
check_type(pl_anchor, type = "double", FALSE)
check_type(pl_min, type = "double", FALSE)
check_type(sustain_track, type = "bool", FALSE)
check_type(sustain_iso_code, type = "string", TRUE)
check_type(sustain_region, type = "string", TRUE)
check_type(sustain_interval, type = "int", FALSE)
check_type(epochs, type = "int", FALSE)
check_type(batch_size, type = "int", FALSE)
check_type(dir_checkpoint, type = "string", FALSE)
check_type(trace, type = "bool", FALSE)
check_type(n_cores, type = "int", FALSE)
check_type(Ns, type = "int", FALSE)
check_type(Nq, type = "int", FALSE)
check_type(loss_alpha, type = "double", FALSE)
check_type(loss_margin, type = "double", FALSE)
check_class(data_embeddings, c("EmbeddedText", "LargeDataSetForTextEmbeddings"), FALSE)
self$check_embedding_model(data_embeddings)
if (is.null(names(data_targets))) {
stop("data_targets must be a named factor.")
}
data_targets <- private$check_and_adjust_target_levels(data_targets)
if (pl_anchor < pl_min) {
stop("pl_anchor must be at least pl_min.")
}
if (pl_anchor > pl_max) {
stop("pl_anchor must be lower or equal to pl_max.")
}
if (data_folds < 2) {
stop("data_folds must be at least 2.")
}
# Saving training configuration-------------------------------------------
self$last_training$config$data_val_size <- data_val_size
self$last_training$config$use_sc <- use_sc
self$last_training$config$sc_method <- sc_method
self$last_training$config$sc_min_k <- sc_min_k
self$last_training$config$sc_max_k <- sc_max_k
self$last_training$config$use_pl <- use_pl
self$last_training$config$pl_max_steps <- pl_max_steps
self$last_training$config$pl_max <- pl_max
self$last_training$config$pl_anchor <- pl_anchor
self$last_training$config$pl_min <- pl_min
self$last_training$config$sustain_track <- sustain_track
self$last_training$config$sustain_iso_code <- sustain_iso_code
self$last_training$config$sustain_region <- sustain_region
self$last_training$config$sustain_interval <- sustain_interval
self$last_training$config$epochs <- epochs
self$last_training$config$batch_size <- batch_size
self$last_training$config$Ns <- Ns
self$last_training$config$Nq <- Nq
self$last_training$config$loss_alpha <- loss_alpha
self$last_training$config$loss_margin <- loss_margin
self$last_training$config$dir_checkpoint <- dir_checkpoint
self$last_training$config$trace <- trace
self$last_training$config$ml_trace <- ml_trace
self$last_training$config$n_cores<-n_cores
self$last_training$config$sampling_separate <- sampling_separate
self$last_training$config$sampling_shuffle <- sampling_shuffle
private$log_config$log_dir <- log_dir
private$log_config$log_state_file <- paste0(private$log_config$log_dir, "/aifeducation_state.log")
private$log_config$log_write_interval <- log_write_interval
# Start-------------------------------------------------------------------
if (self$last_training$config$trace == TRUE) {
message(paste(
date(),
"Start"
))
}
# Set up data
if ("EmbeddedText" %in% class(data_embeddings)) {
data <- data_embeddings$convert_to_LargeDataSetForTextEmbeddings()
} else {
data <- data_embeddings
}
# Create DataManager------------------------------------------------------
if (self$model_config$use_fe == TRUE) {
compressed_embeddings <- self$feature_extractor$extract_features_large(
data_embeddings = data,
as.integer(self$last_training$config$batch_size),
trace = self$last_training$config$trace
)
data_manager <- DataManagerClassifier$new(
data_embeddings = compressed_embeddings,
data_targets = data_targets,
folds = data_folds,
val_size = self$last_training$config$data_val_size,
class_levels = self$model_config$target_levels,
one_hot_encoding = self$model_config$require_one_hot,
add_matrix_map = (self$model_config$require_matrix_map == TRUE | self$last_training$config$use_sc == TRUE),
sc_method = sc_method,
sc_min_k = sc_min_k,
sc_max_k = sc_max_k,
trace = trace,
self$last_training$config$n_cores
)
} else {
data_manager <- DataManagerClassifier$new(
data_embeddings = data,
data_targets = data_targets,
folds = data_folds,
val_size = self$last_training$config$data_val_size,
class_levels = self$model_config$target_levels,
one_hot_encoding = self$model_config$require_one_hot,
add_matrix_map = (self$model_config$require_matrix_map == TRUE | self$last_training$config$use_sc == TRUE),
sc_method = sc_method,
sc_min_k = sc_min_k,
sc_max_k = sc_max_k,
trace = trace,
self$last_training$config$n_cores
)
}
# Save Data Statistics
self$last_training$data <- data_manager$get_statistics()
# Save the number of folds
self$last_training$config$n_folds <- data_manager$get_n_folds()
# Init Training------------------------------------------------------------
private$init_train()
# disable Progress bars
datasets$disable_progress_bars()
# Start Sustainability Tracking-------------------------------------------
if (sustain_track == TRUE) {
if (is.null(sustain_iso_code) == TRUE) {
stop("Sustainability tracking is activated but iso code for the
country is missing. Add iso code or deactivate tracking.")
}
tmp_code_carbon<-reticulate::import("codecarbon")
codecarbon_version=as.character(tmp_code_carbon["__version__"])
if(utils::compareVersion(codecarbon_version,"2.8.0")>=0){
path_look_file=codecarbon$lock$LOCKFILE
if(file.exists(path_look_file)){
unlink(path_look_file)
}
}
sustainability_tracker <- codecarbon$OfflineEmissionsTracker(
country_iso_code = sustain_iso_code,
region = sustain_region,
tracking_mode = "machine",
log_level = "warning",
measure_power_secs = sustain_interval,
save_to_file = FALSE,
save_to_api = FALSE
)
sustainability_tracker$start()
}
# Start Training----------------------------------------------------------
# Load Custom Model Scripts
private$load_reload_python_scripts()
# Start Loop inclusive final training
for (iter in 1:(self$last_training$config$n_folds + 1)) {
base::gc(verbose = FALSE, full = TRUE)
if (self$last_training$config$use_pl == FALSE) {
private$train_standard(
iteration = iter,
data_manager = data_manager,
inc_synthetic = self$last_training$config$use_sc
)
} else if (self$last_training$config$use_pl == TRUE) {
private$train_with_pseudo_labels(
init_train = TRUE,
iteration = iter,
data_manager = data_manager,
inc_synthetic = self$last_training$config$use_sc
)
}
# Calculate measures on categorical level
private$calculate_measures_on_categorical_level(
data_manager = data_manager,
iteration = iter
)
}
# Finalize Training
private$finalize_train()
# Stop sustainability tracking if requested
if (sustain_track == TRUE) {
sustainability_tracker$stop()
private$sustainability <- summarize_tracked_sustainability(sustainability_tracker)
} else {
private$sustainability <- list(
sustainability_tracked = FALSE,
date = NA,
sustainability_data = list(
duration_sec = NA,
co2eq_kg = NA,
cpu_energy_kwh = NA,
gpu_energy_kwh = NA,
ram_energy_kwh = NA,
total_energy_kwh = NA
)
)
}
if (self$last_training$config$trace == TRUE) {
message(paste(
date(),
"Training Complete"
))
}
},
#---------------------------------------------------------------------------
#' @description Method for embedding documents. Please do not confuse this type of embeddings with the embeddings of
#' texts created by an object of class [TextEmbeddingModel]. These embeddings embed documents according to their
#' similarity to specific classes.
#' @param embeddings_q Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] containing the text
#' embeddings for all cases which should be embedded into the classification space.
#' @param batch_size `int` batch size.
#' @return Returns a `list` containing the following elements
#'
#' * `embeddings_q`: embeddings for the cases (query sample).
#' * `embeddings_prototypes`: embeddings of the prototypes which were learned during training. They represents the
#' center for the different classes.
#'
embed = function(embeddings_q = NULL, batch_size = 32) {
check_class(embeddings_q, c("EmbeddedText", "LargeDataSetForTextEmbeddings"), FALSE)
check_type(batch_size, "int", FALSE)
# Check input for compatible text embedding models and feature extractors
if ("EmbeddedText" %in% class(embeddings_q)) {
self$check_embedding_model(text_embeddings = embeddings_q)
requires_compression <- self$requires_compression(embeddings_q)
} else if ("array" %in% class(embeddings_q)) {
requires_compression <- self$requires_compression(embeddings_q)
} else {
requires_compression <- FALSE
}
# Load Custom Model Scripts
private$load_reload_python_scripts()
# Check number of cases in the data
single_prediction <- private$check_single_prediction(embeddings_q)
# Get current row names/name of the cases
current_row_names <- private$get_rownames_from_embeddings(embeddings_q)
# Apply feature extractor if it is part of the model
if (requires_compression == TRUE) {
# Returns a data set
embeddings_q <- self$feature_extractor$extract_features(
data_embeddings = embeddings_q,
batch_size = as.integer(batch_size)
)
}
# If at least two cases are part of the data set---------------------------
if (single_prediction == FALSE) {
# Returns a data set object
prediction_data_q_embeddings <- private$prepare_embeddings_as_dataset(embeddings_q)
if (private$ml_framework == "pytorch") {
prediction_data_q_embeddings$set_format("torch")
embeddings_and_distances <- py$TeProtoNetBatchEmbedDistance(
model = self$model,
dataset_q = prediction_data_q_embeddings,
batch_size = as.integer(batch_size)
)
embeddings_tensors_q <- private$detach_tensors(embeddings_and_distances[[1]])
distances_tensors_q <- private$detach_tensors(embeddings_and_distances[[2]])
}
} else {
prediction_data_q_embeddings <- private$prepare_embeddings_as_np_array(embeddings_q)
# Apply feature extractor if it is part of the model
if (requires_compression == TRUE) {
# Returns a data set
prediction_data_q <- np$array(self$feature_extractor$extract_features(
data_embeddings = prediction_data_q_embeddings,
batch_size = as.integer(batch_size)
)$embeddings)
}
if (private$ml_framework == "pytorch") {
if (torch$cuda$is_available()) {
device <- "cuda"
dtype <- torch$double
self$model$to(device, dtype = dtype)
self$model$eval()
input <- torch$from_numpy(prediction_data_q_embeddings)
embeddings_tensors_q <- self$model$embed(input$to(device, dtype = dtype))
embeddings_tensors_q <- private$detach_tensors(embeddings_tensors_q)
distances_tensors_q <- self$model$get_distances(input$to(device, dtype = dtype))
distances_tensors_q <- private$detach_tensors(distances_tensors_q)
} else {
device <- "cpu"
dtype <- torch$float
self$model$to(device, dtype = dtype)
self$model$eval()
input <- torch$from_numpy(prediction_data_q_embeddings)
embeddings_tensors_q <- self$model$embed(input$to(device, dtype = dtype))
embeddings_tensors_q <- private$detach_tensors(embeddings_tensors_q)
distances_tensors_q <- self$model$get_distances(input$to(device, dtype = dtype))
distances_tensors_q <- private$detach_tensors(distances_tensors_q)
}
}
}
if (private$ml_framework == "pytorch") {
embeddings_prototypes <- private$detach_tensors(
self$model$get_trained_prototypes()
)
}
# Post processing
rownames(embeddings_tensors_q) <- current_row_names
rownames(distances_tensors_q) <- current_row_names
rownames(embeddings_prototypes) <- self$model_config$target_levels
return(list(
embeddings_q = embeddings_tensors_q,
distances_q = distances_tensors_q,
embeddings_prototypes = embeddings_prototypes
))
},
#---------------------------------------------------------------------------
#' @description Method for creating a plot to visualize embeddings and their corresponding centers (prototypes).
#' @param embeddings_q Object of class [EmbeddedText] or [LargeDataSetForTextEmbeddings] containing the text
#' embeddings for all cases which should be embedded into the classification space.
#' @param classes_q Named `factor` containg the true classes for every case. Please note that the names must match
#' the names/ids in `embeddings_q`.
#'@param inc_unlabeled `bool` If `TRUE` plot includes unlabeled cases as data points.
#'@param size_points `int` Size of the points excluding the points for prototypes.
#'@param size_points_prototypes `int` Size of points representing prototypes.
#'@param alpha `float` Value indicating how transparent the points should be (important
#' if many points overlap). Does not apply to points representing prototypes.
#' @param batch_size `int` batch size.
#' @return Returns a plot of class `ggplot`visualizing embeddings.
plot_embeddings = function(embeddings_q,
classes_q=NULL,
batch_size = 12,
alpha = 0.5,
size_points = 3,
size_points_prototypes = 8,
inc_unlabeled = TRUE) {
# Argument checking-------------------------------------------------------
embeddings <- self$embed(
embeddings_q = embeddings_q,
batch_size = batch_size
)
prototypes <- as.data.frame(embeddings$embeddings_prototypes)
prototypes$class <- rownames(embeddings$embeddings_prototypes)
prototypes$type <- rep("prototype", nrow(embeddings$embeddings_prototypes))
colnames(prototypes) <- c("x", "y", "class", "type")
if(!is.null(classes_q)){
true_values_names <- intersect(
x = names(na.omit(classes_q)),
y = private$get_rownames_from_embeddings(embeddings_q)
)
true_values <- as.data.frame(embeddings$embeddings_q[true_values_names, , drop = FALSE])
true_values$class <- classes_q[true_values_names]
true_values$type <- rep("labeled", length(true_values_names))
colnames(true_values) <- c("x", "y", "class", "type")
} else {
true_values_names=NULL
true_values=NULL
}
if (inc_unlabeled == TRUE) {
estimated_values_names <- setdiff(
x = private$get_rownames_from_embeddings(embeddings_q),
y = true_values_names
)
if (length(estimated_values_names) > 0) {
estimated_values <- as.data.frame(embeddings$embeddings_q[estimated_values_names, , drop = FALSE])
estimated_values$class <- private$calc_classes_on_distance(
distance_matrix = embeddings$distances_q[estimated_values_names, , drop = FALSE],
prototypes = embeddings$embeddings_prototypes
)
estimated_values$type <- rep("unlabeled", length(estimated_values_names))
colnames(estimated_values) <- c("x", "y", "class", "type")
}
} else {
estimated_values_names <- NULL
}
plot_data=prototypes
if (length(true_values) > 0) {
plot_data=rbind(plot_data,true_values)
}
if (length(estimated_values_names) > 0) {
plot_data <- rbind(plot_data,estimated_values)
}
plot <- ggplot2::ggplot(data = plot_data) +
ggplot2::geom_point(
mapping = ggplot2::aes(
x = x,
y = y,
color = class,
shape = type,
size = type,
alpha = type
)#,
#position = ggplot2::position_jitter(h = 0.1, w = 0.1)
) +
ggplot2::scale_size_manual(values = c(
"prototype" = size_points_prototypes,
"labeled" = size_points,
"unlabeled" = size_points
)) +
ggplot2::scale_alpha_manual(
values = c(
"prototype" = 1,
"labeled" = alpha,
"unlabeled" = alpha
)
) +
ggplot2::theme_classic()
return(plot)
}
),
private = list(
#-------------------------------------------------------------------------
calc_classes_on_distance = function(distance_matrix, prototypes) {
index_vector <- vector(length = nrow(distance_matrix))
for (i in 1:length(index_vector)) {
index_vector[i] <- which.min(distance_matrix[i, ])
}
classes <- factor(index_vector,
levels = 1:nrow(prototypes),
labels = rownames(prototypes)
)
return(classes)
},
#--------------------------------------------------------------------------
load_reload_python_scripts = function() {
if (private$ml_framework == "tensorflow") {
} else if (private$ml_framework == "pytorch") {
reticulate::py_run_file(system.file("python/pytorch_te_classifier.py",
package = "aifeducation"
))
reticulate::py_run_file(system.file("python/pytorch_te_protonet.py",
package = "aifeducation"
))
reticulate::py_run_file(system.file("python/py_log.py",
package = "aifeducation"
))
}
},
#--------------------------------------------------------------------------
create_reset_model = function() {
private$check_config_for_TRUE()
if (private$ml_framework == "tensorflow") {
} else {
#--------------------------------------------------------------------------
# Load Custom Pytorch Objects and Functions
private$load_reload_python_scripts()
self$model <- py$TextEmbeddingClassifierProtoNet_PT(
features = as.integer(self$model_config$features),
times = as.integer(self$model_config$times),
dense_size = as.integer(self$model_config$dense_size),
dense_layers=as.integer(self$model_config$dense_layers),
rec_size = as.integer(self$model_config$rec_size),
rec_layers=as.integer(self$model_config$rec_layers),
rec_type = self$model_config$rec_type,
rec_bidirectional = self$model_config$rec_bidirectional,
intermediate_size = as.integer(self$model_config$intermediate_size),
attention_type = self$model_config$attention_type,
repeat_encoder = as.integer(self$model_config$repeat_encoder),
dense_dropout = self$model_config$dense_dropout,
rec_dropout = self$model_config$rec_dropout,
encoder_dropout = self$model_config$encoder_dropout,
add_pos_embedding = self$model_config$add_pos_embedding,
self_attention_heads = as.integer(self$model_config$self_attention_heads),
embedding_dim = as.integer(self$model_config$embedding_dim),
target_levels = reticulate::np_array(seq(from = 0, to = (length(self$model_config$target_levels) - 1)))
)
}
},
#--------------------------------------------------------------------------
basic_train = function(train_data = NULL,
val_data = NULL,
test_data = NULL,
reset_model = FALSE,
use_callback = TRUE,
log_dir = NULL,
log_write_interval = 10,
log_top_value = NULL,
log_top_total = NULL,
log_top_message = NULL) {
# Clear session to provide enough resources for computations
if (private$ml_framework == "tensorflow") {
keras$backend$clear_session()
} else if (private$ml_framework == "pytorch") {
if (torch$cuda$is_available()) {
torch$cuda$empty_cache()
}
}
# Reset model if requested
if (reset_model == TRUE) {
private$create_reset_model()
}
# Set Optimizer
if (private$ml_framework == "tensorflow") {
balanced_metric <- py$BalancedAccuracy(n_classes = as.integer(length(self$model_config$target_levels)))
if (self$model_config$optimizer == "adam") {
self$model$compile(
loss = self$model_config$err_fct,
optimizer = keras$optimizers$Adam(),
metrics = c(self$model_config$metric, balanced_metric)
)
} else if (self$model_config$optimizer == "rmsprop") {
self$model$compile(
loss = self$model_config$err_fct,
optimizer = keras$optimizers$RMSprop(),
metrics = c(self$model_config$metric, balanced_metric)
)
}
} else if (private$ml_framework == "pytorch") {
loss_fct_name <- "CrossEntropyLoss"
if (self$model_config$optimizer == "adam") {
optimizer <- "adam"
} else if (self$model_config$optimizer == "rmsprop") {
optimizer <- "rmsprop"
}
}
# Check directory for checkpoints
create_dir(
dir_path = paste0(self$last_training$config$dir_checkpoint, "/checkpoints"),
trace = self$last_training$config$trace,
msg = "Creating Checkpoint Directory")
# Set target column
if (self$model_config$require_one_hot == FALSE) {
target_column <- "labels"
} else {
target_column <- "one_hot_encoding"
}
# Tensorflow - Callbacks and training
if (private$ml_framework == "tensorflow") {
if (use_callback == TRUE) {
callback <- keras$callbacks$ModelCheckpoint(
filepath = paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.h5"),
monitor = paste0("val_", self$model_config$balanced_metric),
verbose = as.integer(min(self$last_training$config$ml_trace, 1)),
mode = "auto",
save_best_only = TRUE,
save_weights_only = TRUE
)
} else {
callback <- reticulate::py_none()
}
if (private$gui$shiny_app_active == TRUE) {
private$load_reload_python_scripts()
callback <- list(callback, py$ReportAiforeducationShiny())
}
data_set_weights <- datasets$Dataset$from_dict(
reticulate::dict(list(
sample_weights = sample_weights
))
)
# inputs, targets, sample_weights
dataset_tf <- train_data$add_column("sample_weights", data_set_weights["sample_weights"])
dataset_tf <- dataset_tf$rename_column("input", "input_embeddings")
# Choose correct target column and rename
dataset_tf <- dataset_tf$rename_column(target_column, "targets")
dataset_tf$with_format("tf")
tf_dataset_train <- dataset_tf$to_tf_dataset(
columns = c("input_embeddings", "sample_weights"),
batch_size = as.integer(self$last_training$config$batch_size),
shuffle = TRUE,
label_cols = "targets"
)
# Add sample weights
tf_dataset_train <- tf_dataset_train$map(py$extract_sample_weight)
dataset_tf_val <- val_data$rename_column("input", "input_embeddings")
# Choose correct target column and rename
dataset_tf_val <- dataset_tf_val$rename_column(target_column, "targets")
tf_dataset_val <- dataset_tf_val$to_tf_dataset(
columns = c("input_embeddings"),
batch_size = as.integer(self$last_training$config$batch_size),
shuffle = FALSE,
label_cols = "targets"
)
history <- self$model$fit(
verbose = as.integer(self$last_training$config$ml_trace),
x = tf_dataset_train,
validation_data = tf_dataset_val,
epochs = as.integer(self$last_training$config$epochs),
callbacks = callback,
class_weight = reticulate::py_dict(keys = names(class_weights), values = class_weights)
)$history
if (self$model_config$n_categories == 2) {
history <- list(
loss = rbind(history$loss, history$val_loss),
accuracy = rbind(history$binary_accuracy, history$val_binary_accuracy),
balanced_accuracy = rbind(history$balanced_accuracy, history$val_balanced_accuracy)
)
} else {
history <- list(
loss = rbind(history$loss, history$val_loss),
accuracy = rbind(history$categorical_accuracy, history$val_categorical_accuracy),
balanced_accuracy = rbind(history$balanced_accuracy, history$val_balanced_accuracy)
)
}
if (use_callback == TRUE) {
self$model$load_weights(paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.h5"))
}
# PyTorch - Callbacks and training
} else if (private$ml_framework == "pytorch") {
dataset_train <- train_data$select_columns(c("input", target_column))
if (self$model_config$require_one_hot == TRUE) {
dataset_train <- dataset_train$rename_column(target_column, "labels")
}
pytorch_train_data <- dataset_train$with_format("torch")
pytorch_val_data <- val_data$select_columns(c("input", target_column))
if (self$model_config$require_one_hot == TRUE) {
pytorch_val_data <- pytorch_val_data$rename_column(target_column, "labels")
}
pytorch_val_data <- pytorch_val_data$with_format("torch")
if (!is.null(test_data)) {
pytorch_test_data <- test_data$select_columns(c("input", target_column))
if (self$model_config$require_one_hot == TRUE) {
pytorch_test_data <- pytorch_test_data$rename_column(target_column, "labels")
}
pytorch_test_data <- pytorch_test_data$with_format("torch")
} else {
pytorch_test_data <- NULL
}
history <- py$TeClassifierProtoNetTrain_PT_with_Datasets(
model = self$model,
loss_fct_name = loss_fct_name,
optimizer_method = self$model_config$optimizer,
Ns = as.integer(self$last_training$config$Ns),
Nq = as.integer(self$last_training$config$Nq),
loss_alpha = self$last_training$config$loss_alpha,
loss_margin = self$last_training$config$loss_margin,
trace = as.integer(self$last_training$config$ml_trace),
use_callback = use_callback,
train_data = pytorch_train_data,
val_data = pytorch_val_data,
test_data = pytorch_test_data,
epochs = as.integer(self$last_training$config$epochs),
sampling_separate = self$last_training$config$sampling_separate,
sampling_shuffle = self$last_training$config$sampling_shuffle,
filepath = paste0(self$last_training$config$dir_checkpoint, "/checkpoints/best_weights.pt"),
n_classes = as.integer(length(self$model_config$target_levels)),
log_dir = log_dir,
log_write_interval = log_write_interval,
log_top_value = log_top_value,
log_top_total = log_top_total,
log_top_message = log_top_message
)
}
# provide rownames and replace -100
history <- private$prepare_history_data(history)
return(history)
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.