R/obj_BaseModelMPNet.R

# This file is part of the R package "aifeducation".
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as published by
# the Free Software Foundation.
#
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>

#' @title MPNet
#' @description Represents models based on MPNet.
#' @references Song,K., Tan, X., Qin, T., Lu, J. & Liu, T.-Y. (2020). MPNet: Masked and Permuted Pre-training for
#'   Language Understanding. \doi{10.48550/arXiv.2004.09297}
#' @return `r get_description("return_object")`
#' @family Base Model
#' @export
BaseModelMPNet <- R6::R6Class(
  classname = "BaseModelMPNet",
  inherit = BaseModelCore,
  private = list(
    model_type = "mpnet",
    adjust_max_sequence_length = 2L,
    return_token_type_ids = FALSE,
    create_model = function(args) {
      configuration <- transformers$MPNetConfig(
        # vocab_size = as.integer(length(args$tokenizer$get_tokenizer()$get_vocab())+length(unique(args$tokenizer$get_tokenizer()$special_tokens_map))),
        vocab_size = as.integer(length(args$tokenizer$get_tokenizer()$get_vocab())),
        hidden_size = as.integer(args$hidden_size),
        num_hidden_layers = as.integer(args$num_hidden_layers),
        num_attention_heads = as.integer(args$num_attention_heads),
        intermediate_size = as.integer(args$intermediate_size),
        hidden_act = tolower(args$hidden_act),
        hidden_dropout_prob = args$hidden_dropout_prob,
        attention_probs_dropout_prob = args$attention_probs_dropout_prob,
        max_position_embeddings = as.integer(args$max_position_embeddings),
        initializer_range = 0.02,
        layer_norm_eps = 1e-12
      )

      run_py_file("MPNetForMPLM_PT.py")
      device <- ifelse(torch$cuda$is_available(), "cuda", "cpu")
      private$model <- py$MPNetForMPLM_PT(configuration)$to(device)
    },
    #--------------------------------------------------------------------------
    create_data_collator = function() {
      collator_maker <- NULL
      run_py_file("DataCollatorForMPLM_PT.py")
      collator_maker <- py$CollatorMaker_PT(
        tokenizer = self$Tokenizer$get_tokenizer(),
        mlm = TRUE,
        mlm_probability = self$last_training$config$p_mask,
        plm_probability = self$last_training$config$p_perm,
        mask_whole_words = self$last_training$config$whole_word
      )
      return(collator_maker$collator$collate_batch)
    },
    #--------------------------------------------------------------------------
    load_BaseModel = function(dir_path) {
      run_py_file("MPNetForMPLM_PT.py")
      private$model <- py$MPNetForMPLM_PT$from_pretrained(
        dir_path,
        from_tf = FALSE,
        use_safetensors = TRUE
      )
    },
    #---------------------------------------------------------------------------
    check_arg_combinations = function(args) {
      if (args$hidden_size %% args$num_attention_heads != 0L) {
        stop("hidden_size must be a multiple auf num_attention_heads.")
      }
    }
  ),
  public = list(
    #---------------------------------------------------------------------------
    #' @description Configures a new object of this class.
    #' Please ensure that your chosen configuration comply with the following
    #' guidelines:
    #' * hidden_size is a multiple of num_attention_heads.
    #'
    #' @param tokenizer `r get_param_doc_desc("tokenizer")`
    #' @param max_position_embeddings `r get_param_doc_desc("max_position_embeddings")`
    #' @param hidden_size `r get_param_doc_desc("hidden_size")`
    #' @param num_hidden_layers `r get_param_doc_desc("num_hidden_layers")`
    #' @param num_attention_heads `r get_param_doc_desc("num_attention_heads")`
    #' @param intermediate_size `r get_param_doc_desc("intermediate_size")`
    #' @param hidden_act `r get_param_doc_desc("hidden_act")`
    #' @param hidden_dropout_prob `r get_param_doc_desc("hidden_dropout_prob")`
    #' @param attention_probs_dropout_prob `r get_param_doc_desc("attention_probs_dropout_prob")`
    #' @return `r get_description("return_nothing")`
    configure = function(tokenizer,
                         max_position_embeddings = 512L,
                         hidden_size = 768L,
                         num_hidden_layers = 12L,
                         num_attention_heads = 12L,
                         intermediate_size = 3072L,
                         hidden_act = "GELU",
                         hidden_dropout_prob = 0.1,
                         attention_probs_dropout_prob = 0.1) {
      arguments <- get_called_args(n = 1L)
      private$do_configuration(args = arguments)
    },
    #--------------------------------------------------------------------------
    #' @description Traines a BaseModel
    #' @param text_dataset `r get_param_doc_desc("text_dataset")`
    #' @param p_mask `r get_param_doc_desc("p_mask")`
    #' @param p_perm `r get_param_doc_desc("p_perm")`
    #' @param whole_word `r get_param_doc_desc("whole_word")`
    #' @param val_size `r get_param_doc_desc("val_size")`
    #' @param n_epoch `r get_param_doc_desc("n_epoch")`
    #' @param batch_size `r get_param_doc_desc("batch_size")`
    #' @param max_sequence_length `r get_param_doc_desc("max_sequence_length")`
    #' @param full_sequences_only `r get_param_doc_desc("full_sequences_only")`
    #' @param min_seq_len `r get_param_doc_desc("min_seq_len")`
    #' @param learning_rate `r get_param_doc_desc("learning_rate")`
    #' @param sustain_track `r get_param_doc_desc("sustain_track")`
    #' @param sustain_iso_code `r get_param_doc_desc("sustain_iso_code")`
    #' @param sustain_region `r get_param_doc_desc("sustain_region")`
    #' @param sustain_interval `r get_param_doc_desc("sustain_interval")`
    #' @param sustain_log_level `r get_param_doc_desc("sustain_log_level")`
    #' @param trace `r get_param_doc_desc("trace")`
    #' @param pytorch_trace `r get_param_doc_desc("pytorch_trace")`
    #' @param log_dir `r get_param_doc_desc("log_dir")`
    #' @param log_write_interval `r get_param_doc_desc("log_write_interval")`
    #' @return `r get_description("return_nothing")`
    train = function(text_dataset,
                     p_mask = 0.15,
                     p_perm = 0.15,
                     whole_word = TRUE,
                     val_size = 0.1,
                     n_epoch = 1L,
                     batch_size = 12L,
                     max_sequence_length = 250L,
                     full_sequences_only = FALSE,
                     min_seq_len = 50L,
                     learning_rate = 3e-3,
                     sustain_track = FALSE,
                     sustain_iso_code = NULL,
                     sustain_region = NULL,
                     sustain_interval = 15L,
                     sustain_log_level = "warning",
                     trace = TRUE,
                     pytorch_trace = 1L,
                     log_dir = NULL,
                     log_write_interval = 2L) {
      run_py_file("data_collator.py")
      private$do_training(args = get_called_args(n = 1L))
    }
  )
)

# Add the model to the user list
BaseModelsIndex$MPNet <- ("BaseModelMPNet")

Try the aifeducation package in your browser

Any scripts or data that you put into this service are public.

aifeducation documentation built on Nov. 19, 2025, 5:08 p.m.