Nothing
# This file is part of the R package "aifeducation".
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as published by
# the Free Software Foundation.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>
#' @title Base `R6` class for creation and definition of `.AIFE*Transformer-like` classes
#'
#' @description This base class is used to create and define `.AIFE*Transformer-like` classes. It serves as a skeleton
#' for a future concrete transformer and cannot be used to create an object of itself (an attempt to call `new`-method
#' will produce an error).
#'
#' See p.1 Base Transformer Class in
#' [Transformers for Developers](https://fberding.github.io/aifeducation/articles/transformers.html) for details.
#'
#' @section Create: The `create`-method is a basic algorithm that is used to create a new transformer, but cannot be
#' called directly.
#'
#' @section Train: The `train`-method is a basic algorithm that is used to train and tune the transformer but cannot be
#' called directly.
#'
#' @section Concrete transformer implementation: There are already implemented concrete (child) transformers (e.g.
#' `BERT`, `DeBERTa-V2`, etc.), to implement a new one see p.4 Implement A Custom Transformer in
#' [Transformers for Developers](https://fberding.github.io/aifeducation/articles/transformers.html)
#'
#' @param ml_framework `r paramDesc.ml_framework()`
#' @param text_dataset `r paramDesc.text_dataset()`
#' @param sustain_track `r paramDesc.sustain_track()`
#' @param sustain_iso_code `r paramDesc.sustain_iso_code()`
#' @param sustain_region `r paramDesc.sustain_region()`
#' @param sustain_interval `r paramDesc.sustain_interval()`
#' @param trace `r paramDesc.trace()`
#' @param pytorch_safetensors `r paramDesc.pytorch_safetensors()`
#' @param log_dir `r paramDesc.log_dir()`
#' @param log_write_interval `r paramDesc.log_write_interval()`
#'
#' @references Hugging Face transformers documantation:
#' * [BERT](https://huggingface.co/docs/transformers/model_doc/bert)
#' * [DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta-v2)
#' * [Funnel](https://huggingface.co/docs/transformers/model_doc/funnel)
#' * [Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)
#' * [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)
#' * [MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)
#'
#' @family Transformers for developers
#' @export
.AIFEBaseTransformer <- R6::R6Class( # nolint
classname = ".AIFEBaseTransformer",
private = list(
# == Attributes ====================================================================================================
# Transformer's title
title = "Transformer Model",
# This object is used to track sustainability on demand
# It can be created with the private `create_sustain_tracker()` method.
sustainability_tracker = NULL,
# Required and optional steps to create a new transformer
steps_for_creation = list(
# required in each transformer
create_tokenizer_draft = NULL,
calculate_vocab = NULL,
save_tokenizer_draft = NULL,
create_final_tokenizer = NULL,
create_transformer_model = NULL,
# optional
check_max_pos_emb = NULL,
# required and already defined for all transformers
save_transformer_model = NULL
),
# Required steps to train the transformer
steps_for_training = list(
# required
load_existing_model = NULL,
# optional
cuda_empty_cache = NULL,
# required and already defined steps
check_chunk_size = NULL,
create_chunks_for_training = NULL,
prepare_train_tune = NULL,
start_training = NULL,
save_model = NULL,
# required, already defined steps. Can be overwritten with a custom version
create_data_collator = NULL
),
# == Methods =======================================================================================================
# Clear methods ----
# Clears the variables of the class: `sustainability_tracker`, `params` and `temp`
clear_variables = function() {
private$sustainability_tracker <- NULL
self$params <- list()
self$temp <- list()
},
# Check methods ----
# Checks whether a required creation step is missing
check_required_SFC = function() {
if (!is.function(private$steps_for_creation$create_tokenizer_draft)) {
stop("'steps_for_creation$create_tokenizer_draft' is not a function")
}
if (!is.function(private$steps_for_creation$calculate_vocab)) {
stop("'steps_for_creation$calculate_vocab' is not a function")
}
if (!is.function(private$steps_for_creation$save_tokenizer_draft)) {
stop("'steps_for_creation$save_tokenizer_draft' is not a function")
}
if (!is.function(private$steps_for_creation$create_final_tokenizer)) {
stop("'steps_for_creation$create_final_tokenizer' is not a function")
}
if (!is.function(private$steps_for_creation$create_transformer_model)) {
stop("'steps_for_creation$create_transformer_model' is not a function")
}
},
# Checks whether a required training step is missing
check_required_SFT = function() {
if (!is.function(private$steps_for_training$load_existing_model)) {
stop("'steps_for_training$load_existing_model' is not a function")
}
},
# Other ----
# Initializes the 'static' parameters of the transformer in the `param` list
init_common_model_params = function(ml_framework,
sustain_track,
sustain_iso_code,
sustain_region,
sustain_interval,
trace,
pytorch_safetensors,
log_dir,
log_write_interval,
text_dataset) {
self$set_model_param("ml_framework", ml_framework)
self$set_model_param("sustain_track", sustain_track)
self$set_model_param("sustain_iso_code", sustain_iso_code)
self$set_model_param("sustain_region", sustain_region)
self$set_model_param("sustain_interval", sustain_interval)
self$set_model_param("trace", trace)
self$set_model_param("pytorch_safetensors", pytorch_safetensors)
self$set_model_param("log_dir", log_dir)
self$set_model_param("log_write_interval", log_write_interval)
self$set_model_param("text_dataset", text_dataset)
},
# -------------------------------------------------------------------------------
define_required_SFC_functions = function() {
if (!is.function(private$steps_for_creation$save_transformer_model)) {
private$steps_for_creation$save_transformer_model <- function(self) {
if (self$params$ml_framework == "tensorflow") {
self$temp$model$build()
self$temp$model$save_pretrained(save_directory = self$params$model_dir)
} else {
self$temp$model$save_pretrained(
save_directory = self$params$model_dir,
safe_serilization = self$temp$pt_safe_save
)
}
}
}
},
# -------------------------------------------------------------------------------
define_required_SFT_functions = function() {
# SFT: check_chunk_size -------------------------------------------------------
if (!is.function(private$steps_for_training$check_chunk_size)) {
private$steps_for_training$check_chunk_size <- function(self) {
if (self$params$chunk_size > (self$temp$model$config$max_position_embeddings)) {
stop(
paste(
"Chunk size is", self$params$chunk_size, ". This value is not allowed to exceed",
self$temp$model$config$max_position_embeddings
)
)
}
if (self$params$chunk_size < 3) {
stop("Chunk size must be at least 3.")
}
# adjust chunk size. To elements are needed for begin and end of sequence
self$params$chunk_size <- self$params$chunk_size - 2
}
}
# SFT: create_chunks_for_training ---------------------------------------------
if (!is.function(private$steps_for_training$create_chunks_for_training)) {
private$steps_for_training$create_chunks_for_training <- function(self) {
tokenized_texts_raw <- tokenize_dataset(
dataset = self$temp$raw_text_dataset,
tokenizer = self$temp$tokenizer,
max_length = self$params$chunk_size,
log_file = self$temp$log_file,
write_interval = self$temp$write_interval,
value_top = self$temp$value_top,
total_top = self$temp$total_top,
message_top = self$temp$message_top
)
length_vector <- tokenized_texts_raw["length"]
if (self$params$full_sequences_only) {
relevant_indices <- which(length_vector == self$params$chunk_size)
} else {
relevant_indices <- which(
length_vector <= self$params$chunk_size &
length_vector >= self$params$min_seq_len
)
}
if (length(relevant_indices) == 0) {
self$temp$tokenized_dataset <- tokenized_texts_raw
} else {
self$temp$tokenized_dataset <- tokenized_texts_raw$select(as.integer(relevant_indices - 1))
}
private$update_tokenizer_statistics(self$temp$tokenized_dataset, "training")
}
}
# SFT: create_data_collator ---------------------------------------------------
if (!is.function(private$steps_for_training$create_data_collator)) {
# Create default data collator
# For mpnet transformer see AIFEMpnetTransformer -> SFT -> create_data_collator
private$steps_for_training$create_data_collator <- function(self) {
if (self$params$whole_word) {
self$temp$data_collator <- transformers$DataCollatorForWholeWordMask(
tokenizer = self$temp$tokenizer,
mlm = TRUE,
mlm_probability = self$params$p_mask,
return_tensors = self$temp$return_tensors
)
} else {
self$temp$data_collator <- transformers$DataCollatorForLanguageModeling(
tokenizer = self$temp$tokenizer,
mlm = TRUE,
mlm_probability = self$params$p_mask,
return_tensors = self$temp$return_tensors
)
}
}
}
# SFT: prepare_train_tune -----------------------------------------------------
if (!is.function(private$steps_for_training$prepare_train_tune)) {
private$steps_for_training$prepare_train_tune <- function(self) {
msg <- ifelse(self$params$whole_word, "Using Whole Word Masking", "Using Token Masking")
print_message(msg, self$params$trace)
self$temp$return_tensors <- ifelse(self$params$ml_framework == "tensorflow", "tf", "pt")
# Create data collator
private$steps_for_training$create_data_collator(self)
# ----
format_type <- ifelse(self$params$ml_framework == "tensorflow", "tensorflow", "torch")
self$temp$tokenized_dataset$set_format(type = format_type)
self$temp$tokenized_dataset <- self$temp$tokenized_dataset$train_test_split(
test_size = self$params$val_size
)
print_message("Preparing Training of the Model", self$params$trace)
# Create Custom Callbacks ----
if (self$params$ml_framework == "tensorflow") {
run_py_file("keras_callbacks.py")
create_logger <- py$create_AIFETransformerCSVLogger_TF
} else {
run_py_file("pytorch_transformer_callbacks.py")
create_logger <- py$create_AIFETransformerCSVLogger_PT
}
logger_args <- list(
loss_file = self$temp$loss_file,
log_file = self$temp$log_file,
value_top = 6,
total_top = 10,
message_top = "Overall: Training...",
min_step = 1
)
logger <- do.call(create_logger, logger_args)
if (self$params$ml_framework == "tensorflow") { # TENSORFLOW --------------
self$temp$tf_train_dataset <- self$temp$model$prepare_tf_dataset(
dataset = self$temp$tokenized_dataset$train,
batch_size = as.integer(self$params$batch_size),
collate_fn = self$temp$data_collator,
shuffle = TRUE
)
self$temp$tf_test_dataset <- self$temp$model$prepare_tf_dataset(
dataset = self$temp$tokenized_dataset$test,
batch_size = as.integer(self$params$batch_size),
collate_fn = self$temp$data_collator,
shuffle = TRUE
)
adam <- tf$keras$optimizers$Adam
# Create Callbacks ---------------------------------------------------------
callback_checkpoint <- tf$keras$callbacks$ModelCheckpoint(
filepath = paste0(self$params$output_dir, "/checkpoints/best_weights.h5"),
monitor = "val_loss",
verbose = as.integer(min(self$params$keras_trace, 1)),
mode = "auto",
save_best_only = TRUE,
save_freq = "epoch",
save_weights_only = TRUE
)
callback_history <- tf$keras$callbacks$CSVLogger(
filename = paste0(self$params$output_dir, "/checkpoints/history.log"),
separator = ",",
append = FALSE
)
# Add Callbacks
self$temp$callbacks <- list(callback_checkpoint, callback_history, logger)
print_message("Compile Model", self$params$trace)
self$temp$model$compile(optimizer = adam(self$params$learning_rate), loss = "auto")
# Clear session to provide enough resources for computations ---------------
tf$keras$backend$clear_session()
} else { # PYTORCH ------------------
if(utils::compareVersion(transformers["__version__"],"4.46.0")>=0){
training_args <- transformers$TrainingArguments(
output_dir = paste0(self$params$output_dir, "/checkpoints"),
overwrite_output_dir = TRUE,
eval_strategy = "epoch",
num_train_epochs = as.integer(self$params$n_epoch),
logging_strategy = "epoch",
save_strategy = "epoch",
save_total_limit = as.integer(1),
load_best_model_at_end = TRUE,
optim = "adamw_torch",
learning_rate = self$params$learning_rate,
per_device_train_batch_size = as.integer(self$params$batch_size),
per_device_eval_batch_size = as.integer(self$params$batch_size),
save_safetensors = TRUE,
auto_find_batch_size = FALSE,
report_to = "none",
log_level = "error",
disable_tqdm = !self$params$pytorch_trace
)
} else {
training_args <- transformers$TrainingArguments(
output_dir = paste0(self$params$output_dir, "/checkpoints"),
overwrite_output_dir = TRUE,
evaluation_strategy = "epoch",
num_train_epochs = as.integer(self$params$n_epoch),
logging_strategy = "epoch",
save_strategy = "epoch",
save_total_limit = as.integer(1),
load_best_model_at_end = TRUE,
optim = "adamw_torch",
learning_rate = self$params$learning_rate,
per_device_train_batch_size = as.integer(self$params$batch_size),
per_device_eval_batch_size = as.integer(self$params$batch_size),
save_safetensors = TRUE,
auto_find_batch_size = FALSE,
report_to = "none",
log_level = "error",
disable_tqdm = !self$params$pytorch_trace
)
}
if(utils::compareVersion(transformers["__version__"],"4.46.0")>=0){
self$temp$trainer <- transformers$Trainer(
model = self$temp$model,
train_dataset = self$temp$tokenized_dataset$train,
eval_dataset = self$temp$tokenized_dataset$test,
args = training_args,
data_collator = self$temp$data_collator,
processing_class = self$temp$tokenizer
)
} else {
self$temp$trainer <- transformers$Trainer(
model = self$temp$model,
train_dataset = self$temp$tokenized_dataset$train,
eval_dataset = self$temp$tokenized_dataset$test,
args = training_args,
data_collator = self$temp$data_collator,
tokenizer = self$temp$tokenizer
)
}
self$temp$trainer$remove_callback(transformers$integrations$CodeCarbonCallback)
if (!as.logical(self$params$pytorch_trace)) {
self$temp$trainer$remove_callback(transformers$PrinterCallback)
self$temp$trainer$remove_callback(transformers$ProgressCallback)
}
# Add Callbacks
self$temp$trainer$add_callback(logger)
}
}
}
# SFT: start_training ---------------------------------------------------------
if (!is.function(private$steps_for_training$start_training)) {
private$steps_for_training$start_training <- function(self) {
if (self$params$ml_framework == "tensorflow") { # TENSORFLOW --------------
self$temp$model$fit(
x = self$temp$tf_train_dataset,
validation_data = self$temp$tf_test_dataset,
epochs = as.integer(self$params$n_epoch),
workers = as.integer(self$params$n_workers),
use_multiprocessing = self$params$multi_process,
callbacks = list(self$temp$callbacks),
verbose = as.integer(self$params$keras_trace)
)
print_message("Load Weights From Best Checkpoint", self$params$trace)
self$temp$model$load_weights(paste0(self$params$output_dir, "/checkpoints/best_weights.h5"))
} else { # PYTORCH -----------------
if (is.function(private$steps_for_training$cuda_empty_cache)) {
private$steps_for_training$cuda_empty_cache()
}
self$temp$trainer$train()
}
}
}
# SFT: save_model -------------------------------------------------------------
if (!is.function(private$steps_for_training$save_model)) {
private$steps_for_training$save_model <- function(self) {
if (self$params$ml_framework == "tensorflow") { # TENSORFLOW --------------
self$temp$model$save_pretrained(
save_directory = self$params$output_dir
)
history_log <- read.csv(file = paste0(self$params$output_dir, "/checkpoints/history.log"))
} else { # PYTORCH -----------------
self$temp$model$save_pretrained(
save_directory = self$params$output_dir,
safe_serilization = self$temp$pt_safe_save
)
history_log <- pandas$DataFrame(self$temp$trainer$state$log_history)
history_log <- clean_pytorch_log_transformers(history_log)
}
# Write history log
write.csv2(
history_log,
file = paste0(self$params$output_dir, "/history.log"),
row.names = FALSE,
quote = FALSE
)
}
}
},
# Creates a sustainability tracker and stores it in the private `sustainability_tracker` attribute
create_sustain_tracker = function() {
tmp_code_carbon<-reticulate::import("codecarbon")
codecarbon_version=as.character(tmp_code_carbon["__version__"])
if(utils::compareVersion(codecarbon_version,"2.8.0")>=0){
path_look_file=codecarbon$lock$LOCKFILE
if(file.exists(path_look_file)){
unlink(path_look_file)
}
}
private$sustainability_tracker <- codecarbon$OfflineEmissionsTracker(
country_iso_code = self$params$sustain_iso_code,
region = self$params$sustain_region,
tracking_mode = "machine",
log_level = "warning",
measure_power_secs = self$params$sustain_interval,
save_to_file = FALSE,
save_to_api = FALSE
)
},
# Save data csv (for sustainability data and tokenizer statistics)
write_data_csv = function(data, csv_file_name, quote = TRUE, mode = "create") {
matrix <- t(as.matrix(unlist(data)))
x_value <- matrix
if (mode == "create") {
data_output <- paste0(self$params$model_dir, "/", csv_file_name)
} else if (mode == "train") {
data_output <- paste0(self$params$output_dir, "/", csv_file_name)
data_input <- paste0(self$params$model_dir_path, "/", csv_file_name)
if (file.exists(data_input)) {
data_chronic <- as.matrix(read.csv(data_input))
data_chronic <- rbind(data_chronic, matrix)
x_value <- data_chronic
}
} else {
stop(paste("Mode", mode, "is invalid. Allowed: create or train"))
}
write.csv(
x = x_value,
file = data_output,
row.names = FALSE,
quote = quote
)
},
# Save Sustainability tracker
save_sustainability_data = function(mode = "create") {
sustainability_data <- summarize_tracked_sustainability(private$sustainability_tracker)
private$write_data_csv(
data = sustainability_data,
csv_file_name = "sustainability.csv",
mode = mode
)
},
# Save Tokenizer Statistics on disk
save_tokenizer_statistics = function(mode = "create") {
private$write_data_csv(
data = self$temp$tokenizer_statistics,
csv_file_name = "tokenizer_statistics.csv",
quote = FALSE,
mode = mode
)
},
# Update Tokenizer Statistics (calculate)
update_tokenizer_statistics = function(dataset, step = "creation") {
self$temp$tokenizer_statistics[length(self$temp$tokenizer_statistics) + 1] <- list(
calc_tokenizer_statistics(
dataset = dataset,
step = step
)
)
}
),
public = list(
# == Attributes ====================================================================================================
#' @field params A list containing transformer's parameters ('static', 'dynamic' and 'dependent' parameters)
#'
#' `list()` containing all the transformer parameters. Can be set with `set_model_param()`.
#'
#' ### **'Static' parameters**
#'
#' Regardless of the transformer, the following parameters are always included:
#' * `ml_framework`
#' * `text_dataset`
#' * `sustain_track`
#' * `sustain_iso_code`
#' * `sustain_region`
#' * `sustain_interval`
#' * `trace`
#' * `pytorch_safetensors`
#' * `log_dir`
#' * `log_write_interval`
#'
#' ### **'Dynamic' parameters**
#'
#' In the case of **create** it also contains (see `create`-method for details):
#' * `model_dir`
#' * `vocab_size`
#' * `max_position_embeddings`
#' * `hidden_size`
#' * `hidden_act`
#' * `hidden_dropout_prob`
#' * `attention_probs_dropout_prob`
#' * `intermediate_size`
#' * `num_attention_heads`
#'
#' In the case of **train** it also contains (see `train`-method for details):
#' * `output_dir`
#' * `model_dir_path`
#' * `p_mask`
#' * `whole_word`
#' * `val_size`
#' * `n_epoch`
#' * `batch_size`
#' * `chunk_size`
#' * `min_seq_len`
#' * `full_sequences_only`
#' * `learning_rate`
#' * `n_workers`
#' * `multi_process`
#' * `keras_trace`
#' * `pytorch_trace`
#'
#' ### **'Dependent' parameters**
#'
#' Depending on the transformer and the method used class may contain different parameters:
#' * `vocab_do_lower_case`
#' * `num_hidden_layer`
#' * `add_prefix_space`
#' * etc.
#'
params = list(),
#' @field temp A list containing temporary transformer's parameters
#'
#' `list()` containing all the temporary local variables that need to be accessed between the step functions. Can
#' be set with `set_model_temp()`.
#'
#' For example, it can be a variable `tok_new` that stores the tokenizer from
#' `steps_for_creation$create_tokenizer_draft`. To train the tokenizer, access the variable `tok_new` in
#' `steps_for_creation$calculate_vocab` through the `temp` list of this class.
#
temp = list(),
# == Methods =======================================================================================================
# New ----
#' @description An object of this class cannot be created. Thus, method's call will produce an error.
#' @return This method returns an error.
initialize = function() {
stop("Cannot create .AIFEBaseTransformer objects.")
},
# Setters ----
#' @description Setter for the title. Sets a new value for the `title` private attribute.
#' @param title `string` A new title.
#' @return This method returns nothing.
set_title = function(title) private$title <- title,
#' @description Setter for the parameters. Adds a new parameter and its value to the `params` list.
#' @param param_name `string` Parameter's name.
#' @param param_value `any` Parameter's value.
#' @return This method returns nothing.
set_model_param = function(param_name, param_value) self$params[[param_name]] <- param_value,
#' @description Setter for the temporary model's parameters. Adds a new temporary parameter and its value to the
#' `temp` list.
#' @param temp_name `string` Parameter's name.
#' @param temp_value `any` Parameter's value.
#' @return This method returns nothing.
set_model_temp = function(temp_name, temp_value) self$temp[[temp_name]] <- temp_value,
# Setters for SFC ----
#' @description Setter for the `check_max_pos_emb` element of the private `steps_for_creation` list. Sets a new
#' `fun` function as the `check_max_pos_emb` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFC_check_max_pos_emb = function(fun) private$steps_for_creation$check_max_pos_emb <- fun,
#' @description Setter for the `create_tokenizer_draft` element of the private `steps_for_creation` list. Sets a
#' new `fun` function as the `create_tokenizer_draft` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFC_create_tokenizer_draft = function(fun) private$steps_for_creation$create_tokenizer_draft <- fun,
#' @description Setter for the `calculate_vocab` element of the private `steps_for_creation` list. Sets a new `fun`
#' function as the `calculate_vocab` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFC_calculate_vocab = function(fun) private$steps_for_creation$calculate_vocab <- fun,
#' @description Setter for the `save_tokenizer_draft` element of the private `steps_for_creation` list. Sets a new
#' `fun` function as the `save_tokenizer_draft` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFC_save_tokenizer_draft = function(fun) private$steps_for_creation$save_tokenizer_draft <- fun,
#' @description Setter for the `create_final_tokenizer` element of the private `steps_for_creation` list. Sets a new
#' `fun` function as the `create_final_tokenizer` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFC_create_final_tokenizer = function(fun) private$steps_for_creation$create_final_tokenizer <- fun,
#' @description Setter for the `create_transformer_model` element of the private `steps_for_creation` list. Sets a
#' new `fun` function as the `create_transformer_model` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFC_create_transformer_model = function(fun) private$steps_for_creation$create_transformer_model <- fun,
#' @description Setter for all required elements of the private `steps_for_creation` list. Executes setters for all
#' required creation steps.
#' @param required_SFC `list()` A list of all new required steps.
#' @return This method returns nothing.
set_required_SFC = function(required_SFC) { # nolint
self$set_SFC_create_tokenizer_draft(required_SFC$create_tokenizer_draft)
self$set_SFC_calculate_vocab(required_SFC$calculate_vocab)
self$set_SFC_save_tokenizer_draft(required_SFC$save_tokenizer_draft)
self$set_SFC_create_final_tokenizer(required_SFC$create_final_tokenizer)
self$set_SFC_create_transformer_model(required_SFC$create_transformer_model)
},
# Setters for SFT ----
#' @description Setter for the `load_existing_model` element of the private `steps_for_training` list. Sets a new
#' `fun` function as the `load_existing_model` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFT_load_existing_model = function(fun) private$steps_for_training$load_existing_model <- fun,
#' @description Setter for the `cuda_empty_cache` element of the private `steps_for_training` list. Sets a new
#' `fun` function as the `cuda_empty_cache` step.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFT_cuda_empty_cache = function(fun) private$steps_for_training$cuda_empty_cache <- fun,
#' @description Setter for the `create_data_collator` element of the private `steps_for_training` list. Sets a new
#' `fun` function as the `create_data_collator` step. Use this method to make a custom data collator for a
#' transformer.
#' @param fun `function()` A new function.
#' @return This method returns nothing.
set_SFT_create_data_collator = function(fun) private$steps_for_training$create_data_collator <- fun,
# Main methods ----
# Create ----
#' @description This method creates a transformer configuration based on the child-transformer architecture and a
#' vocabulary using the python libraries `transformers` and `tokenizers`.
#'
#' This method **adds** the following parameters to the `temp` list:
#' * `log_file`
#' * `raw_text_dataset`
#' * `pt_safe_save`
#' * `value_top`
#' * `total_top`
#' * `message_top`
#'
#' This method **uses** the following parameters from the `temp` list:
#' * `log_file`
#' * `raw_text_dataset`
#' * `tokenizer`
#'
#' @param model_dir `r paramDesc.model_dir()`
#' @param vocab_size `r paramDesc.vocab_size()`
#' @param max_position_embeddings `r paramDesc.max_position_embeddings()`
#' @param hidden_size `r paramDesc.hidden_size()`
#' @param num_attention_heads `r paramDesc.num_attention_heads()`
#' @param intermediate_size `r paramDesc.intermediate_size()`
#' @param hidden_act `r paramDesc.hidden_act()`
#' @param hidden_dropout_prob `r paramDesc.hidden_dropout_prob()`
#' @param attention_probs_dropout_prob `r paramDesc.attention_probs_dropout_prob()`
#'
#' @return This method does not return an object. Instead, it saves the configuration and vocabulary of the new
#' model to disk.
create = function(ml_framework,
model_dir,
text_dataset,
vocab_size,
max_position_embeddings,
hidden_size,
num_attention_heads,
intermediate_size,
hidden_act,
hidden_dropout_prob,
attention_probs_dropout_prob,
sustain_track,
sustain_iso_code,
sustain_region,
sustain_interval,
trace,
pytorch_safetensors,
log_dir,
log_write_interval) {
tryCatch({
private$define_required_SFC_functions()
# Init model parameters -----------------------------------------------------
# Each transformer has these parameters ----
private$init_common_model_params(
ml_framework,
sustain_track, sustain_iso_code, sustain_region, sustain_interval,
trace, pytorch_safetensors,
log_dir, log_write_interval,
text_dataset
)
# Each transformer has these parameters in the case of creation ----
self$set_model_param("model_dir", model_dir)
self$set_model_param("vocab_size", vocab_size)
self$set_model_param("max_position_embeddings", max_position_embeddings)
self$set_model_param("hidden_size", hidden_size)
self$set_model_param("hidden_act", hidden_act)
self$set_model_param("hidden_dropout_prob", hidden_dropout_prob)
self$set_model_param("attention_probs_dropout_prob", attention_probs_dropout_prob)
self$set_model_param("intermediate_size", intermediate_size)
self$set_model_param("num_attention_heads", num_attention_heads)
# Check definition of required functions ----
private$check_required_SFC()
# Logging file
run_py_file("py_log.py")
self$temp$log_file <- NULL
if (!is.null(self$params$log_dir) && dir.exists(self$params$log_dir)) {
self$temp$log_file <- paste0(self$params$log_dir, "/aifeducation_state.log")
}
total <- 10
write_interval <- self$params$log_write_interval
last_log <- NULL
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 0, total_top = total, message_top = paste(private$title, "Overall: Checking"),
last_log = last_log, write_interval = write_interval
)
# argument checking ---------------------------------------------------------
# optional function
if (is.function(private$steps_for_creation$check_max_pos_emb)) {
private$steps_for_creation$check_max_pos_emb(self)
}
check_class(text_dataset, "LargeDataSetForText", FALSE)
self$temp$raw_text_dataset <- text_dataset$get_dataset()
if (is.null(self$temp$raw_text_dataset$features$text)) {
stop("Dataset does not contain a column 'text' storing the raw texts.")
}
check.ml_framework(ml_framework)
(hidden_act)
check.sustain_iso_code(sustain_iso_code, sustain_track)
# Check possible save formats
self$temp$pt_safe_save <- check.possible_save_formats(ml_framework, pytorch_safetensors)
# Start Sustainability Tracking ---------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 1, total_top = total, message_top = paste(private$title, "Overall: Starting"),
last_log = last_log, write_interval = write_interval
)
track_msg <- ifelse(
sustain_track,
"Start Sustainability Tracking",
"Start without Sustainability Tracking"
)
print_message(track_msg, trace)
if (sustain_track) {
private$create_sustain_tracker()
private$sustainability_tracker$start()
}
# Creating a new Tokenizer for Computing Vocabulary -------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 2, total_top = total, message_top = paste(private$title, "Overall: Tokenizer Draft"),
last_log = last_log, write_interval = write_interval
)
print_message("Creating Tokenizer Draft", trace)
private$steps_for_creation$create_tokenizer_draft(self)
# Calculating Vocabulary ----------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 3, total_top = total, message_top = paste(private$title, "Overall: Computing Vocabulary"),
last_log = last_log, write_interval = write_interval
)
self$temp$value_top <- 3
self$temp$total_top <- total
self$temp$message_top <- paste(private$title, "Overall: Computing Vocabulary")
print_message("Start Computing Vocabulary", trace)
private$steps_for_creation$calculate_vocab(self)
print_message("Start Computing Vocabulary - Done", trace)
# Saving Tokenizer Draft ----------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 4, total_top = total, message_top = paste(private$title, "Overall: Saving Tokenizer Draft"),
last_log = last_log, write_interval = write_interval
)
print_message("Saving Draft", trace)
create_dir(model_dir, trace, "Creating Model Directory")
private$steps_for_creation$save_tokenizer_draft(self)
# Final Tokenizer -----------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 5, total_top = total,
message_top = paste(private$title, "Overall: Creating Final Tokenizer & Calculating Statistics"),
last_log = last_log, write_interval = write_interval
)
print_message("Creating Tokenizer", trace)
private$steps_for_creation$create_final_tokenizer(self)
if (!("tokenizer" %in% names(self$temp))) {
stop("The final tokenizer must be stored in the 'tokenizer' parameter of the 'temp' list.")
}
print_message("Creating Tokenizer - Done", trace)
# Calculate tokenizer statistics -------------------------------------------
tokenized_texts_raw <- tokenize_dataset(
dataset = self$temp$raw_text_dataset,
tokenizer = self$temp$tokenizer,
max_length = 2048,
log_file = self$temp$log_file,
write_interval = write_interval,
value_top = 5, total_top = total,
message_top = paste(private$title, "Overall: Creating Final Tokenizer & Calculating Statistics")
)
private$update_tokenizer_statistics(tokenized_texts_raw, "creation")
# Creating Transformer Model ------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 6, total_top = total, message_top = paste(private$title, "Overall: Creating Transformer Model"),
last_log = last_log, write_interval = write_interval
)
print_message("Creating Transformer Model", trace)
private$steps_for_creation$create_transformer_model(self)
if (!("model" %in% names(self$temp))) {
stop("The transformer model must be stored in the 'model' parameter of the 'temp' list.")
}
# Saving Model --------------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 7, total_top = total, message_top = paste(private$title, "Overall: Saving Model"),
last_log = last_log, write_interval = write_interval
)
print_message(paste("Saving", private$title), trace)
private$steps_for_creation$save_transformer_model(self)
private$save_tokenizer_statistics()
# Saving Tokenizer ----------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 8, total_top = total, message_top = paste(private$title, "Overall: Saving Tokenizer"),
last_log = last_log, write_interval = write_interval
)
print_message("Saving Tokenizer Model", trace)
self$temp$tokenizer$save_pretrained(model_dir)
# Stop Sustainability Tracking if requested ----------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 9, total_top = total, message_top = paste(private$title, "Overall: Stopping"),
last_log = last_log, write_interval = write_interval
)
if (sustain_track) {
private$sustainability_tracker$stop()
print_message("Saving Sustainability Data", trace)
private$save_sustainability_data()
}
# Finish --------------------------------------------------------------------
py$write_log_py(
self$temp$log_file,
value_top = 10, total_top = total, message_top = paste(private$title, "Overall: Done"),
last_log = NULL, write_interval = write_interval
)
print_message("Done", trace)
# Clear variables -----------------------------------------------------------
private$clear_variables()
}, finally = {
if (!is.null(private$sustainability_tracker)) {
private$sustainability_tracker$stop()
}
})
},
# Train ----
#' @description This method can be used to train or fine-tune a transformer based on `BERT` architecture with the
#' help of the python libraries `transformers`, `datasets`, and `tokenizers`.
#'
#' This method **adds** the following parameters to the `temp` list:
#' * `log_file`
#' * `loss_file`
#' * `from_pt`
#' * `from_tf`
#' * `load_safe`
#' * `raw_text_dataset`
#' * `pt_safe_save`
#' * `value_top`
#' * `total_top`
#' * `message_top`
#'
#' This method **uses** the following parameters from the `temp` list:
#' * `log_file`
#' * `raw_text_dataset`
#' * `tokenized_dataset`
#' * `tokenizer`
#'
#' @param output_dir `r paramDesc.output_dir()`
#' @param model_dir_path `r paramDesc.model_dir_path()`
#' @param p_mask `r paramDesc.p_mask()`
#' @param whole_word `r paramDesc.whole_word()`
#' @param val_size `r paramDesc.val_size()`
#' @param n_epoch `r paramDesc.n_epoch()`
#' @param batch_size `r paramDesc.batch_size()`
#' @param chunk_size `r paramDesc.chunk_size()`
#' @param full_sequences_only `r paramDesc.full_sequences_only()`
#' @param min_seq_len `r paramDesc.min_seq_len()`
#' @param learning_rate `r paramDesc.learning_rate()`
#' @param n_workers `r paramDesc.n_workers()`
#' @param multi_process `r paramDesc.multi_process()`
#' @param keras_trace `r paramDesc.keras_trace()`
#' @param pytorch_trace `r paramDesc.pytorch_trace()`
#'
#' @return This method does not return an object. Instead, it saves the configuration and vocabulary of the new
#' model to disk.
#'
#' @importFrom utils write.csv
#' @importFrom utils read.csv
train = function(ml_framework,
output_dir,
model_dir_path,
text_dataset,
p_mask,
whole_word,
val_size,
n_epoch,
batch_size,
chunk_size,
full_sequences_only,
min_seq_len,
learning_rate,
n_workers,
multi_process,
sustain_track,
sustain_iso_code,
sustain_region,
sustain_interval,
trace,
keras_trace,
pytorch_trace,
pytorch_safetensors,
log_dir,
log_write_interval) {
tryCatch({
private$define_required_SFT_functions()
# Init model parameters -----------------------------------------------------
# Each transformer has these parameters ----
private$init_common_model_params(
ml_framework,
sustain_track, sustain_iso_code, sustain_region, sustain_interval,
trace, pytorch_safetensors,
log_dir, log_write_interval,
text_dataset
)
# Each transformer has these parameters in the case of training ----
self$set_model_param("output_dir", output_dir)
self$set_model_param("model_dir_path", model_dir_path)
self$set_model_param("p_mask", p_mask)
self$set_model_param("whole_word", whole_word)
self$set_model_param("val_size", val_size)
self$set_model_param("n_epoch", n_epoch)
self$set_model_param("batch_size", batch_size)
self$set_model_param("chunk_size", chunk_size)
self$set_model_param("full_sequences_only", full_sequences_only)
self$set_model_param("min_seq_len", min_seq_len)
self$set_model_param("learning_rate", learning_rate)
self$set_model_param("n_workers", n_workers)
self$set_model_param("multi_process", multi_process)
self$set_model_param("keras_trace", keras_trace)
self$set_model_param("pytorch_trace", pytorch_trace)
# Check defining of required functions ----
private$check_required_SFT()
# Logging file
run_py_file("py_log.py")
self$temp$log_file <- NULL
self$temp$loss_file <- NULL
if (!is.null(self$params$log_dir) && dir.exists(self$params$log_dir)) {
self$temp$log_file <- paste0(self$params$log_dir, "/aifeducation_state.log")
self$temp$loss_file <- paste0(self$params$log_dir, "/aifeducation_loss.log")
} else {
self$temp$loss_file <- paste0(self$params$output_dir, "/aifeducation_loss.log")
}
total <- 10
write_interval <- self$params$log_write_interval
last_log <- NULL
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 0, total_top = total, message_top = paste(private$title, "Overall: Checking"),
last_log = last_log, write_interval = write_interval
)
# argument checking ---------------------------------------------------------
check.ml_framework(ml_framework)
model_files_check <- check.model_files(ml_framework, model_dir_path)
self$temp$from_pt <- model_files_check$from_pt
self$temp$from_tf <- model_files_check$from_tf
self$temp$load_safe <- model_files_check$load_safe
check_class(text_dataset, "LargeDataSetForText", FALSE)
self$temp$raw_text_dataset <- text_dataset$get_dataset()
if (is.null(self$temp$raw_text_dataset$features$text)) {
stop("Dataset does not contain a column 'text' storing the texts for training.")
}
check.sustain_iso_code(sustain_iso_code, sustain_track)
# Check possible save formats
self$temp$pt_safe_save <- check.possible_save_formats(ml_framework, pytorch_safetensors)
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 1, total_top = total, message_top = paste(private$title, "Overall: Starting"),
last_log = last_log, write_interval = write_interval
)
# Start Sustainability Tracking ---------------------------------------------
track_msg <- ifelse(
sustain_track,
"Start Sustainability Tracking",
"Start without Sustainability Tracking"
)
print_message(track_msg, trace)
if (sustain_track) {
private$create_sustain_tracker()
private$sustainability_tracker$start()
}
# Loading existing model ----------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 2, total_top = total, message_top = paste(private$title, "Overall: Loading Existing Model"),
last_log = last_log, write_interval = write_interval
)
print_message("Loading Existing Model", trace)
private$steps_for_training$load_existing_model(self)
if (!("tokenizer" %in% names(self$temp))) {
stop("The tokenizer must be stored in the 'tokenizer' parameter of the 'temp' list.")
}
if (!("model" %in% names(self$temp))) {
stop("The transformer model must be stored in the 'model' parameter of the 'temp' list.")
}
# argument checking------------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 3, total_top = total, message_top = paste(private$title, "Overall: Checking"),
last_log = last_log, write_interval = write_interval
)
private$steps_for_training$check_chunk_size(self)
# creating chunks of sequences ----------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 4, total_top = total,
message_top = paste(private$title, "Overall: Creating Chunks of Sequences & Calculating Tokenizer Statistics"),
last_log = last_log, write_interval = write_interval
)
self$temp$write_interval <- write_interval
self$temp$value_top <- 4
self$temp$total_top <- total
self$temp$message_top <- paste(private$title, "Overall: Creating Chunks of Sequences & Calculating Tokenizer Statistics")
print_message("Creating Chunks of Sequences for Training", trace)
private$steps_for_training$create_chunks_for_training(self)
n_chunks <- self$temp$tokenized_dataset$num_rows
print_message(paste(n_chunks, "Chunks Created"), trace)
# Seeting up DataCollator and Dataset ----------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 5, total_top = total,
message_top = paste(private$title, "Overall: Seeting up DataCollator and Dataset"),
last_log = last_log, write_interval = write_interval
)
create_dir(output_dir, trace, "Creating Output Directory")
create_dir(paste0(output_dir, "/checkpoints"), trace, "Creating Checkpoint Directory")
private$steps_for_training$prepare_train_tune(self)
# Start Training -------------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 6, total_top = total, message_top = paste(private$title, "Overall: Start Training"),
last_log = last_log, write_interval = write_interval
)
print_message("Start Fine Tuning", trace)
private$steps_for_training$start_training(self)
# Saving Model --------------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 7, total_top = total, message_top = paste(private$title, "Overall: Saving Model"),
last_log = last_log, write_interval = write_interval
)
print_message(paste("Saving", private$title), trace)
private$steps_for_training$save_model(self)
private$save_tokenizer_statistics("train")
# Saving Tokenizer -----------------------------------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 8, total_top = total, message_top = paste(private$title, "Overall: Saving Tokenizer"),
last_log = last_log, write_interval = write_interval
)
print_message("Saving Tokenizer", trace)
self$temp$tokenizer$save_pretrained(output_dir)
# Stop Sustainability Tracking if requested ----------------------------------
last_log <- py$write_log_py(
self$temp$log_file,
value_top = 9, total_top = total, message_top = paste(private$title, "Overall: Stopping"),
last_log = last_log, write_interval = write_interval
)
if (sustain_track) {
private$sustainability_tracker$stop()
print_message("Saving Sustainability Data", trace)
private$save_sustainability_data("train")
}
# Finish --------------------------------------------------------------------
py$write_log_py(
self$temp$log_file,
value_top = 10, total_top = total, message_top = paste(private$title, "Overall: Done"),
last_log = NULL, write_interval = write_interval
)
print_message("Done", trace)
# Clear variables -----------------------------------------------------------
private$clear_variables()
}, finally = {
if (!is.null(private$sustainability_tracker)) {
private$sustainability_tracker$stop()
}
})
}
)
)
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.