tests/testthat/test-mvn-algebra.R

# Tests for MVN algebra: conditional, affine_transform

# =============================================================================
# conditional.mvn — Schur complement
# =============================================================================

test_that("conditional.mvn with given_indices returns correct type", {
  m <- mvn(mu = c(1, 2, 3), sigma = diag(3))
  # Condition on variable 3 = 5
  result <- conditional(m, given_indices = 3, given_values = 5)
  expect_true(is_mvn(result))
  expect_equal(dim(result), 2)
})

test_that("conditional.mvn independent components: conditioning doesn't affect mean", {
  # Independent components: Sigma = I
  m <- mvn(mu = c(10, 20, 30), sigma = diag(3))
  # Condition on X3 = 50
  result <- conditional(m, given_indices = 3, given_values = 50)
  # Since independent, X1|X3 should still have mean 10, X2|X3 mean 20
  expect_equal(mean(result), c(10, 20))
  expect_equal(vcov(result), diag(2))
})

test_that("conditional.mvn correlated: Schur complement correct", {
  # X = (X1, X2) with cov(X1, X2) = 0.5
  sigma <- matrix(c(1, 0.5, 0.5, 1), 2, 2)
  m <- mvn(mu = c(0, 0), sigma = sigma)

  # Condition on X2 = 2
  result <- conditional(m, given_indices = 2, given_values = 2)

  # Expected: mu_cond = 0 + 0.5 * 1 * (2 - 0) = 1
  # sig_cond = 1 - 0.5 * 1 * 0.5 = 0.75
  expect_true(is_normal(result))
  expect_equal(mean(result), 1)
  expect_equal(vcov(result), 0.75)
})

test_that("conditional.mvn 4D: condition on 2 variables", {
  mu <- c(1, 2, 3, 4)
  sigma <- diag(4)
  sigma[1, 3] <- sigma[3, 1] <- 0.8
  sigma[2, 4] <- sigma[4, 2] <- 0.6
  m <- mvn(mu = mu, sigma = sigma)

  # Condition on X3 = 5, X4 = 6
  result <- conditional(m, given_indices = c(3, 4), given_values = c(5, 6))
  expect_true(is_mvn(result))
  expect_equal(dim(result), 2)

  # mu_cond for X1: 1 + 0.8 * 1 * (5 - 3) = 1 + 1.6 = 2.6
  # mu_cond for X2: 2 + 0.6 * 1 * (6 - 4) = 2 + 1.2 = 3.2
  expect_equal(mean(result)[1], 2.6, tolerance = 1e-10)
  expect_equal(mean(result)[2], 3.2, tolerance = 1e-10)
})

test_that("conditional.mvn condition on all-but-one returns normal", {
  m <- mvn(mu = c(0, 0), sigma = matrix(c(1, 0.5, 0.5, 1), 2, 2))
  result <- conditional(m, given_indices = 1, given_values = 1)
  expect_true(is_normal(result))
  expect_equal(mean(result), 0.5)  # 0 + 0.5 * 1 * (1 - 0)
  expect_equal(vcov(result), 0.75) # 1 - 0.5^2
})

test_that("conditional.mvn errors: condition on all variables", {
  m <- mvn(mu = c(1, 2), sigma = diag(2))
  expect_error(
    conditional(m, given_indices = c(1, 2), given_values = c(3, 4)),
    "cannot condition on all variables"
  )
})

test_that("conditional.mvn errors: mismatched indices/values", {
  m <- mvn(mu = c(1, 2, 3), sigma = diag(3))
  expect_error(
    conditional(m, given_indices = c(1, 2), given_values = 3),
    "same length"
  )
})

test_that("conditional.mvn errors: out-of-range indices", {
  m <- mvn(mu = c(1, 2), sigma = diag(2))
  expect_error(
    conditional(m, given_indices = 3, given_values = 1),
    "given_indices"
  )
})

test_that("conditional.mvn P fallback works", {
  m <- mvn(mu = c(0, 0), sigma = diag(2))
  set.seed(42)
  result <- conditional(m, P = function(x) x[1] > 0)
  expect_s3_class(result, "empirical_dist")
})

test_that("conditional.mvn errors: no P or given_indices", {
  m <- mvn(mu = c(1, 2), sigma = diag(2))
  expect_error(conditional(m), "must provide either")
})

# ---- MC validation of Schur complement ----

test_that("conditional.mvn: MC validation of closed form", {
  sigma <- matrix(c(4, 2, 2, 3), 2, 2)
  m <- mvn(mu = c(5, 10), sigma = sigma)

  # Closed form: X1 | X2 = 12
  result <- conditional(m, given_indices = 2, given_values = 12)

  # MC validation: sample from joint, filter X2 ≈ 12, check X1 distribution
  set.seed(42)
  samp <- sampler(m)(100000)
  # Select samples where X2 is close to 12
  close <- abs(samp[, 2] - 12) < 0.5
  x1_given <- samp[close, 1]

  # Closed-form mean should be close to empirical mean
  expect_equal(mean(result), mean(x1_given), tolerance = 0.3)
})


# =============================================================================
# affine_transform
# =============================================================================

test_that("affine_transform: identity matrix on mvn returns same distribution", {
  m <- mvn(mu = c(1, 2), sigma = matrix(c(4, 1, 1, 3), 2, 2))
  result <- affine_transform(m, A = diag(2))
  expect_true(is_mvn(result))
  expect_equal(mean(result), c(1, 2))
  expect_equal(vcov(result), matrix(c(4, 1, 1, 3), 2, 2))
})

test_that("affine_transform: scale matrix", {
  m <- mvn(mu = c(1, 2), sigma = diag(2))
  A <- matrix(c(2, 0, 0, 3), 2, 2)
  result <- affine_transform(m, A = A)
  expect_true(is_mvn(result))
  expect_equal(mean(result), c(2, 6))
  expect_equal(vcov(result), matrix(c(4, 0, 0, 9), 2, 2))
})

test_that("affine_transform: with offset b", {
  m <- mvn(mu = c(0, 0), sigma = diag(2))
  result <- affine_transform(m, A = diag(2), b = c(5, 10))
  expect_true(is_mvn(result))
  expect_equal(mean(result), c(5, 10))
  expect_equal(vcov(result), diag(2))
})

test_that("affine_transform: projection (2D -> 1D)", {
  m <- mvn(mu = c(3, 7), sigma = matrix(c(4, 1, 1, 2), 2, 2))
  # Project onto first component
  A <- matrix(c(1, 0), nrow = 1)
  result <- affine_transform(m, A = A)
  expect_true(is_normal(result))
  expect_equal(mean(result), 3)
  expect_equal(vcov(result), 4)
})

test_that("affine_transform: sum of components (1xd matrix)", {
  m <- mvn(mu = c(1, 2), sigma = matrix(c(1, 0.5, 0.5, 1), 2, 2))
  # A = [1, 1] -> X1 + X2
  A <- matrix(c(1, 1), nrow = 1)
  result <- affine_transform(m, A = A)
  expect_true(is_normal(result))
  expect_equal(mean(result), 3)
  # var(X1 + X2) = 1 + 2*0.5 + 1 = 3
  expect_equal(vcov(result), 3)
})

test_that("affine_transform: works on univariate normal", {
  n <- normal(mu = 5, var = 4)
  # 2*X + 3
  result <- affine_transform(n, A = 2, b = 3)
  expect_true(is_normal(result))
  expect_equal(mean(result), 13)   # 2*5 + 3
  expect_equal(vcov(result), 16)   # 2^2 * 4
})

test_that("affine_transform errors: wrong dimensions", {
  m <- mvn(mu = c(1, 2), sigma = diag(2))
  expect_error(affine_transform(m, A = diag(3)), "ncol.*must equal dim")
})

test_that("affine_transform errors: wrong b length", {
  m <- mvn(mu = c(1, 2), sigma = diag(2))
  expect_error(affine_transform(m, A = diag(2), b = c(1, 2, 3)), "length.*must equal nrow")
})

test_that("affine_transform errors: non-dist input", {
  expect_error(affine_transform(42, A = 1), "'x' must be a 'dist' object")
})

test_that("affine_transform errors: unsupported dist type", {
  expect_error(affine_transform(exponential(1), A = 1), "'x' must be a 'normal' or 'mvn'")
})

Try the algebraic.dist package in your browser

Any scripts or data that you put into this service are public.

algebraic.dist documentation built on Feb. 27, 2026, 5:06 p.m.