tests/testthat/test-04_DataManager.R

testthat::skip_on_cran()
testthat::skip_if_not(
  condition = check_aif_py_modules(trace = FALSE),
  message = "Necessary python modules not available"
)

# SetUp-------------------------------------------------------------------------
root_path_general_data<-testthat::test_path("test_data_tmp/Embeddings")
root_path_data <- testthat::test_path("test_data/classifier")
# if(dir.exists(testthat::test_path("test_artefacts"))==FALSE){
#  dir.create(testthat::test_path("test_artefacts"))
# }
# root_path_results=testthat::test_path("test_artefacts/DataManager")
# if(dir.exists(root_path_results)==FALSE){
#  dir.create(root_path_results)
# }

# SetUp datasets
# Disable tqdm progressbar
transformers$logging$disable_progress_bar()
datasets$disable_progress_bars()

# Load test data
imdb_embeddings=load_from_disk(paste0(root_path_general_data, "/imdb_embeddings"))
current_embeddings <- imdb_embeddings$clone(deep = TRUE)
example_data <- imdb_movie_reviews

n_classes <- 2

example_data$label <- as.character(example_data$label)
example_data$label[c(201:300)] <- NA
if (n_classes > 2) {
  example_data$label[c(201:250)] <- "medium"
}
example_targets <- as.factor(example_data$label)
names(example_targets) <- example_data$id

table(example_targets)
data_targets <- example_targets
data_embeddings <- current_embeddings

# config test
folds <- c(2, 5)
methods <- c("dbsmote", "smote")
datasets$disable_progress_bars()

# Start Tests-------------------------------------------------------------------
for (method in methods) {
  for (fold in folds) {
    test_datamanager <- DataManagerClassifier$new(
      data_embeddings = data_embeddings,
      data_targets = data_targets,
      folds = fold,
      val_size = 0.25,
      class_levels = levels(data_targets),
      one_hot_encoding = TRUE,
      add_matrix_map = TRUE,
      sc_methods = method,
      sc_min_k = 1,
      sc_max_k = 10,
      trace = FALSE,
      n_cores=2
    )

    for (i in 1:(test_datamanager$get_n_folds() + 1)) {
      sample <- test_datamanager$get_samples()[[i]]
      #-----------------------------------------------------------------------------
      test_that(paste("DataManager - Valid Splits", "Fold:", i), {
        # Test if no case is missing
        expect_equal(length(sample$train) + length(sample$val) + length(sample$test), length(na.omit(example_targets)))

        # Test if the splits are disjunctive
        expect_equal(length(intersect(sample$train, sample$val)), 0)
        expect_equal(length(intersect(sample$train, sample$test)), 0)
        expect_equal(length(intersect(sample$val, sample$test)), 0)
        gc()

        # Test if every class is part of a split
        expect_true(length(table(test_datamanager$datasets$data_labeled[sample$train]["labels"])) == n_classes)
        expect_true(length(table(test_datamanager$datasets$data_labeled[sample$val]["labels"])) == n_classes)
        if (i <= test_datamanager$get_n_folds()) {
          expect_true(length(table(test_datamanager$datasets$data_labeled[sample$test]["labels"])) == n_classes)
        }
        gc()

        # Test if the splits have the minimal absolute frequency
        expect_true(min(table(test_datamanager$datasets$data_labeled[sample$train]["labels"])) > 2)
        expect_true(min(table(test_datamanager$datasets$data_labeled[sample$val]["labels"])) > 1)
        if (i <= test_datamanager$get_n_folds()) {
          expect_true(min(table(test_datamanager$datasets$data_labeled[sample$test]["labels"])) > 1)
        }
        gc()

        # Test if the ratio of the labels is correct (stratified sample)
        expect_identical(
          ignore_attr = TRUE,
          table(test_datamanager$datasets$data_labeled[sample$train]["labels"]) /
            sum(table(test_datamanager$datasets$data_labeled[sample$train]["labels"])),
          table(example_targets) / sum(table(example_targets))
        )
        gc()
        expect_identical(
          ignore_attr = TRUE,
          table(test_datamanager$datasets$data_labeled[sample$val]["labels"]) /
            sum(table(test_datamanager$datasets$data_labeled[sample$val]["labels"])),
          table(example_targets) / sum(table(example_targets))
        )
        gc()
        if (i <= test_datamanager$get_n_folds()) {
          expect_identical(
            ignore_attr = TRUE,
            table(test_datamanager$datasets$data_labeled[sample$test]["labels"]) /
              sum(table(test_datamanager$datasets$data_labeled[sample$test]["labels"])),
            table(example_targets) / sum(table(example_targets))
          )
        }
        gc()
      })

      #----------------------------------------------------------------------------
      test_that(paste("DataManager - Synthetic Cases", method, "Fold:", i), {
        test_datamanager$set_state(iteration = i, step = NULL)
        test_datamanager$create_synthetic(trace = FALSE, inc_pseudo_data = FALSE)
        if (!is.null(test_datamanager$datasets$data_labeled_synthetic)) {
          synthetic_cases_per_seq <- table(
            test_datamanager$datasets$data_labeled_synthetic["length"],
            test_datamanager$datasets$data_labeled_synthetic["labels"]
          )
          original_cases_per_seq <- table(
            test_datamanager$get_dataset()["length"],
            test_datamanager$get_dataset()["labels"]
          )
          for (r in intersect(rownames(original_cases_per_seq), rownames(synthetic_cases_per_seq))) {
            for (c in intersect(colnames(original_cases_per_seq), colnames(synthetic_cases_per_seq))) {
              if (original_cases_per_seq[r, c] > 3) {
                expect_equal(
                  object = original_cases_per_seq[r, c] + synthetic_cases_per_seq[r, c],
                  expected = max(original_cases_per_seq[r, ]),
                  tolerance = 1
                )
              }
            }
          }
        }
      })
      gc()
      #----------------------------------------------------------------------------
      test_that(paste("DataManager - Pseudo Data", "Fold:", i), {
        test_datamanager$add_replace_pseudo_data(
          inputs = data_embeddings$embeddings[1:10, , , drop = FALSE],
          labels = example_targets[1:10]
        )

        expect_equal(
          object = length(test_datamanager$datasets$data_labeled_pseudo),
          expected = 10
        )
      })
      gc()
      #----------------------------------------------------------------------------
      test_that(paste("DataManager - get_dataset()", "Fold:", i), {
        data_test <- test_datamanager$get_dataset(
          inc_labeled = TRUE,
          inc_unlabeled = FALSE,
          inc_synthetic = TRUE,
          inc_pseudo_data = FALSE
        )
        number_of_cases <- sum(table(data_test["length"]))
        true_number_of_cases <- length(test_datamanager$samples[[i]]$train) +
          length(test_datamanager$datasets$data_labeled_synthetic)
        expect_equal(number_of_cases, true_number_of_cases)

        data_test <- test_datamanager$get_dataset(
          inc_labeled = TRUE,
          inc_unlabeled = FALSE,
          inc_synthetic = TRUE,
          inc_pseudo_data = TRUE
        )
        number_of_cases <- sum(table(data_test["length"]))
        true_number_of_cases <- length(test_datamanager$samples[[i]]$train) +
          length(test_datamanager$datasets$data_labeled_synthetic) +
          length(test_datamanager$datasets$data_labeled_pseudo)
        expect_equal(number_of_cases, true_number_of_cases)

        data_test <- test_datamanager$get_dataset(
          inc_labeled = FALSE,
          inc_unlabeled = FALSE,
          inc_synthetic = TRUE,
          inc_pseudo_data = TRUE
        )
        if (!is.null(data_test)) {
          number_of_cases <- sum(table(data_test["length"]))
          true_number_of_cases <- length(test_datamanager$datasets$data_labeled_synthetic) +
            length(test_datamanager$datasets$data_labeled_pseudo)
          expect_equal(number_of_cases, true_number_of_cases)

          data_test <- test_datamanager$get_dataset(
            inc_labeled = FALSE,
            inc_unlabeled = FALSE,
            inc_synthetic = TRUE,
            inc_pseudo_data = FALSE
          )
          number_of_cases <- sum(table(data_test["length"]))
          true_number_of_cases <- length(test_datamanager$datasets$data_labeled_synthetic)
          expect_equal(number_of_cases, true_number_of_cases)
        } else {
          expect_equal(
            object = test_datamanager$datasets$data_labeled_synthetic,
            expected = NULL
          )
        }


        data_test <- test_datamanager$get_dataset(
          inc_labeled = FALSE,
          inc_unlabeled = FALSE,
          inc_synthetic = FALSE,
          inc_pseudo_data = TRUE
        )
        number_of_cases <- sum(table(data_test["length"]))
        true_number_of_cases <- length(test_datamanager$datasets$data_labeled_pseudo)
        expect_equal(number_of_cases, true_number_of_cases)

        data_test <- test_datamanager$get_dataset(
          inc_labeled = FALSE,
          inc_unlabeled = TRUE,
          inc_synthetic = FALSE,
          inc_pseudo_data = FALSE
        )
        number_of_cases <- sum(table(data_test["length"]))
        true_number_of_cases <- length(test_datamanager$datasets$data_unlabeled)
        expect_equal(number_of_cases, true_number_of_cases)
      })
      gc()
    }
  }
}

Try the aifeducation package in your browser

Any scripts or data that you put into this service are public.

aifeducation documentation built on April 4, 2025, 2:01 a.m.