library(luz)
torch::torch_manual_seed(1)
get_model <- function() {
  torch::nn_module(
    initialize = function(input_size, output_size) {
      self$fc <- torch::nn_linear(prod(input_size), prod(output_size))
      self$output_size <- output_size
    },
    forward = function(x) {
      out <- x %>%
        torch::torch_flatten(start_dim = 2) %>%
        self$fc()
      out$view(c(x$shape[1], self$output_size))
    }
  )
}

model <- get_model()
model <- model %>%
  setup(
    loss = torch::nn_mse_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz_metric_mae(),
      luz_metric_mse(),
      luz_metric_rmse()
    )
  ) %>%
  set_hparams(input_size = 10, output_size = 1) %>%
  set_opt_hparams(lr = 0.001)

x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))

fitted <- model %>% fit(
  x,
  epochs = 1,
  verbose = FALSE,
  dataloader_options = list(batch_size = 2, shuffle = FALSE)
)

evaluation <- fitted %>% evaluate(data = x)

Once a model has been trained you might want to evaluate its performance on a different dataset. For that reason, luz provides the ?evaluate function that takes a fitted model and a dataset and computes the metrics attached to the model.

Evaluate returns a luz_module_evaluation object that you can query for metrics using the get_metrics function or simply print to see the results.

For example:

evaluation <- fitted %>% evaluate(data = valid_dl)
metrics <- get_metrics(evaluation)
print(evaluation)
options(cli.unicode = FALSE)
metrics <- get_metrics(evaluation)
print(evaluation)


mlverse/luz documentation built on Sept. 19, 2024, 11:20 p.m.