test_that("Can create a simple metric", {
metric <- luz_metric(
name = "MyMetric",
initialize = function(y) {
self$z <- 0
self$y <- y
},
update = function(preds, target) {
self$z <- self$z + self$y
},
compute = function() {
self$z
}
)
# this creates an inherited class with partialized arguments.
init <- metric(y = 1)
expect_s3_class(init, "R6ClassGenerator")
# initializes an instance of the class
m <- init$new()
expect_s3_class(m, "MyMetric")
# run update step
m$update(1,1)
expect_equal(m$compute(), 1)
})
test_that("mae works", {
x <- torch::torch_randn(100, 100, device = get_device())
y <- torch::torch_randn(100, 100, device = get_device())
m <- luz_metric_mae()
m <- m$new()$to(device = get_device())
m$update(x, y)
o <- m$compute()
eo <- mean(abs(as.array(x$cpu()) - as.array(y$cpu())))
expect_equal(o, eo, tolerance = 1e-5)
})
test_that("mse works", {
x <- torch::torch_randn(100, 100, device = get_device())
y <- torch::torch_randn(100, 100, device = get_device())
m <- luz_metric_mse()
m <- m$new()$to(device = get_device())
m$update(x, y)
o <- m$compute()
eo <- mean((as.array(x$cpu()) - as.array(y$cpu()))^2)
expect_equal(o, eo, tolerance = 1e-5)
})
test_that("rmse works", {
x <- torch::torch_randn(100, 100, device = get_device())
y <- torch::torch_randn(100, 100, device = get_device())
m <- luz_metric_rmse()
m <- m$new()$to(device = get_device())
m$update(x, y)
o <- m$compute()
eo <- sqrt(mean((as.array(x$cpu()) - as.array(y$cpu()))^2))
expect_equal(o, eo, tolerance = 1e-5)
})
test_that("binary accuracy with logits", {
m <- luz_metric_binary_accuracy_with_logits(threshold = 0.5)
m <- m$new()$to(device = get_device())
x <- torch_randn(100, device = get_device())
y <- torch_randint(0, 1, 100, device = get_device())
m$update(x, y)
expect_equal(
m$compute(),
mean(as.array((x > 0)$cpu()) == as.array(y$cpu()))
)
})
test_that("metrics works within models", {
dl <- get_binary_dl()
model <- torch::nn_linear
mod <- model %>%
setup(
loss = torch::nn_bce_with_logits_loss(),
optimizer = torch::optim_adam,
metrics = list(
luz_metric_binary_accuracy_with_logits(),
luz_metric_binary_auroc(from_logits = TRUE),
luz_metric_binary_accuracy()
)
)
expect_error(
output <- mod %>%
set_hparams(in_features = 10, out_features = 1) %>%
fit(dl, epochs = 1, verbose = FALSE),
regexp = NA
)
expect_length(
output$records$metrics$train[[1]],
4
)
})
test_that("can specify metrics for training and validation", {
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = torch::nn_mse_loss(),
metrics = luz_metric_set(
metrics = list(luz_metric_mae()),
valid_metrics = list(luz_metric_rmse()),
train_metrics = list(luz_metric_mse())
)) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.001)
res <- model %>%
fit(list(x, y), epochs = 5, valid_data = list(x, y), verbose = FALSE)
metrics <- get_metrics(res)
expect_equal(
unique(metrics[metrics$set== "valid", "metric"]),
c("loss", "mae", "rmse")
)
expect_equal(
unique(metrics[metrics$set== "train", "metric"]),
c("loss", "mae", "mse")
)
expect_error(regexp = NA, {
plot(res)
})
})
test_that("Get a nice error message if you pass the wrong metric type", {
expect_error(
luz_metric_set(
metrics = c(luz_metric_mae),
valid_metrics = c(luz_metric_rmse()),
train_metrics = c(luz_metric_mse())),
regexp = "Expected an object with"
)
expect_true(inherits(luz_metric_mse(), "luz_metric_generator"))
})
test_that("get a nice error message when metric fails updating", {
metric_fail <- luz_metric(
name = "hello",
abbrev = "h",
initialize = function() {
},
update = function(pred, save) {
stop("error in metric!")
}
)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = torch::nn_mse_loss(),
metrics = list(metric_fail())) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.001)
expect_snapshot_error({
res <- model %>%
fit(list(x, y), epochs = 5, valid_data = list(x, y), verbose = FALSE)
})
})
test_that("error gracefully on failed `compute`.", {
metric_fail <- luz_metric(
name = "hello",
abbrev = "h",
initialize = function() {
},
update = function(pred, save) {
},
compute = function() {
stop("Error computing metric.")
}
)
x <- torch_randn(1000, 10)
y <- torch_randn(1000, 1)
model <- nn_linear %>%
setup(optimizer = optim_sgd, loss = torch::nn_mse_loss(),
metrics = list(metric_fail())) %>%
set_hparams(in_features = 10, out_features = 1) %>%
set_opt_hparams(lr = 0.001)
expect_snapshot_error({
res <- model %>%
fit(list(x, y), epochs = 5, valid_data = list(x, y), verbose = FALSE)
})
})
test_that("metrics run in the environments they are created", {
x <- 1
e <- new.env(parent = globalenv())
e$x <- 2
base::eval(expr = rlang::expr(
metric <- luz_metric(
name = "hello",
abbrev = "h",
initialize = function() {
},
update = function(pred, save) {
},
compute = function() {
x
}
)
), envir = e)
instance <- e$metric()
internal <- instance$new()
expect_equal(internal$compute(), 2)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.