| BaseModelCore | R Documentation |
This class contains all methods shared by all BaseModels.
Does return a new object of this class.
aifeducation::AIFEMaster -> aifeducation::AIFEBaseModel -> BaseModelCore
Tokenizer('TokenizerBase')
Objects of class TokenizerBase.
aifeducation::AIFEMaster$get_all_fields()aifeducation::AIFEMaster$get_documentation_license()aifeducation::AIFEMaster$get_ml_framework()aifeducation::AIFEMaster$get_model_config()aifeducation::AIFEMaster$get_model_description()aifeducation::AIFEMaster$get_model_info()aifeducation::AIFEMaster$get_model_license()aifeducation::AIFEMaster$get_package_versions()aifeducation::AIFEMaster$get_private()aifeducation::AIFEMaster$get_publication_info()aifeducation::AIFEMaster$get_sustainability_data()aifeducation::AIFEMaster$is_configured()aifeducation::AIFEMaster$is_trained()aifeducation::AIFEMaster$set_documentation_license()aifeducation::AIFEMaster$set_model_description()aifeducation::AIFEMaster$set_model_license()create_from_hf()Creates BaseModel from a pretrained model
BaseModelCore$create_from_hf(model_dir = NULL, tokenizer_dir = NULL)
model_dirPath where the model is stored.
tokenizer_dirstring Path to the directory where the tokenizer is saved. Allowed values: any
Does return a new object of this class.
train()Traines a BaseModel
BaseModelCore$train( text_dataset, p_mask = 0.15, whole_word = TRUE, val_size = 0.1, n_epoch = 1L, batch_size = 12L, max_sequence_length = 250L, full_sequences_only = FALSE, min_seq_len = 50L, learning_rate = 0.003, sustain_track = FALSE, sustain_iso_code = NULL, sustain_region = NULL, sustain_interval = 15L, sustain_log_level = "warning", trace = TRUE, pytorch_trace = 1L, log_dir = NULL, log_write_interval = 2L )
text_datasetLargeDataSetForText LargeDataSetForText Object storing textual data.
p_maskdouble Ratio that determines the number of tokens used for masking. Allowed values: 0.05 <= x <= 0.6
whole_wordbool * TRUE: whole word masking should be applied. Only relevant if a WordPieceTokenizer is used.
FALSE: token masking is used.
val_sizedouble between 0 and 1, indicating the proportion of cases which should be
used for the validation sample during the estimation of the model.
The remaining cases are part of the training data. Allowed values: 0 < x < 1
n_epochint Number of training epochs. Allowed values: 1 <= x
batch_sizeint Size of the batches for training. Allowed values: 1 <= x
max_sequence_lengthint Maximal number of tokens for every sequence. Allowed values: 20 <= x
full_sequences_onlybool TRUE for using only chunks with a sequence length equal to chunk_size.
min_seq_lenint Only relevant if full_sequences_only = FALSE. Value determines the minimal sequence length included in
training process. Allowed values: 10 <= x
learning_ratedouble Initial learning rate for the training. Allowed values: 0 < x <= 1
sustain_trackbool If TRUE energy consumption is tracked during training via the python library 'codecarbon'.
sustain_iso_codestring ISO code (Alpha-3-Code) for the country. This variable must be set if
sustainability should be tracked. A list can be found on Wikipedia:
https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: any
sustain_regionstring Region within a country. Only available for USA and Canada See the documentation of
codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: any
sustain_intervalint Interval in seconds for measuring power usage. Allowed values: 1 <= x
sustain_log_levelstring Level for printing information to the console. Allowed values: 'debug', 'info', 'warning', 'error', 'critical'
tracebool TRUE if information about the estimation phase should be printed to the console.
pytorch_traceint ml_trace=0 does not print any information about the training process from pytorch on the console. Allowed values: 0 <= x <= 1
log_dirstring Path to the directory where the log files should be saved.
If no logging is desired set this argument to NULL. Allowed values: any
log_write_intervalint Time in seconds determining the interval in which the logger should try to update
the log files. Only relevant if log_dir is not NULL. Allowed values: 1 <= x
Does nothing return.
count_parameter()Method for counting the trainable parameters of a model.
BaseModelCore$count_parameter()
Returns the number of trainable parameters of the model.
plot_training_history()Method for requesting a plot of the training history. This method requires the R package 'ggplot2' to work.
BaseModelCore$plot_training_history( x_min = NULL, x_max = NULL, y_min = NULL, y_max = NULL, ind_best_model = TRUE, text_size = 10L )
x_minint Minimal value for x-axis. Set to NULL for an automatic adjustment. Allowed values: x
x_maxint Maximal value for x-axis. Set to NULL for an automatic adjustment. Allowed values: x
y_minint Minimal value for y-axis. Set to NULL for an automatic adjustment. Allowed values: x
y_maxint Maximal value for y-axis. Set to NULL for an automatic adjustment. Allowed values: x
ind_best_modelbool If TRUE the plot indicates the best states of the model according to the chosen measure.
text_sizeint Size of text elements. Allowed values: 1 <= x
Returns a plot of class ggplot visualizing the training process.
get_special_tokens()Method for receiving the special tokens of the model
BaseModelCore$get_special_tokens()
Returns a matrix containing the special tokens in the rows
and their type, token, and id in the columns.
get_tokenizer_statistics()Tokenizer statistics
BaseModelCore$get_tokenizer_statistics()
Returns a data.frame containing the tokenizer's statistics.
fill_mask()Method for calculating tokens behind mask tokens.
BaseModelCore$fill_mask(masked_text, n_solutions = 5L)
masked_textstring Text with mask tokens. Allowed values: any
n_solutionsint Number of solutions the model should predict. Allowed values: 1 <= x
Returns a list containing a data.frame for every
mask. The data.frame contains the solutions in the rows and reports
the score, token id, and token string in the columns.
save()Method for saving a model on disk.
BaseModelCore$save(dir_path, folder_name)
dir_pathPath to the directory where to save the object.
folder_namestring Name of the folder where the model should be saved. Allowed values: any
Function does nothing return. It is used to save an object on disk.
load_from_disk()Loads an object from disk and updates the object to the current version of the package.
BaseModelCore$load_from_disk(dir_path)
dir_pathPath where the object set is stored.
Function does nothin return. It loads an object from disk.
get_model()Get 'PyTorch' model
BaseModelCore$get_model()
Returns the underlying 'PyTorch' model.
get_model_type()Type of the underlying model.
BaseModelCore$get_model_type()
Returns a string describing the model's architecture.
get_final_size()Size of the final layer.
BaseModelCore$get_final_size()
Returns an int describing the number of dimensions of the last
hidden layer.
get_n_layers()Number of layers.
BaseModelCore$get_n_layers()
Returns an int describing the number of layers available for
embedding.
get_flops_estimates()Flop estimates
BaseModelCore$get_flops_estimates()
Returns a data.frame containing statistics about the flops.
set_publication_info()Method for setting the bibliographic information of the model.
BaseModelCore$set_publication_info(type, authors, citation, url = NULL)
typestring Type of information which should be changed/added.
developer, and modifier are possible.
authorsList of people.
citationstring Citation in free text.
urlstring Corresponding URL if applicable.
Function does not return a value. It is used to set the private members for publication information of the model.
estimate_sustainability_inference_fill_mask()Calculates the energy consumption for inference of the given task.
BaseModelCore$estimate_sustainability_inference_fill_mask( text_dataset = NULL, n_samples = NULL, sustain_iso_code = NULL, sustain_region = NULL, sustain_interval = 15L, sustain_log_level = "warning", trace = TRUE )
text_datasetLargeDataSetForText LargeDataSetForText Object storing textual data.
n_samplesint Number of samples. Allowed values: 1 <= x
sustain_iso_codestring ISO code (Alpha-3-Code) for the country. This variable must be set if
sustainability should be tracked. A list can be found on Wikipedia:
https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes. Allowed values: any
sustain_regionstring Region within a country. Only available for USA and Canada See the documentation of
codecarbon for more information. https://mlco2.github.io/codecarbon/parameters.html Allowed values: any
sustain_intervalint Interval in seconds for measuring power usage. Allowed values: 1 <= x
sustain_log_levelstring Level for printing information to the console. Allowed values: 'debug', 'info', 'warning', 'error', 'critical'
tracebool TRUE if information about the estimation phase should be printed to the console.
Returns nothing. Method saves the statistics internally.
The statistics can be accessed with the method get_sustainability_data("inference")
calc_flops_architecture_based()Calculates FLOPS based on model's architecture.
BaseModelCore$calc_flops_architecture_based(batch_size, n_batches, n_epoch)
batch_sizeint Size of the batches for training. Allowed values: 1 <= x
n_batchesint Number of batches. Allowed values: 1 <= x
n_epochint Number of training epochs. Allowed values: 1 <= x
Returns a data.frame storing the estimates.
clone()The objects of this class are cloneable with this method.
BaseModelCore$clone(deep = FALSE)
deepWhether to make a deep clone.
Other R6 Classes for Developers:
AIFEBaseModel,
AIFEMaster,
ClassifiersBasedOnTextEmbeddings,
DataManagerClassifier,
LargeDataSetBase,
ModelsBasedOnTextEmbeddings,
TEClassifiersBasedOnProtoNet,
TEClassifiersBasedOnRegular,
TokenizerBase
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.