tests/testthat/test-nn-utils-clip-grad.R

test_that("clip_grad_norm_", {
  compute_norm <- function(parameters, norm_type) {
    if (is.finite(norm_type)) {
      total_norm <- 0
      for (p in parameters) {
        total_norm <- total_norm + p$grad$data()$abs()$pow(norm_type)$sum()
      }
      total_norm^(1 / norm_type)
    } else {
      torch_tensor(max(sapply(parameters, function(p) p$grad$data()$abs()$max()$item())))
    }
  }

  grads <- list(
    torch_arange(start = 1, end = 100)$view(c(10, 10)),
    torch_ones(10)$div(1000)
  )

  l <- nn_linear(10, 10)
  max_norm <- 2

  for (norm_type in c(0.5, 1.5, 2, 4, Inf)) {
    for (i in seq_along(l$parameters)) {
      l$parameters[[i]]$set_grad_(grads[[i]]$clone())
    }

    norm_before <- compute_norm(l$parameters, norm_type)
    norm <- nn_utils_clip_grad_norm_(l$parameters, max_norm, norm_type)
    norm_after <- compute_norm(l$parameters, norm_type)
    expect_equal_to_tensor(norm_before, norm, tolerance = 1e-2)
    expect_equal_to_tensor(norm_after, torch_tensor(max_norm), tolerance = 1e-2)
    expect_equal_to_r(norm_after < norm_before, TRUE)
  }
})

test_that("clip_grad_value_", {
  grads <- list(
    torch_arange(start = 1, end = 100)$view(c(10, 10)),
    torch_ones(10)$div(1000)
  )

  l <- nn_linear(10, 10)
  max_norm <- 2

  for (i in seq_along(l$parameters)) {
    l$parameters[[i]]$set_grad_(grads[[i]])
  }

  nn_utils_clip_grad_value_(l$parameters, 2.5)

  for (p in l$parameters) {
    expect_equal_to_r(p$grad$max() <= 2.5, TRUE)
  }
})

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on June 7, 2023, 6:19 p.m.