tests/testthat/helper-luz.R

luz_test_data <- function(example) {
  set.seed(42)

  nn2poly_dataset <- torch::dataset(
    name = "nn2poly_dataset",
    initialize = function(x, y) {
      self$x <- torch::torch_tensor(x)
      self$y <- torch::torch_tensor(y)
    },
    .getitem = function(i) {
      x <- self$x[i,]
      y <- self$y[i]
      list(x = x,
           y = y)
    },
    .length = function() {
      self$y$size()[[1]]
    }
  )

  data_full <- nn2poly_dataset(example$train_x, example$train_y)

  all_indices   <- 1:length(data_full)
  train_indices <- sample(all_indices, size = round(length(data_full)) * 0.8)
  val_indices   <- setdiff(all_indices, train_indices)

  data_train <- torch::dataset_subset(data_full, train_indices)
  data_val   <- torch::dataset_subset(data_full, val_indices)

  list(
    train = torch::dataloader(data_train, batch_size = 5, shuffle = TRUE),
    valid = torch::dataloader(data_val, batch_size = 5)
  )
}

luz_test_model <- function() {
  torch::torch_manual_seed(42)

  luz_model_sequential(
    torch::nn_linear(2,2),
    torch::nn_softplus(),
    torch::nn_linear(2,3),
    torch::nn_softplus(),
    torch::nn_linear(3,1)
  )
}

Try the nn2poly package in your browser

Any scripts or data that you put into this service are public.

nn2poly documentation built on May 29, 2024, 5:08 a.m.