inst/doc/get-started.R

## ---- include = FALSE---------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(luz)
library(torch)

## ---- eval = FALSE------------------------------------------------------------
#  net <- nn_module(
#    "Net",
#    initialize = function(num_class) {
#      self$conv1 <- nn_conv2d(1, 32, 3, 1)
#      self$conv2 <- nn_conv2d(32, 64, 3, 1)
#      self$dropout1 <- nn_dropout2d(0.25)
#      self$dropout2 <- nn_dropout2d(0.5)
#      self$fc1 <- nn_linear(9216, 128)
#      self$fc2 <- nn_linear(128, num_class)
#    },
#    forward = function(x) {
#      x <- self$conv1(x)
#      x <- nnf_relu(x)
#      x <- self$conv2(x)
#      x <- nnf_relu(x)
#      x <- nnf_max_pool2d(x, 2)
#      x <- self$dropout1(x)
#      x <- torch_flatten(x, start_dim = 2)
#      x <- self$fc1(x)
#      x <- nnf_relu(x)
#      x <- self$dropout2(x)
#      x <- self$fc2(x)
#      x
#    }
#  )

## ---- eval = FALSE------------------------------------------------------------
#  fitted <- net %>%
#    setup(
#      loss = nn_cross_entropy_loss(),
#      optimizer = optim_adam,
#      metrics = list(
#        luz_metric_accuracy
#      )
#    ) %>%
#    set_hparams(num_class = 10) %>%
#    set_opt_hparams(lr = 0.003) %>%
#    fit(train_dl, epochs = 10, valid_data = test_dl)

## ---- eval = FALSE------------------------------------------------------------
#  predictions <- predict(fitted, test_dl)

## ---- eval = FALSE------------------------------------------------------------
#  # -> Initialize objects: model, optimizers.
#  # -> Select fitting device.
#  # -> Move data, model, optimizers to the selected device.
#  # -> Start training
#  for (epoch in 1:epochs) {
#    # -> Training procedure
#    for (batch in train_dl) {
#      # -> Calculate model `forward` method.
#      # -> Calculate the loss
#      # -> Update weights
#      # -> Update metrics and tracking loss
#    }
#    # -> Validation procedure
#    for (batch in valid_dl) {
#      # -> Calculate model `forward` method.
#      # -> Calculate the loss
#      # -> Update metrics and tracking loss
#    }
#  }
#  # -> End training

## ---- eval=FALSE--------------------------------------------------------------
#  fitted <- net %>%
#    setup(
#      ...
#      metrics = list(
#        luz_metric_accuracy
#      )
#    ) %>%
#    fit(...)

## ---- eval = FALSE------------------------------------------------------------
#  luz_metric_accuracy <- luz_metric(
#    # An abbreviation to be shown in progress bars, or
#    # when printing progress
#    abbrev = "Acc",
#    # Initial setup for the metric. Metrics are initialized
#    # every epoch, for both training and validation
#    initialize = function() {
#      self$correct <- 0
#      self$total <- 0
#    },
#    # Run at every training or validation step and updates
#    # the internal state. The update function takes `preds`
#    # and `target` as parameters.
#    update = function(preds, target) {
#      pred <- torch::torch_argmax(preds, dim = 2)
#      self$correct <- self$correct + (pred == target)$
#        to(dtype = torch::torch_float())$
#        sum()$
#        item()
#      self$total <- self$total + pred$numel()
#    },
#    # Use the internal state to query the metric value
#    compute = function() {
#      self$correct/self$total
#    }
#  )

## ----include=FALSE, eval = torch::torch_is_installed()------------------------
library(luz)
torch::torch_manual_seed(1)
get_model <- function() {
  torch::nn_module(
    initialize = function(input_size, output_size) {
      self$fc <- torch::nn_linear(prod(input_size), prod(output_size))
      self$output_size <- output_size
    },
    forward = function(x) {
      out <- x %>%
        torch::torch_flatten(start_dim = 2) %>%
        self$fc()
      out$view(c(x$shape[1], self$output_size))
    }
  )
}

model <- get_model()
model <- model %>%
  setup(
    loss = torch::nn_mse_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz_metric_mae(),
      luz_metric_mse(),
      luz_metric_rmse()
    )
  ) %>%
  set_hparams(input_size = 10, output_size = 1) %>%
  set_opt_hparams(lr = 0.001)

x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))

fitted <- model %>% fit(
  x,
  epochs = 1,
  verbose = FALSE,
  dataloader_options = list(batch_size = 2, shuffle = FALSE)
)

evaluation <- fitted %>% evaluate(data = x)

## ---- eval = FALSE------------------------------------------------------------
#  evaluation <- fitted %>% evaluate(data = valid_dl)
#  metrics <- get_metrics(evaluation)
#  print(evaluation)

## ----echo=FALSE, eval=torch::torch_is_installed()-----------------------------
options(cli.unicode = FALSE)
metrics <- get_metrics(evaluation)
print(evaluation)

## ---- eval = FALSE------------------------------------------------------------
#  print_callback <- luz_callback(
#    name = "print_callback",
#    initialize = function(message) {
#      self$message <- message
#    },
#    on_train_batch_end = function() {
#      cat("Iteration ", ctx$iter, "\n")
#    },
#    on_epoch_end = function() {
#      cat(self$message, "\n")
#    }
#  )

## ---- eval = FALSE------------------------------------------------------------
#  fitted <- net %>%
#    setup(...) %>%
#    fit(..., callbacks = list(
#      print_callback(message = "Done!")
#    ))

Try the luz package in your browser

Any scripts or data that you put into this service are public.

luz documentation built on April 17, 2023, 5:08 p.m.