tests/testthat/test-algebra-extended.R

# Tests for extended algebraic operations and simplification rules

# ---- Scalar multiplication ----

test_that("scalar * normal simplifies", {
  d <- 3 * normal(2, 4)
  expect_true(is_normal(d))
  expect_equal(mean(d), 6)
  expect_equal(vcov(d), 36)
})

test_that("normal * scalar simplifies", {
  d <- normal(2, 4) * 3
  expect_true(is_normal(d))
  expect_equal(mean(d), 6)
  expect_equal(vcov(d), 36)
})

test_that("scalar * gamma simplifies", {
  d <- 2 * gamma_dist(3, 4)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 3)
  expect_equal(d$rate, 2)  # 4/2
})

test_that("scalar * exponential -> gamma", {
  d <- 2 * exponential(3)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 1)
  expect_equal(d$rate, 1.5)  # 3/2
})

test_that("negative scalar * gamma stays edist", {
  d <- (-2) * gamma_dist(3, 4)
  expect_true(is_edist(d))
})

test_that("dist * dist stays edist", {
  d <- normal(0, 1) * normal(0, 1)
  expect_true(is_edist(d))
})

# ---- Location shift (dist + scalar, scalar + dist, dist - scalar) ----

test_that("normal + scalar simplifies", {
  d <- normal(2, 4) + 3
  expect_true(is_normal(d))
  expect_equal(mean(d), 5)
  expect_equal(vcov(d), 4)
})

test_that("scalar + normal simplifies", {
  d <- 3 + normal(2, 4)
  expect_true(is_normal(d))
  expect_equal(mean(d), 5)
  expect_equal(vcov(d), 4)
})

test_that("normal - scalar simplifies", {
  d <- normal(5, 4) - 3
  expect_true(is_normal(d))
  expect_equal(mean(d), 2)
  expect_equal(vcov(d), 4)
})

test_that("exp + scalar stays edist for non-normal", {
  d <- exponential(1) + 5
  expect_true(is_edist(d))
})

# ---- Gamma / Exponential addition rules ----

test_that("gamma + gamma same rate simplifies", {
  d <- gamma_dist(2, 3) + gamma_dist(4, 3)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 6)
  expect_equal(d$rate, 3)
})

test_that("gamma + gamma different rate stays edist", {
  d <- gamma_dist(2, 3) + gamma_dist(4, 5)
  expect_true(is_edist(d))
})

test_that("exp + exp same rate -> gamma(2, rate)", {
  d <- exponential(3) + exponential(3)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 2)
  expect_equal(d$rate, 3)
})

test_that("exp + exp different rate stays edist", {
  d <- exponential(3) + exponential(5)
  expect_true(is_edist(d))
})

test_that("gamma + exp same rate simplifies", {
  d <- gamma_dist(3, 2) + exponential(2)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 4)
  expect_equal(d$rate, 2)
})

test_that("exp + gamma same rate simplifies (reverse order)", {
  d <- exponential(2) + gamma_dist(3, 2)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 4)
  expect_equal(d$rate, 2)
})

# ---- Chi-squared addition ----

test_that("chi_squared + chi_squared simplifies", {
  d <- chi_squared(3) + chi_squared(5)
  expect_true(is_chi_squared(d))
  expect_equal(d$df, 8)
})

# ---- Poisson addition ----

test_that("poisson + poisson simplifies", {
  d <- poisson_dist(3) + poisson_dist(5)
  expect_true(is_poisson_dist(d))
  expect_equal(d$lambda, 8)
})

# ---- Power operator ----

test_that("N(0,1)^2 -> chi_squared(1)", {
  d <- normal(0, 1) ^ 2
  expect_true(is_chi_squared(d))
  expect_equal(d$df, 1)
})

test_that("N(0,1)^2 + N(0,1)^2 -> chi_squared(2) (chaining)", {
  d <- normal(0, 1)^2 + normal(0, 1)^2
  expect_true(is_chi_squared(d))
  expect_equal(d$df, 2)
})

test_that("non-standard normal ^2 stays edist", {
  d <- normal(1, 1) ^ 2
  expect_true(is_edist(d))
})

# ---- Math group generic ----

test_that("exp(normal) -> lognormal", {
  d <- exp(normal(2, 3))
  expect_true(is_lognormal(d))
  expect_equal(d$meanlog, 2)
  expect_equal(d$sdlog, sqrt(3))
})

test_that("log(lognormal) -> normal", {
  d <- log(lognormal(2, 3))
  expect_true(is_normal(d))
  expect_equal(mean(d), 2)
  expect_equal(vcov(d), 9)  # 3^2
})

test_that("sqrt(normal) stays edist", {
  d <- sqrt(normal(4, 1))
  expect_true(is_edist(d))
})

test_that("abs(normal) stays edist", {
  d <- abs(normal(0, 1))
  expect_true(is_edist(d))
})

# ---- Summary group generic ----

test_that("sum of normals simplifies via Summary.dist", {
  d <- sum(normal(1, 1), normal(2, 3), normal(3, 5))
  expect_true(is_normal(d))
  expect_equal(mean(d), 6)
  expect_equal(vcov(d), 9)
})

test_that("min of exponentials simplifies", {
  d <- min(exponential(1), exponential(2), exponential(3))
  expect_true(is_exponential(d))
  expect_equal(d$rate, 6)
})

test_that("min of mixed stays edist", {
  d <- min(exponential(1), normal(1, 1))
  expect_true(is_edist(d))
})

test_that("max stays edist", {
  d <- max(exponential(1), exponential(2))
  expect_true(is_edist(d))
})

test_that("prod creates edist", {
  d <- prod(normal(1, 1), normal(2, 1))
  expect_true(is_edist(d))
})

test_that("single element sum returns self", {
  n <- normal(1, 1)
  d <- sum(n)
  expect_true(is_normal(d))
  expect_equal(mean(d), 1)
})

test_that("single element min returns self", {
  e <- exponential(2)
  d <- min(e)
  expect_true(is_exponential(d))
  expect_equal(d$rate, 2)
})

# ---- Sampler n=1 bug fix ----

test_that("sampler.edist works with n=1", {
  x <- normal(0, 1)
  y <- normal(0, 1)
  ed <- edist(quote(x + y), list(x = x, y = y))
  samp <- sampler(ed)(1)
  expect_type(samp, "double")
  expect_length(samp, 1)
})

# ---- Chaining / composition ----

test_that("3 * normal(0,1)^2 works via chaining", {
  # normal(0,1)^2 = chi_squared(1)
  # 3 * chi_squared(1) = ?  -- no rule for scalar * chi_squared, stays edist
  # but this tests that the chain works without errors
  d <- 3 * normal(0, 1)^2
  # normal(0,1)^2 simplifies to chi_squared(1), then 3 * chi_squared(1) is edist
  expect_true(is_edist(d) || is_chi_squared(d) || is_gamma_dist(d))
})

test_that("sum of 4 exponentials (same rate) via Summary", {
  d <- sum(exponential(2), exponential(2), exponential(2), exponential(2))
  # Reduce with + chains: ((exp+exp)+exp)+exp
  # exp(2)+exp(2) -> gamma(2,2)
  # gamma(2,2)+exp(2) -> gamma(3,2)
  # gamma(3,2)+exp(2) -> gamma(4,2)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 4)
  expect_equal(d$rate, 2)
})

test_that("zero scalar multiplication of normal", {
  d <- 0 * normal(5, 3)
  expect_true(is_normal(d))
  expect_equal(mean(d), 0)
  expect_equal(vcov(d), 0)
})

test_that("identity scalar multiplication of normal", {
  d <- 1 * normal(5, 3)
  expect_true(is_normal(d))
  expect_equal(mean(d), 5)
  expect_equal(vcov(d), 3)
})

test_that("zero location shift of normal", {
  d <- normal(5, 3) + 0
  expect_true(is_normal(d))
  expect_equal(mean(d), 5)
  expect_equal(vcov(d), 3)
})

# ---- Uniform location shift rules ----

test_that("uniform + scalar simplifies", {
  d <- uniform_dist(2, 5) + 3
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, 5)
  expect_equal(d$max, 8)
})

test_that("scalar + uniform simplifies", {
  d <- 3 + uniform_dist(2, 5)
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, 5)
  expect_equal(d$max, 8)
})

test_that("uniform - scalar simplifies", {
  d <- uniform_dist(2, 5) - 1
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, 1)
  expect_equal(d$max, 4)
})

# ---- Uniform scale rules ----

test_that("positive scalar * uniform simplifies", {
  d <- 2 * uniform_dist(1, 4)
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, 2)
  expect_equal(d$max, 8)
})

test_that("negative scalar * uniform flips bounds", {
  d <- (-2) * uniform_dist(1, 4)
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, -8)
  expect_equal(d$max, -2)
})

test_that("uniform / scalar simplifies via multiplication", {
  d <- uniform_dist(2, 6) / 2
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, 1)
  expect_equal(d$max, 3)
})

# ---- Unary negation of uniform ----

test_that("-uniform flips bounds", {
  d <- -uniform_dist(1, 4)
  expect_true(is_uniform_dist(d))
  expect_equal(d$min, -4)
  expect_equal(d$max, -1)
})

# ---- Weibull scale rule ----

test_that("positive scalar * weibull simplifies", {
  d <- 3 * weibull_dist(shape = 2, scale = 5)
  expect_true(is_weibull_dist(d))
  expect_equal(d$shape, 2)
  expect_equal(d$scale, 15)
})

test_that("weibull / scalar simplifies", {
  d <- weibull_dist(shape = 2, scale = 6) / 3
  expect_true(is_weibull_dist(d))
  expect_equal(d$shape, 2)
  expect_equal(d$scale, 2)
})

test_that("negative scalar * weibull stays edist", {
  d <- (-2) * weibull_dist(shape = 2, scale = 5)
  expect_true(is_edist(d))
})

# ---- ChiSq scale rule ----

test_that("positive scalar * chi_squared -> gamma", {
  d <- 3 * chi_squared(4)
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 2)    # df/2 = 4/2
  expect_equal(d$rate, 1/6)   # 1/(2*3)
})

test_that("chi_squared / scalar -> gamma", {
  d <- chi_squared(6) / 2
  expect_true(is_gamma_dist(d))
  expect_equal(d$shape, 3)    # 6/2
  expect_equal(d$rate, 1)     # 1/(2*0.5)
})

# ---- LogNormal scale rule ----

test_that("positive scalar * lognormal simplifies", {
  d <- 2 * lognormal(3, 1)
  expect_true(is_lognormal(d))
  expect_equal(d$meanlog, 3 + log(2))
  expect_equal(d$sdlog, 1)
})

test_that("lognormal / scalar simplifies", {
  d <- lognormal(3, 1) / 2
  expect_true(is_lognormal(d))
  expect_equal(d$meanlog, 3 + log(0.5))
  expect_equal(d$sdlog, 1)
})

test_that("negative scalar * lognormal stays edist", {
  d <- (-2) * lognormal(3, 1)
  expect_true(is_edist(d))
})

# ---- LogNormal * LogNormal rule ----

test_that("lognormal * lognormal simplifies", {
  d <- lognormal(1, 2) * lognormal(3, 4)
  expect_true(is_lognormal(d))
  expect_equal(d$meanlog, 4)
  expect_equal(d$sdlog, sqrt(4 + 16))
})

test_that("prod of lognormals chains correctly", {
  d <- prod(lognormal(1, 1), lognormal(2, 1), lognormal(3, 1))
  expect_true(is_lognormal(d))
  expect_equal(d$meanlog, 6)
  expect_equal(d$sdlog, sqrt(3))
})

# ---- Division operator ----

test_that("scalar / dist creates edist", {
  d <- 1 / normal(1, 1)
  expect_true(is_edist(d))
})

test_that("dist / dist creates edist", {
  d <- normal(1, 1) / normal(2, 1)
  expect_true(is_edist(d))
})

# ---- MC validation of new rules ----

test_that("uniform shift: MC moments match closed form", {
  d_unsimplified <- edist(quote(x + 3), list(x = uniform_dist(0, 1)))
  d_simplified <- uniform_dist(3, 4)

  set.seed(42)
  s1 <- sampler(d_unsimplified)(50000)
  expect_equal(mean(s1), mean(d_simplified), tolerance = 0.05)
  expect_equal(var(s1), vcov(d_simplified), tolerance = 0.05)
})

test_that("weibull scale: MC moments match closed form", {
  d_simplified <- 3 * weibull_dist(shape = 2, scale = 5)

  set.seed(42)
  # Sample from unsimplified version
  raw <- 3 * rweibull(50000, shape = 2, scale = 5)
  expect_equal(mean(raw), mean(d_simplified), tolerance = 0.2)
})

test_that("lognormal product: MC moments match closed form", {
  d_simplified <- lognormal(1, 2) * lognormal(3, 4)

  set.seed(42)
  raw <- rlnorm(50000, 1, 2) * rlnorm(50000, 3, 4)
  expect_equal(mean(log(raw)), d_simplified$meanlog, tolerance = 0.1)
})

Try the algebraic.dist package in your browser

Any scripts or data that you put into this service are public.

algebraic.dist documentation built on Feb. 27, 2026, 5:06 p.m.