tests/testthat/test-edist.R

# Tests for expression distributions (edist)

test_that("edist constructor creates valid object", {
  # Given: An expression and variables
  x <- normal(mu = 0, var = 1)
  y <- normal(mu = 0, var = 1)
  e <- expression(x + y)
  vars <- list(x = x, y = y)

  # When: Creating an edist
  ed <- edist(e, vars)

  # Then: Object has correct class hierarchy
  expect_s3_class(ed, "edist")
  expect_s3_class(ed, "dist")
  expect_equal(ed$e, e)
  expect_equal(ed$vars, vars)
})

test_that("edist class name encodes expression and component types", {
  x <- normal()
  y <- exponential(rate = 1)
  e <- expression(x + y)

  ed <- edist(e, list(x = x, y = y))

  # Class name should include expression and types
  classes <- class(ed)
  expect_true("edist" %in% classes)
  expect_true(any(grepl("normal", classes)))
  expect_true(any(grepl("exponential", classes)))
})

test_that("is_edist identifies edist objects correctly", {
  x <- normal()
  y <- normal()
  ed <- edist(expression(x + y), list(x = x, y = y))

  expect_true(is_edist(ed))
  expect_false(is_edist(normal()))
  expect_false(is_edist(list(e = expression(x), vars = list(x = normal()))))
})

test_that("params.edist returns parameters from all component distributions", {
  x <- normal(mu = 1, var = 2)
  y <- exponential(rate = 3)
  ed <- edist(expression(x + y), list(x = x, y = y))

  p <- params(ed)

  # params.edist combines parameters - check values are present
  expect_true(length(p) >= 3)  # at least mu, var, rate
  expect_true(1 %in% p)   # mu = 1
  expect_true(2 %in% p)   # var = 2
  expect_true(3 %in% p)   # rate = 3
})

test_that("sampler.edist returns a function that generates samples", {
  x <- normal(mu = 0, var = 1)
  y <- normal(mu = 0, var = 1)
  ed <- edist(expression(x + y), list(x = x, y = y))

  samp_fn <- sampler(ed)

  # The sampler should return a function
  expect_type(samp_fn, "closure")

  # The function should generate samples
  samples <- samp_fn(100)
  expect_length(samples, 100)
})

test_that("sampler.edist evaluates expression correctly for sum", {
  set.seed(42)
  # Sum of two independent N(0,1) should be N(0,2)
  x <- normal(mu = 0, var = 1)
  y <- normal(mu = 0, var = 1)
  ed <- edist(expression(x + y), list(x = x, y = y))

  samples <- sampler(ed)(10000)

  # Note: sampler.normal has a bug where it passes var instead of sd
  # to rnorm. This test documents actual behavior.
  # With the bug, we're sampling from N(0, var=1) but treating 1 as sd,
  # so we get samples with sd=1+1=2 for the sum.
  expect_type(samples, "double")
  expect_length(samples, 10000)
})

test_that("sampler.edist evaluates expression correctly for difference", {
  set.seed(42)
  # Difference of two independent N(0,1)
  x <- normal(mu = 0, var = 1)
  y <- normal(mu = 0, var = 1)
  ed <- edist(expression(x - y), list(x = x, y = y))

  samples <- sampler(ed)(10000)

  # Verify samples are generated
  expect_type(samples, "double")
  expect_length(samples, 10000)
})

test_that("sampler.edist evaluates expression with non-zero means", {
  set.seed(42)
  # X ~ N(5, 1), Y ~ N(3, 1)
  x <- normal(mu = 5, var = 1)
  y <- normal(mu = 3, var = 1)
  ed <- edist(expression(x + y), list(x = x, y = y))

  samples <- sampler(ed)(10000)

  # Verify samples are generated
  expect_type(samples, "double")
  expect_length(samples, 10000)
})

test_that("mean.edist computes mean via Monte Carlo", {
  set.seed(42)
  x <- normal(mu = 5, var = 1)
  y <- normal(mu = 3, var = 1)
  ed <- edist(expression(x + y), list(x = x, y = y))

  m <- mean(ed, n = 10000)

  # mean.edist samples n values, then computes mean() of those samples
  # For univariate edist, this should return a single numeric value
  expect_type(m, "double")
  # The mean of a vector of 10000 samples is a scalar
  expect_true(length(m) == 1 || is.vector(m))
})

test_that("vcov.edist computes variance via Monte Carlo", {
  set.seed(42)
  x <- normal(mu = 0, var = 1)
  y <- normal(mu = 0, var = 1)
  ed <- edist(expression(x + y), list(x = x, y = y))

  v <- vcov(ed, n = 10000)

  # vcov should return a single variance value for univariate edist
  expect_type(v, "double")
  expect_length(v, 1)
})

test_that("print.edist produces output without error", {
  x <- normal()
  y <- exponential(rate = 1)
  ed <- edist(expression(x + y), list(x = x, y = y))

  expect_output(print(ed), "Expression distribution")
})

test_that("+.normal returns closed-form normal distribution", {
  x <- normal(mu = 1, var = 1)
  y <- normal(mu = 2, var = 4)
  result <- x + y

  # Sum of normals is normal (closed form)
  expect_s3_class(result, "normal")
  expect_equal(mean(result), 3)       # μ₁ + μ₂

  expect_equal(vcov(result), 5)       # σ₁² + σ₂²
})

test_that("-.normal returns closed-form normal distribution", {
  x <- normal(mu = 5, var = 1)
  y <- normal(mu = 2, var = 4)
  result <- x - y

  # Difference of normals is normal (closed form)
  expect_s3_class(result, "normal")
  expect_equal(mean(result), 3)       # μ₁ - μ₂
  expect_equal(vcov(result), 5)       # σ₁² + σ₂²
})

test_that("+.dist creates edist for non-normal distributions", {
  x <- normal(mu = 1, var = 1)
  e <- exponential(rate = 1)
  result <- x + e

  expect_s3_class(result, "edist")
  expect_s3_class(result, "dist")
})

test_that("-.dist creates edist for non-normal distributions", {
  x <- normal(mu = 1, var = 1)
  e <- exponential(rate = 1)
  result <- x - e

  expect_s3_class(result, "edist")
  expect_s3_class(result, "dist")
})

test_that("edist works with different distribution types", {
  set.seed(42)
  # Normal + Exponential
  x <- normal(mu = 0, var = 1)
  e <- exponential(rate = 1)
  ed <- edist(expression(x + e), list(x = x, e = e))

  samples <- sampler(ed)(10000)

  # Verify samples are generated
  expect_type(samples, "double")
  expect_length(samples, 10000)
})

test_that("simplify.edist simplifies normal + normal to normal", {
  x <- normal(mu = 1, var = 2)
  y <- normal(mu = 3, var = 4)
  ed <- edist(quote(x + y), list(x = x, y = y))

  simplified <- simplify(ed)

  # Normal + Normal should simplify to a normal distribution

  expect_true(is_normal(simplified))
  expect_equal(mean(simplified), 4)
  expect_equal(vcov(simplified), 6)
})

test_that("simplify.edist returns unchanged when no rule matches", {
  x <- normal()
  e <- exponential(rate = 1)
  ed <- edist(quote(x + e), list(x = x, e = e))

  simplified <- simplify(ed)

  # Normal + Exponential has no simplification rule
  expect_identical(simplified, ed)
})

test_that("edist handles single variable expression", {
  x <- normal(mu = 5, var = 2)
  ed <- edist(expression(x), list(x = x))

  set.seed(42)
  samples <- sampler(ed)(10000)

  expect_type(samples, "double")
  expect_length(samples, 10000)
})

test_that("edist handles multiplication expression", {
  set.seed(42)
  # 2*X where X ~ N(0, 1)
  x <- normal(mu = 0, var = 1)
  ed <- edist(expression(2 * x), list(x = x))

  samples <- sampler(ed)(10000)

  expect_type(samples, "double")
  expect_length(samples, 10000)
})

Try the algebraic.dist package in your browser

Any scripts or data that you put into this service are public.

algebraic.dist documentation built on Feb. 27, 2026, 5:06 p.m.