tests/testthat/test_entity_embedding.R

test_that("entity embedding works for all tasks", {
  skip_on_os("solaris")
  for (k in mlr_tasks$keys()) {
    task = mlr3::mlr_tasks$get(k)
    embds = make_embedding(task)
    expect_list(embds, len = 2L, names = "named")
    dt = task$feature_types[type %in% c("character", "factor", "ordered"), ]
    has_nonfactors = nrow(task$feature_types[!(type %in% c("character", "factor", "ordered")), ]) > 0
    expect_true(length(embds$inputs) == nrow(dt) + has_nonfactors)
    expect_class(embds$layers, "tensorflow.tensor")
    map(embds$inputs, expect_class, "tensorflow.tensor")
  }
  k_clear_session()
})

test_that("entity embedding works for all tasks", {
  for (k in mlr_tasks$keys()) {
    task = mlr3::mlr_tasks$get(k)
    embds = reshape_task_embedding(task)
    expect_list(embds, len = 2L, names = "named")
    dt = task$feature_types[type %in% c("character", "factor", "ordered"), ]
    has_nonfactors = nrow(task$feature_types[!(type %in% c("character", "factor", "ordered")), ]) > 0
    expect_true(length(embds$fct_levels) == nrow(dt))
    expect_true(length(embds$fct_levels) == length(embds$data) - has_nonfactors)
  }
})
mlr-org/mlr3keras documentation built on April 12, 2022, 11:35 a.m.