tests/testthat/test_losses.R

test_that("nn_mixup_loss works when used with cross entropy", {

  # 1: manual check for constant mixing weight of 0.9
  target1 <- torch::torch_ones(7, dtype = torch::torch_long())
  target2 <- torch::torch_empty_like(target1)$fill_(4)
  weight <- torch::torch_empty_like(target1, dtype = torch::torch_float())$fill_(0.9)
  target <- list(list(target1, target2), weight)
  input <- torch::torch_randn(c(7, 4))

  mixup_loss <- nn_mixup_loss(torch::nn_cross_entropy_loss(ignore_index = 222))
  t1_loss <- torch::nnf_cross_entropy(input, target1, reduction = "none")
  t2_loss <- torch::nnf_cross_entropy(input, target2, reduction = "none")
  expect_equal_to_tensor(mixup_loss(input, target), torch::torch_mean(t1_loss + 0.9 * (t2_loss - t1_loss)))

  # 2: mixing weight of 1 yields same loss as using target2 only
  weight <- torch::torch_empty_like(target1, dtype = torch::torch_float())$fill_(1)
  target <- list(list(target1, target2), weight)
  mixup_loss <- nn_mixup_loss(torch::nn_cross_entropy_loss())
  t2_loss <- torch::nnf_cross_entropy(input, target2)
  expect_equal_to_tensor(mixup_loss(input, target), t2_loss)

})
mlverse/torchlight documentation built on Sept. 19, 2024, 11:22 p.m.