tests/testthat/test-dist-dirichlet.R

test_that("Dirichlet distribution", {
  alpha <- c(2, 5, 3)
  dist <- dist_dirichlet(alpha = list(alpha))

  # format
  expect_equal(format(dist), "Dirichlet[3]")

  # stats
  alpha_0 <- sum(alpha)
  expect_equal(
    mean(dist),
    matrix(alpha / alpha_0, nrow = 1)
  )

  expect_equal(
    variance(dist),
    matrix(alpha * (alpha_0 - alpha) / (alpha_0^2 * (alpha_0 + 1)), nrow = 1)
  )

  expected_cov <- matrix(NA_real_, nrow = 3, ncol = 3)
  denom <- alpha_0^2 * (alpha_0 + 1)
  for (i in 1:3) {
    for (j in 1:3) {
      expected_cov[i, j] <- if (i == j) {
        alpha[i] * (alpha_0 - alpha[i]) / denom
      } else {
        -alpha[i] * alpha[j] / denom
      }
    }
  }
  expect_equal(covariance(dist), list(expected_cov))

  # pdf
  x <- c(0.2, 0.5, 0.3)
  expect_equal(
    density(dist, cbind(x[1], x[2], x[3])),
    LaplacesDemon::ddirichlet(x, alpha)
  )
  expect_equal(
    density(dist, cbind(x[1], x[2], x[3]), log = TRUE),
    LaplacesDemon::ddirichlet(x, alpha, log = TRUE)
  )

  # support
  s <- support(dist)
  expect_s3_class(s, "support_region")
  expect_equal(format(s), "[0,1]^3")

  # generate
  set.seed(42)
  samples <- generate(dist, 10)[[1]]
  expect_equal(nrow(samples), 10L)
  expect_equal(ncol(samples), 3L)
  # rows should sum to 1
  expect_equal(rowSums(samples), rep(1, 10), tolerance = 1e-10)
  # all values non-negative
  expect_true(all(samples >= 0))
})

test_that("Dirichlet distribution - symmetric case", {
  alpha <- c(1, 1, 1)
  dist <- dist_dirichlet(alpha = list(alpha))
  alpha_0 <- sum(alpha)

  # uniform Dirichlet: mean is (1/3, 1/3, 1/3)
  expect_equal(mean(dist), matrix(rep(1 / 3, 3), nrow = 1), tolerance = 1e-10)

  # variance should be equal across components
  vars <- as.numeric(variance(dist))
  expect_true(all(abs(vars - vars[1]) < 1e-10))
})

test_that("Dirichlet distribution - two-component reduces to Beta", {
  # Dir(a, b) marginal X1 has mean a/(a+b) and variance matching Beta(a,b)
  a <- 3
  b <- 4
  dist_dir <- dist_dirichlet(alpha = list(c(a, b)))
  alpha_0 <- a + b

  expect_equal(
    as.numeric(mean(dist_dir)),
    c(a / alpha_0, b / alpha_0)
  )

  expected_var_x1 <- a * b / (alpha_0^2 * (alpha_0 + 1))
  expect_equal(as.numeric(variance(dist_dir))[1], expected_var_x1)
})

Try the distributional package in your browser

Any scripts or data that you put into this service are public.

distributional documentation built on June 11, 2026, 9:07 a.m.