test_that("callback lr scheduler", {
skip_on_os("windows")
torch::torch_manual_seed(1)
set.seed(1)
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
)
expect_snapshot({
expect_message({
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE, epochs = 5, callbacks = list(
luz_callback_lr_scheduler(torch::lr_multiplicative, verbose = TRUE,
lr_lambda = function(epoch) 0.5)
))
})
})
expect_snapshot({
expect_message({
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE, epochs = 20, callbacks = list(
luz_callback_lr_scheduler(
torch::lr_reduce_on_plateau,
verbose = TRUE,
patience = 2,
threshold = 1e-1
)
))
})
})
})
test_that("csv callback", {
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
)
tmp <- tempfile()
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE, epochs = 5, callbacks = list(
luz_callback_csv_logger(tmp)
))
x <- read.table(tmp, header = TRUE, sep = ",")
expect_equal(nrow(x), 5)
expect_equal(names(x), c("epoch", "set", "loss"))
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE, epochs = 5, valid_data = dl, callbacks = list(
luz_callback_csv_logger(tmp)
))
x <- read.table(tmp, header = TRUE, sep = ",")
expect_equal(nrow(x), 10)
expect_equal(names(x), c("epoch", "set", "loss"))
})
test_that("progressbar appears with training and validation", {
torch::torch_manual_seed(1)
set.seed(1)
model <- get_model()
dl <- get_test_dl(len = 500)
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
)
withr::with_options(list(luz.force_progress_bar = TRUE,
luz.show_progress_bar_eta = FALSE,
width = 80), {
expect_snapshot({
expect_message({
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = TRUE, epochs = 2, valid_data = dl)
})
})
})
})
test_that("gradient clip works correctly", {
model <- get_model()
dl <- get_test_dl(len = 500)
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
)
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE, epochs = 2, valid_data = dl,
callbacks = list(luz_callback_gradient_clip(max_norm = 0)))
# we expect that no learning happened thus the loss is identicall
# acrosss all metrics.
expect_length(unique(get_metrics(output)$value), 1)
expect_length(get_metrics(output)$value, 4)
expect_error(luz_callback_gradient_clip(max_norm = "a"), "max_norm")
expect_error(luz_callback_gradient_clip(norm_type = "a"), "norm_type")
})
test_that("improve error message when you provide a unitinitilized callback", {
skip_on_os("windows")
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = nnf_mse_loss) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.01)
expect_snapshot_error({
model %>% fit(list(x, y), callbacks = list(luz_callback_auto_resume))
})
})
test_that("can get progress when using iterable datasets", {
torch::torch_manual_seed(1)
set.seed(1)
model <- get_model()
get_iterable_ds <- torch::iterable_dataset(
"iterable_ds",
initialize = function(len = 100, x_size = 10, y_size = 1, fixed_values = FALSE) {
self$len <- len
self$x <- torch::torch_randn(size = c(len, x_size))
self$y <- torch::torch_randn(size = c(len, y_size))
},
.iter = function() {
i <- 0
function() {
i <<- i + 1
if (i > self$len) {
return(coro::exhausted())
}
list(
x = self$x[i,..],
y = self$y[i,..]
)
}
}
)
ds <- get_iterable_ds()
dl <- torch::dataloader(ds, batch_size = 32)
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
)
withr::with_options(list(
luz.force_progress_bar = TRUE,
luz.show_progress_bar_eta = FALSE,
width = 80), {
expect_snapshot({
expect_message({
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(
get_iterable_ds(),
verbose = TRUE,
epochs = 2,
valid_data = get_iterable_ds(),
)
})
})
})
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.