tests/testthat/test-02_07_BaseModels_from_hf.R

testthat::skip_on_cran()

testthat::skip_if_not(
  condition = check_aif_py_modules(trace = FALSE),
  message = "Necessary python modules not available"
)

# Start time
test_time_start <- Sys.time()

# Config transformer library
transformers$utils$logging$set_verbosity_error()
os$environ$setdefault("TOKENIZERS_PARALLELISM", "false")

# Disable tqdm progressbar
transformers$logging$disable_progress_bar()
datasets$disable_progress_bars()

# Path Management
test_art_path <- testthat::test_path("test_artefacts")
test_art_tmp_path <- testthat::test_path("test_artefacts/base_models")
create_dir(test_art_path, FALSE)
create_dir(test_art_tmp_path, FALSE)

test_tmp_data_path <- testthat::test_path("test_data_tmp")
create_dir(test_tmp_data_path, FALSE)
test_tmp_data_base_model_path <- paste0(test_tmp_data_path, "/", "TEM")

create_dir(test_tmp_data_base_model_path, FALSE)

# Test Configuration
object_class_names <- BaseModelsIndex
# object_class_names <- c(
#  "BaseModelBert",
#  #"BaseModelFunnel",
#  #"BaseModelLongformer",
#  "BaseModelModernBert",
#  #"BaseModelRoberta",
#  "BaseModelMPNet"
#  )

for (object_class_name in object_class_names) {
  # Data Management
  example_data <- imdb_movie_reviews
  raw_texts <- LargeDataSetForText$new(example_data)

  raw_texts_training <- LargeDataSetForText$new(example_data[1:50, ])

  base_to_existing_base_mode <- paste0(test_tmp_data_base_model_path, "/", object_class_name)

  train_args <- generate_args_for_tests(
    object_name = object_class_name,
    method = "train",
    var_objects = list(),
    necessary_objects = list(
      text_dataset = raw_texts_training
    )
  )

  # Create and train model
  base_model <- create_object(object_class_name)
  suppressMessages(
    do.call(
      what = base_model$create_from_hf,
      args = list(
        model_dir = base_to_existing_base_mode,
        tokenizer_dir = paste0(base_to_existing_base_mode, "/", "tokenizer")
      )
    )
  )

  suppressMessages(
    do.call(
      what = base_model$train,
      args = train_args
    )
  )

  # Prepare directory
  dir_path_new <- paste0(test_art_tmp_path, "/", generate_id(10))
  tmp_dir <- paste0(dir_path_new, "/", object_class_name)
  # Clear directory for next test
  unlink(paste0(tmp_dir, "/", object_class_name), recursive = TRUE)
  create_dir(tmp_dir, trace = FALSE)

  #--------------------------------------------------------------------------
  test_that(paste(
    "Save Model",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    expect_no_error(
      save_to_disk(
        object = base_model,
        dir_path = dir_path_new,
        folder_name = object_class_name
      )
    )
  })

  test_that(paste(
    "Sustainability Tracking",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    if (train_args$sustain_track == TRUE) {
      expect_gte(nrow(base_model$get_sustainability_data()), 1)
    } else {
      expect_gte(nrow(base_model$get_sustainability_data()), 0)
    }
  })

  test_that(paste(
    "History Plot",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    history <- base_model$last_training$history
    expect_equal(nrow(history), 2)
    expect_equal(ncol(history), 3)
    expect_true("epoch" %in% colnames(history))
    expect_true("loss" %in% colnames(history))
    expect_true("val_loss" %in% colnames(history))

    expect_s3_class(object = base_model$plot_training_history(y_min = NULL, y_max = NULL), class = "ggplot")
    expect_s3_class(object = base_model$plot_training_history(y_min = 0, y_max = NULL), class = "ggplot")
    expect_s3_class(object = base_model$plot_training_history(y_min = 0, y_max = 10), class = "ggplot")
    expect_s3_class(object = base_model$plot_training_history(y_min = NULL, y_max = 10), class = "ggplot")
  })

  test_that(paste(
    "Fill-Mask",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    tokens <- base_model$get_special_tokens()
    mask_token <- tokens[which(tokens[, 1] == "mask_token"), 2]

    first_solution <- base_model$fill_mask(
      masked_text = paste("This is a", mask_token, "."),
      n_solutions = 5
    )

    expect_equal(length(first_solution), 1)
    expect_true(is.data.frame(first_solution[[1]]))
    expect_equal(nrow(first_solution[[1]]), 5)
    expect_equal(ncol(first_solution[[1]]), 3)

    second_solution <- base_model$fill_mask(
      masked_text = paste("This is a", mask_token, "."),
      n_solutions = 1
    )
    expect_equal(length(second_solution), 1)
    expect_true(is.data.frame(second_solution[[1]]))
    expect_equal(nrow(second_solution[[1]]), 1)
    expect_equal(ncol(second_solution[[1]]), 3)

    third_solution <- base_model$fill_mask(
      masked_text = paste(
        "This is a", mask_token, ".",
        "The weather is", mask_token, "."
      ),
      n_solutions = 5
    )
    expect_equal(length(third_solution), 2)
    for (i in 1:2) {
      expect_true(is.data.frame(third_solution[[i]]))
      expect_equal(nrow(third_solution[[i]]), 5)
      expect_equal(ncol(third_solution[[i]]), 3)
    }
  })

  test_that(paste(
    "Sustainaility Inference Fill-Mask",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    n_repeat <- 2
    start_values <- nrow(base_model$get_sustainability_data("inference"))
    for (j in 1:n_repeat) {
      suppressMessages(
        base_model$estimate_sustainability_inference_fill_mask(
          text_dataset = raw_texts_training,
          n = 30,
          sustain_iso_code = "DEU",
          sustain_region = NULL,
          sustain_interval = 15,
          trace = train_args$trace,
          sustain_log_level = "error"
        )
      )
      expect_equal(nrow(base_model$get_sustainability_data("inference")), j + start_values)
    }
  })

  test_that(paste(
    "Flops Estimates",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    expect_equal(nrow(base_model$get_flops_estimates()), 1)
    expect_gt(base_model$get_flops_estimates()$flops_bp_1, 0)
    expect_gt(base_model$get_flops_estimates()$flops_bp_2, 0)
    expect_gt(base_model$get_flops_estimates()$flops_bp_3, 0)
    expect_gt(base_model$get_flops_estimates()$flops_bp_4, 0)
  })

  #---------------------------------------------------------------------------
  # Re-Load Base Model and compare with the initial model
  base_model_reloaded <- load_from_disk(
    dir_path = tmp_dir
  )

  test_that(paste(
    "Saving and Loading",
    object_class_name,
    get_current_args_for_print(train_args)
  ), {
    expect_equal(
      base_model$count_parameter(),
      base_model_reloaded$count_parameter()
    )

    # Not available for models from hugging face
    # expect_equal(
    #  base_model$get_sustainability_data(),
    #  base_model_reloaded$get_sustainability_data()
    # )

    # Not available for models from hugging face
    # expect_equal(
    #  base_model$Tokenizer$get_tokenizer_statistics(),
    #  base_model_reloaded$Tokenizer$get_tokenizer_statistics()
    # )

    expect_equal(
      base_model$Tokenizer$get_sustainability_data(),
      base_model_reloaded$Tokenizer$get_sustainability_data()
    )
  })

  # Clear directory for next test
  unlink(paste0(tmp_dir), recursive = TRUE)
}

# Monitor test time
monitor_test_time_on_CI(
  start_time = test_time_start,
  test_name = "02_07_BaseModels_from_hf"
)

Try the aifeducation package in your browser

Any scripts or data that you put into this service are public.

aifeducation documentation built on Nov. 19, 2025, 5:08 p.m.