Nothing
# Tests for MVN algebra: conditional, affine_transform
# =============================================================================
# conditional.mvn — Schur complement
# =============================================================================
test_that("conditional.mvn with given_indices returns correct type", {
m <- mvn(mu = c(1, 2, 3), sigma = diag(3))
# Condition on variable 3 = 5
result <- conditional(m, given_indices = 3, given_values = 5)
expect_true(is_mvn(result))
expect_equal(dim(result), 2)
})
test_that("conditional.mvn independent components: conditioning doesn't affect mean", {
# Independent components: Sigma = I
m <- mvn(mu = c(10, 20, 30), sigma = diag(3))
# Condition on X3 = 50
result <- conditional(m, given_indices = 3, given_values = 50)
# Since independent, X1|X3 should still have mean 10, X2|X3 mean 20
expect_equal(mean(result), c(10, 20))
expect_equal(vcov(result), diag(2))
})
test_that("conditional.mvn correlated: Schur complement correct", {
# X = (X1, X2) with cov(X1, X2) = 0.5
sigma <- matrix(c(1, 0.5, 0.5, 1), 2, 2)
m <- mvn(mu = c(0, 0), sigma = sigma)
# Condition on X2 = 2
result <- conditional(m, given_indices = 2, given_values = 2)
# Expected: mu_cond = 0 + 0.5 * 1 * (2 - 0) = 1
# sig_cond = 1 - 0.5 * 1 * 0.5 = 0.75
expect_true(is_normal(result))
expect_equal(mean(result), 1)
expect_equal(vcov(result), 0.75)
})
test_that("conditional.mvn 4D: condition on 2 variables", {
mu <- c(1, 2, 3, 4)
sigma <- diag(4)
sigma[1, 3] <- sigma[3, 1] <- 0.8
sigma[2, 4] <- sigma[4, 2] <- 0.6
m <- mvn(mu = mu, sigma = sigma)
# Condition on X3 = 5, X4 = 6
result <- conditional(m, given_indices = c(3, 4), given_values = c(5, 6))
expect_true(is_mvn(result))
expect_equal(dim(result), 2)
# mu_cond for X1: 1 + 0.8 * 1 * (5 - 3) = 1 + 1.6 = 2.6
# mu_cond for X2: 2 + 0.6 * 1 * (6 - 4) = 2 + 1.2 = 3.2
expect_equal(mean(result)[1], 2.6, tolerance = 1e-10)
expect_equal(mean(result)[2], 3.2, tolerance = 1e-10)
})
test_that("conditional.mvn condition on all-but-one returns normal", {
m <- mvn(mu = c(0, 0), sigma = matrix(c(1, 0.5, 0.5, 1), 2, 2))
result <- conditional(m, given_indices = 1, given_values = 1)
expect_true(is_normal(result))
expect_equal(mean(result), 0.5) # 0 + 0.5 * 1 * (1 - 0)
expect_equal(vcov(result), 0.75) # 1 - 0.5^2
})
test_that("conditional.mvn errors: condition on all variables", {
m <- mvn(mu = c(1, 2), sigma = diag(2))
expect_error(
conditional(m, given_indices = c(1, 2), given_values = c(3, 4)),
"cannot condition on all variables"
)
})
test_that("conditional.mvn errors: mismatched indices/values", {
m <- mvn(mu = c(1, 2, 3), sigma = diag(3))
expect_error(
conditional(m, given_indices = c(1, 2), given_values = 3),
"same length"
)
})
test_that("conditional.mvn errors: out-of-range indices", {
m <- mvn(mu = c(1, 2), sigma = diag(2))
expect_error(
conditional(m, given_indices = 3, given_values = 1),
"given_indices"
)
})
test_that("conditional.mvn P fallback works", {
m <- mvn(mu = c(0, 0), sigma = diag(2))
set.seed(42)
result <- conditional(m, P = function(x) x[1] > 0)
expect_s3_class(result, "empirical_dist")
})
test_that("conditional.mvn errors: no P or given_indices", {
m <- mvn(mu = c(1, 2), sigma = diag(2))
expect_error(conditional(m), "must provide either")
})
# ---- MC validation of Schur complement ----
test_that("conditional.mvn: MC validation of closed form", {
sigma <- matrix(c(4, 2, 2, 3), 2, 2)
m <- mvn(mu = c(5, 10), sigma = sigma)
# Closed form: X1 | X2 = 12
result <- conditional(m, given_indices = 2, given_values = 12)
# MC validation: sample from joint, filter X2 ≈ 12, check X1 distribution
set.seed(42)
samp <- sampler(m)(100000)
# Select samples where X2 is close to 12
close <- abs(samp[, 2] - 12) < 0.5
x1_given <- samp[close, 1]
# Closed-form mean should be close to empirical mean
expect_equal(mean(result), mean(x1_given), tolerance = 0.3)
})
# =============================================================================
# affine_transform
# =============================================================================
test_that("affine_transform: identity matrix on mvn returns same distribution", {
m <- mvn(mu = c(1, 2), sigma = matrix(c(4, 1, 1, 3), 2, 2))
result <- affine_transform(m, A = diag(2))
expect_true(is_mvn(result))
expect_equal(mean(result), c(1, 2))
expect_equal(vcov(result), matrix(c(4, 1, 1, 3), 2, 2))
})
test_that("affine_transform: scale matrix", {
m <- mvn(mu = c(1, 2), sigma = diag(2))
A <- matrix(c(2, 0, 0, 3), 2, 2)
result <- affine_transform(m, A = A)
expect_true(is_mvn(result))
expect_equal(mean(result), c(2, 6))
expect_equal(vcov(result), matrix(c(4, 0, 0, 9), 2, 2))
})
test_that("affine_transform: with offset b", {
m <- mvn(mu = c(0, 0), sigma = diag(2))
result <- affine_transform(m, A = diag(2), b = c(5, 10))
expect_true(is_mvn(result))
expect_equal(mean(result), c(5, 10))
expect_equal(vcov(result), diag(2))
})
test_that("affine_transform: projection (2D -> 1D)", {
m <- mvn(mu = c(3, 7), sigma = matrix(c(4, 1, 1, 2), 2, 2))
# Project onto first component
A <- matrix(c(1, 0), nrow = 1)
result <- affine_transform(m, A = A)
expect_true(is_normal(result))
expect_equal(mean(result), 3)
expect_equal(vcov(result), 4)
})
test_that("affine_transform: sum of components (1xd matrix)", {
m <- mvn(mu = c(1, 2), sigma = matrix(c(1, 0.5, 0.5, 1), 2, 2))
# A = [1, 1] -> X1 + X2
A <- matrix(c(1, 1), nrow = 1)
result <- affine_transform(m, A = A)
expect_true(is_normal(result))
expect_equal(mean(result), 3)
# var(X1 + X2) = 1 + 2*0.5 + 1 = 3
expect_equal(vcov(result), 3)
})
test_that("affine_transform: works on univariate normal", {
n <- normal(mu = 5, var = 4)
# 2*X + 3
result <- affine_transform(n, A = 2, b = 3)
expect_true(is_normal(result))
expect_equal(mean(result), 13) # 2*5 + 3
expect_equal(vcov(result), 16) # 2^2 * 4
})
test_that("affine_transform errors: wrong dimensions", {
m <- mvn(mu = c(1, 2), sigma = diag(2))
expect_error(affine_transform(m, A = diag(3)), "ncol.*must equal dim")
})
test_that("affine_transform errors: wrong b length", {
m <- mvn(mu = c(1, 2), sigma = diag(2))
expect_error(affine_transform(m, A = diag(2), b = c(1, 2, 3)), "length.*must equal nrow")
})
test_that("affine_transform errors: non-dist input", {
expect_error(affine_transform(42, A = 1), "'x' must be a 'dist' object")
})
test_that("affine_transform errors: unsupported dist type", {
expect_error(affine_transform(exponential(1), A = 1), "'x' must be a 'normal' or 'mvn'")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.