Nothing
test_that("transpose_metrics is not adding an unnamed entry on top of the list", {
metrics <- list(loss = 1, loss = 2, loss = 3, loss = 4)
expect_no_error(
tabnet:::transpose_metrics(metrics)
)
expect_equal(
tabnet:::transpose_metrics(metrics),
list(loss = c(1, 2, 3, 4))
)
})
test_that("Unsupervised training with default config, data.frame and formula", {
expect_no_error(
fit <- tabnet_pretrain(x, y, epochs = 1)
)
expect_s3_class( fit, "tabnet_pretrain")
expect_equal(length(fit), 3)
expect_equal(names(fit), c("fit", "serialized_net", "blueprint"))
expect_equal(length(fit$fit), 5)
expect_equal(names(fit$fit), c("network", "metrics", "config", "checkpoints", "importances"))
expect_equal(length(fit$fit$metrics), 1)
expect_no_error(
fit <- tabnet_pretrain(Sale_Price ~ ., data = ames, epochs = 1)
)
expect_s3_class( fit, "tabnet_pretrain")
expect_equal(length(fit), 3)
expect_equal(names(fit), c("fit", "serialized_net", "blueprint"))
expect_equal(length(fit$fit), 5)
expect_equal(names(fit$fit), c("network", "metrics", "config", "checkpoints", "importances"))
expect_equal(length(fit$fit$metrics), 1)
})
test_that("Unsupervised training with pretraining_ratio", {
expect_no_error(
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, pretraining_ratio=0.2)
)
})
test_that("Unsupervised training prevent predict with an explicit message", {
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, pretraining_ratio=0.2)
expect_error(
predict(pretrain, attrix, type = "prob"),
regexp = "tabnet_pretrain"
)
expect_error(
predict(pretrain, attrix),
regexp = "tabnet_pretrain"
)
})
test_that("pretraining with `tabnet_model= ` parameter raise a warning", {
expect_warning(
fit <- tabnet_pretrain(x, y, epochs = 1, tabnet_model = ames_pretrain)
)
expect_s3_class( fit, "tabnet_pretrain")
expect_equal( length(fit), length(ames_pretrain))
expect_equal( length(fit$fit$metrics), 1)
})
test_that("errors when using an argument that do not exist", {
expect_error(
pretrain <- tabnet_pretrain(x, y, pretraining_ratiosas = 1-1e5),
regexp = "unused argument"
)
})
test_that("works with validation split", {
expect_no_error(
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2)
)
expect_no_error(
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2, verbose = TRUE)
)
})
test_that("works with categorical embedding dimension as list", {
config <- tabnet_config(cat_emb_dim=c(1,1,2,2,1,1,1,2,1,1,1,2,2,2))
expect_no_error(
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2, config=config)
)
})
test_that("explicit error message when categorical embedding dimension vector has wrong size", {
config <- tabnet_config(cat_emb_dim=c(1,1,2,2))
expect_error(
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1, valid_split = 0.2, config=config),
regexp = "number of categorical predictors"
)
})
test_that("can train from a recipe", {
rec <- recipe(Attrition ~ ., data = attrition) %>%
step_normalize(all_numeric(), -all_outcomes())
expect_no_error(
pretrain <- tabnet_pretrain(rec, attrition, epochs = 1, verbose = TRUE)
)
})
test_that("lr scheduler step works", {
expect_no_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "step",
lr_decay = 0.1, step_size = 1)
)
sc_fn <- function(optimizer) {
torch::lr_step(optimizer, step_size = 1, gamma = 0.1)
}
expect_no_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = sc_fn,
lr_decay = 0.1, step_size = 1)
)
})
test_that("lr scheduler reduce_on_plateau works", {
expect_no_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = "reduce_on_plateau",
lr_decay = 0.1, step_size = 1)
)
sc_fn <- function(optimizer) {
torch::lr_reduce_on_plateau(optimizer, factor = 0.1, patience = 10)
}
expect_no_error(
fit <- tabnet_pretrain(x, y, epochs = 3, lr_scheduler = sc_fn,
lr_decay = 0.1, step_size = 1)
)
})
test_that("checkpoints works", {
expect_no_error(
pretrain <- tabnet_pretrain(x, y, epochs = 3, checkpoint_epochs = 1)
)
expect_length( pretrain$fit$checkpoints, 3 )
# expect_equal( pretrain$fit$checkpoints[[3]], pretrain$serialized_net )
})
test_that("print module works", {
testthat::local_edition(3)
testthat::skip_on_os("linux")
testthat::skip_on_os("windows")
expect_no_error(
fit <- tabnet_pretrain(x, y, epochs = 1)
)
withr::with_options(new = c(cli.width = 50),
expect_snapshot_output(fit))
})
test_that("num_independent_decoder and num_shared_decoder change the network number of parameters", {
expect_no_error(
pretrain <- tabnet_pretrain(attrix, attriy, epochs = 1,
num_independent_decoder = 3, num_shared_decoder = 2)
)
expect_gt( torch:::get_parameter_count(pretrain$fit$network),
torch:::get_parameter_count(attr_pretrained$fit$network)
)
})
test_that("num_independent_decoder and num_shared_decoder do not change the network number of parameters for fit", {
expect_no_error(
config <- tabnet_config(epochs = 1,
num_independent_decoder = 3, num_shared_decoder = 2)
)
expect_no_error(
attr_fit <- tabnet_fit(attrix, attriy, config = config)
)
expect_equal( torch:::get_parameter_count(attr_fit$fit$network),
torch:::get_parameter_count(attr_fitted$fit$network)
)
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.