tests/testthat/test-distributions-bernoulli.R

#' Note: consider PyTorch - like test schema
#' See: https://github.com/pytorch/pytorch/blob/master/test/distributions/test_distributions.py
#' TODO: add more unit tests

test_that("Bernoulli distribution - basic tests", {
  p <- torch_tensor(c(0.7, 0.2, 0.4), requires_grad = TRUE)
  r <- torch_tensor(0.3, requires_grad = TRUE)
  s <- 0.3

  expect_equal(distr_bernoulli(p)$sample(8)$size(), c(8, 3))
  expect_false(distr_bernoulli(p)$sample()$requires_grad)
  expect_equal(distr_bernoulli(r)$sample(8)$size(), c(8, 1))
  expect_equal(distr_bernoulli(r)$sample()$size(), 1)
  expect_equal(distr_bernoulli(s)$sample()$size(), 1)

  ref_log_prob <- function(idx, val, log_prob) {
    prob <- p[idx]
    prob <- if (as.logical(val != 0)) prob else 1 - prob
    expect_equal(log_prob, log(prob))
  }

  check_log_prob(distr_bernoulli(p), ref_log_prob)
  check_log_prob(distr_bernoulli(logits = p$log() - (-p)$log1p()), ref_log_prob)

  expect_error(distr_bernoulli(r)$rsample())

  # check entropy computation
  expect_equal(
    distr_bernoulli(p)$entropy(),
    torch_tensor(c(0.6108, 0.5004, 0.6730))
  )
  expect_equal(distr_bernoulli(0)$entropy(), torch_tensor(0))
  expect_equal(distr_bernoulli(s)$entropy(), torch_tensor(0.6108))
})

test_that("Bernoulli Distribution - enumerate support", {
  examples <- list(
    list(list(probs = 0.1), matrix(c(0, 1), 2, 1)),
    list(list(probs = c(0.1, 0.9)), matrix(c(0, 1), 2, 1)),
    list(
      list(probs = matrix(c(0.1, 0.1, 0.3, 0.4), 2, 2)),
      array(c(0, 1), dim = c(1, 2, 1))
    )
  )

  check_enumerate_support(distr_bernoulli, examples)
})

test_that("Bernoulli Distribution 3D", {
  p <- torch_full(c(2, 3, 5), 0.5)$requires_grad_()
  expect_equal(distr_bernoulli(p)$sample()$size(), c(2, 3, 5))
  expect_equal(
    distr_bernoulli(p)$sample(sample_shape = c(2, 5))$size(),
    c(2, 5, 2, 3, 5)
  )
  expect_equal(
    distr_bernoulli(p)$sample(2)$size(),
    c(2, 2, 3, 5)
  )
})

test_that("Bernoulli distribution - expand", {
  shapes <-
    list(NULL, 2, c(2, 1))

  d <- distr_bernoulli(torch_tensor(c(0.7, 0.2, 0.4),
    requires_grad = TRUE
  ))

  for (shape in shapes) {
    shape <- shape[[1]]
    expanded_shape <- c(shape, d$batch_shape)
    original_shape <- c(d$batch_shape, d$event_shape)
    expected_shape <- c(shape, original_shape)
    expanded <- d$expand(batch_shape = c(expanded_shape))
    sample <- expanded$sample()
    actual_shape <- expanded$sample()$shape

    expect_equal(class(expanded), class(d))
    expect_equal(d$sample()$shape, original_shape)
    expect_equal(expanded$log_prob(sample), d$log_prob(sample))
    expect_equal(actual_shape, expected_shape)
    expect_equal(expanded$batch_shape, expanded_shape)
  }
})

test_that("Bernoulli distribution - enumerate_support", {
  d <- distr_bernoulli(0.7)
  required_values <- c(0, 1)
  unique_values <- unique(as.array(d$enumerate_support()))
  expect_true(all(unique_values %in% required_values))
})

test_that("log prob is correct", {
  probs <- torch_rand(10)
  d <- distr_bernoulli(probs = probs)

  x <- torch_tensor(sample(c(0, 1), 10, replace = TRUE))
  result <- d$log_prob(x)
  expected <- dbinom(as.numeric(x), 1, prob = as.numeric(probs), log = TRUE)

  expect_equal_to_r(result, expected, tolerance = 1e-6)
})

test_that("gradients are correct", {
  probs <- torch_tensor(c(0.5, 0.2), requires_grad = TRUE)
  d <- distr_bernoulli(probs = probs)

  x <- torch_cat(list(torch_ones(5, 2), torch_zeros(5, 2)))
  loss <- d$log_prob(x)$mean()
  loss$backward()

  expect_equal_to_r(probs$grad, c(0.0000000000, 0.9375000000)) # from pytorch
})

Try the torch package in your browser

Any scripts or data that you put into this service are public.

torch documentation built on May 29, 2024, 9:54 a.m.