test_that("Fully managed", {
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
)
expect_s3_class(mod, "luz_module_generator")
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, valid_data = dl, verbose = FALSE)
expect_s3_class(output, "luz_module_fitted")
})
test_that("Custom optimizer", {
model <- get_model()
model <- torch::nn_module(
inherit = model,
set_optimizers = function() {
torch::optim_adam(self$parameters, lr = 0.01)
}
)
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
)
expect_s3_class(mod, "luz_module_generator")
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, valid_data = dl, verbose = FALSE)
expect_s3_class(output, "luz_module_fitted")
})
test_that("Multiple optimizers", {
model <- get_model()
module <- torch::nn_module(
initialize = function(input_size = 10, output_size = 1) {
self$model1 = model(input_size, output_size)
self$model2 <- model(input_size, output_size)
},
forward = function(x) {
self$model1(x) + self$model2(x)
},
set_optimizers = function() {
list(
one = torch::optim_adam(self$model1$parameters, lr = 0.01),
two = torch::optim_adam(self$model2$parameters, lr = 0.01)
)
}
)
dl <- get_dl()
mod <- module %>%
setup(
loss = torch::nn_mse_loss(),
)
expect_s3_class(mod, "luz_module_generator")
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, valid_data = dl, verbose = FALSE)
expect_s3_class(output, "luz_module_fitted")
})
test_that("can train without a validation dataset", {
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
)
expect_s3_class(mod, "luz_module_generator")
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE)
expect_s3_class(output, "luz_module_fitted")
})
test_that("predict works for modules", {
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
)
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, verbose = FALSE)
pred <- predict(output, dl)
pred2 <- predict(output, dl)
expect_equal(pred$shape, c(100, 1))
expect_equal(as.array(pred$to(device = "cpu")), as.array(pred2$to(device="cpu")))
# try with a different dataloader
dl <- get_dl()
pred <- predict(output, dl)
pred2 <- predict(output, dl)
expect_equal(as.array(pred$to(device = "cpu")), as.array(pred2$to(device="cpu")))
})
test_that("predict can use a progress bar", {
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
)
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
set_opt_hparams(lr = 0.001) %>%
fit(dl, epochs = 1, verbose = FALSE)
dl <- get_dl(len = 500)
withr::with_options(
list(luz.force_progress_bar = TRUE,
luz.show_progress_bar_eta = FALSE,
width = 80),
{
expect_snapshot(
pred <- predict(output, dl, verbose=TRUE)
)
}
)
expect_equal(output$ctx$hparams$input_size, 10)
expect_equal(output$ctx$opt_hparams$lr, 0.001)
})
test_that("valid_data works", {
model <- get_model()
model <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
) %>%
set_hparams(input_size = 10, output_size = 1) %>%
set_opt_hparams(lr = 0.001)
fitted <- model %>% fit(
list(torch::torch_randn(100,10), torch::torch_randn(100, 1)),
epochs = 10,
valid_data = 0.1,
verbose = FALSE
)
expect_true("valid" %in% get_metrics(fitted)$set)
expect_error(class= "value_error", regexp = "2", {
model %>% fit(
list(torch::torch_randn(100,10), torch::torch_randn(100, 1)),
epochs = 10,
valid_data = 2,
verbose = FALSE
)
})
expect_error(class= "value_error", regexp = "-1", {
model %>% fit(
list(torch::torch_randn(100,10), torch::torch_randn(100, 1)),
epochs = 10,
valid_data = -1,
verbose = FALSE
)
})
dl <- get_dl()
expect_error(class= "value_error", regexp = "dataloader", {
model %>% fit(
dl,
epochs = 10,
valid_data = 0.2,
verbose = FALSE
)
})
})
test_that("we can pass dataloader_options", {
iter_callback <- luz_callback(
initialize = function() {
self$iter <- 0
},
on_train_batch_end = function() {
self$iter <- self$iter + 1
}
)
model <- get_model()
model <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
) %>%
set_hparams(input_size = 10, output_size = 1) %>%
set_opt_hparams(lr = 0.001)
x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))
iter <- iter_callback()
fitted <- model %>% fit(
x,
epochs = 1,
valid_data = 0.1,
verbose = FALSE,
dataloader_options = list(batch_size = 2, shuffle = FALSE),
callbacks = iter
)
expect_equal(iter$iter, 45)
dl <- get_dl()
expect_error(regexp = "already a dataloader", {
model %>% fit(
dl,
epochs = 1,
verbose = FALSE,
dataloader_options = list(batch_size = 2, shuffle = FALSE)
)
})
expect_warning(regexp = "already a dataloader", {
model %>% fit(
x,
epochs = 1,
verbose = FALSE,
valid_data = dl,
dataloader_options = list(batch_size = 2, shuffle = FALSE)
)
})
pred <- predict(fitted, x, dataloader_options = list(batch_size = 45, drop_last = TRUE))
expect_tensor_shape(pred, c(90, 1))
expect_warning(regexp = "already a dataloader", {
predict(fitted, dl, dataloader_options = list(batch_size = 45, drop_last = TRUE))
})
expect_warning(regexp = "ignored for predictions", {
predict(fitted, x, dataloader_options = list(shuffle = TRUE))
})
})
test_that("evaluate works", {
set.seed(1)
torch_manual_seed(1)
model <- get_model()
model <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
metrics = list(
luz_metric_mae(),
luz_metric_mse(),
luz_metric_rmse()
)
) %>%
set_hparams(input_size = 10, output_size = 1) %>%
set_opt_hparams(lr = 0.001)
x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))
fitted <- model %>% fit(
x,
epochs = 1,
verbose = FALSE,
dataloader_options = list(batch_size = 2, shuffle = FALSE)
)
e <- evaluate(fitted, x)
expect_equal(nrow(get_metrics(e)), 4)
expect_equal(ncol(get_metrics(e)), 2)
expect_snapshot(print(e))
})
test_that("cutom backward", {
model <- get_model()
dl <- get_dl()
mod <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
backward = function(x) {
x$backward()
if (ctx$iter == 1 && ctx$epoch == 1) {
print("hello")
}
}
)
expect_s3_class(mod, "luz_module_generator")
expect_output(regexp = "hello", {
output <- mod %>%
set_hparams(input_size = 10, output_size = 1) %>%
fit(dl, valid_data = dl, verbose = FALSE)
})
expect_s3_class(output, "luz_module_fitted")
})
test_that("luz module has a device arg", {
mod <- get_model() %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
)
modul <- mod(1,1)
expect_true(
modul$device == torch_device("cpu")
)
mod <- nn_module(
initialize = function() {
self$par <- torch::nn_parameter(torch_randn(10, 10))
},
forward = function(x) {
self$par
},
active = list(
device = function() {
"hello"
}
)
)
model <- mod %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam
)
modul <- mod()
expect_equal(modul$device, "hello")
})
test_that("evaluate allows setting metrics for it", {
model <- get_model()
model <- model %>%
setup(
loss = torch::nn_mse_loss(),
optimizer = torch::optim_adam,
metrics = list(
luz_metric_mae(),
luz_metric_mse(),
luz_metric_rmse()
)
) %>%
set_hparams(input_size = 10, output_size = 1) %>%
set_opt_hparams(lr = 0.001)
x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))
fitted <- model %>% fit(
x,
epochs = 1,
verbose = FALSE,
dataloader_options = list(batch_size = 2, shuffle = FALSE)
)
e1 <- get_metrics(evaluate(fitted, x))
e2 <- get_metrics(evaluate(fitted, x, metrics = list(luz_metric_mae())))
e3 <- get_metrics(evaluate(fitted, x))
expect_equal(e1, e3)
expect_equal(e2, e1[e1$metric %in% c("loss", "mae"),])
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.