tests/testthat/test-math.R

# Test all math functions against numerical central differences.
# Each test verifies that dual-mode derivative matches the numerical estimate.

tol_deriv <- 1e-6  # tolerance for derivative comparison

# Helper: check a unary function's derivative at a point
check_deriv <- function(f, x, label = deparse(substitute(f))) {
  d <- f(dual(x, 1))
  num <- central_difference(f, x)
  expect_equal(
    deriv(d), num,
    tolerance = tol_deriv,
    label = sprintf("%s'(%g): AD=%g, numerical=%g", label, x, deriv(d), num)
  )
}

# -- Exponential / Logarithm --------------------------------------------------

test_that("exp derivative", {
  check_deriv(exp, 1)
  check_deriv(exp, -2)
  check_deriv(exp, 0)
})

test_that("expm1 derivative", {
  check_deriv(expm1, 0.01)
  check_deriv(expm1, -0.5)
})

test_that("log derivative", {
  check_deriv(log, 1)
  check_deriv(log, 2.5)
  check_deriv(log, 0.1)
})

test_that("log2 derivative", {
  check_deriv(log2, 1)
  check_deriv(log2, 4)
})

test_that("log10 derivative", {
  check_deriv(log10, 1)
  check_deriv(log10, 100)
})

test_that("log1p derivative", {
  check_deriv(log1p, 0.001)
  check_deriv(log1p, 1)
})

test_that("log with base argument", {
  f <- function(x) log(x, base = 3)
  check_deriv(f, 2, "log_base3")
  check_deriv(f, 9, "log_base3")
})

# -- Square root ---------------------------------------------------------------

test_that("sqrt derivative", {
  check_deriv(sqrt, 4)
  check_deriv(sqrt, 0.25)
})

# -- Trigonometric --------------------------------------------------------------

test_that("sin derivative", {
  check_deriv(sin, 0)
  check_deriv(sin, pi / 4)
  check_deriv(sin, pi)
})

test_that("cos derivative", {
  check_deriv(cos, 0)
  check_deriv(cos, pi / 4)
  check_deriv(cos, pi)
})

test_that("tan derivative", {
  check_deriv(tan, 0)
  check_deriv(tan, pi / 4)
  check_deriv(tan, 0.5)
})

# -- Inverse trigonometric ------------------------------------------------------

test_that("asin derivative", {
  check_deriv(asin, 0)
  check_deriv(asin, 0.5)
  check_deriv(asin, -0.5)
})

test_that("acos derivative", {
  check_deriv(acos, 0)
  check_deriv(acos, 0.5)
  check_deriv(acos, -0.5)
})

test_that("atan derivative", {
  check_deriv(atan, 0)
  check_deriv(atan, 1)
  check_deriv(atan, -3)
})

# -- Hyperbolic ----------------------------------------------------------------

test_that("sinh derivative", {
  check_deriv(sinh, 0)
  check_deriv(sinh, 1)
})

test_that("cosh derivative", {
  check_deriv(cosh, 0)
  check_deriv(cosh, 1)
})

test_that("tanh derivative", {
  check_deriv(tanh, 0)
  check_deriv(tanh, 1)
  check_deriv(tanh, -2)
})

# -- Inverse hyperbolic --------------------------------------------------------

test_that("asinh derivative", {
  check_deriv(asinh, 0)
  check_deriv(asinh, 1)
  check_deriv(asinh, -2)
})

test_that("acosh derivative", {
  check_deriv(acosh, 1.5)
  check_deriv(acosh, 3)
})

test_that("atanh derivative", {
  check_deriv(atanh, 0)
  check_deriv(atanh, 0.5)
  check_deriv(atanh, -0.3)
})

# -- Gamma-related -------------------------------------------------------------

test_that("gamma derivative", {
  check_deriv(gamma, 1)
  check_deriv(gamma, 2.5)
  check_deriv(gamma, 0.5)
})

test_that("lgamma derivative", {
  check_deriv(lgamma, 1)
  check_deriv(lgamma, 2.5)
  check_deriv(lgamma, 5)
})

test_that("digamma derivative", {
  check_deriv(digamma, 1)
  check_deriv(digamma, 2.5)
  check_deriv(digamma, 5)
})

test_that("trigamma derivative", {
  check_deriv(trigamma, 1)
  check_deriv(trigamma, 2.5)
  check_deriv(trigamma, 5)
})

# -- Absolute value / sign ------------------------------------------------------

test_that("abs derivative", {
  check_deriv(abs, 3)
  check_deriv(abs, -3)
})

test_that("sign derivative is zero", {
  d <- sign(dual(3, 1))
  expect_equal(value(d), 1)
  expect_equal(deriv(d), 0)
})

# -- Floor / ceiling / round ---------------------------------------------------

test_that("floor derivative is zero", {
  d <- floor(dual(3.7, 1))
  expect_equal(value(d), 3)
  expect_equal(deriv(d), 0)
})

test_that("ceiling derivative is zero", {
  d <- ceiling(dual(3.2, 1))
  expect_equal(value(d), 4)
  expect_equal(deriv(d), 0)
})

# -- atan2 ---------------------------------------------------------------------

test_that("atan2 derivative", {
  # d/dx atan2(x, 2) at x=1 is 2/(1+4) = 2/5
  y <- dual(1, 1)
  r <- atan2(y, 2)
  expect_equal(value(r), atan2(1, 2))
  expect_equal(deriv(r), 2 / 5, tolerance = 1e-10)
})

test_that("atan2 both dual", {
  # d/dx atan2(x, x) = d/dx atan2(1,1) = 0 (constant pi/4)
  # More precisely: (x*1 - x*1)/(x^2+x^2) = 0
  x <- dual(1, 1)
  r <- atan2(x, x)
  expect_equal(value(r), pi / 4)
  expect_equal(deriv(r), 0, tolerance = 1e-10)
})

# -- Compositions ---------------------------------------------------------------

test_that("log(sin(x) + 2) derivative", {
  f <- function(x) log(sin(x) + 2)
  x <- 1.0
  check_deriv(f, x, "log(sin(x)+2)")
})

test_that("exp(cos(x) * x) derivative", {
  f <- function(x) exp(cos(x) * x)
  check_deriv(f, 0.5, "exp(cos(x)*x)")
})

test_that("sqrt(1 + x^2) derivative", {
  f <- function(x) sqrt(1 + x^2)
  check_deriv(f, 2, "sqrt(1+x^2)")
  check_deriv(f, 0, "sqrt(1+x^2)")
})

test_that("lgamma(exp(x)) derivative", {
  f <- function(x) lgamma(exp(x))
  check_deriv(f, 0.5, "lgamma(exp(x))")
})

# -- max / min ------------------------------------------------------------------

test_that("max selects correct branch", {
  a <- dual(3, 1)
  b <- dual(5, 2)
  r <- max(a, b)
  expect_equal(value(r), 5)
  expect_equal(deriv(r), 2)
})

test_that("min selects correct branch", {
  a <- dual(3, 1)
  b <- dual(5, 2)
  r <- min(a, b)
  expect_equal(value(r), 3)
  expect_equal(deriv(r), 1)
})

# -- max / min with 3+ arguments (bug fix regression tests) ------------------

test_that("max with 3 dual arguments selects the largest", {
  a <- dual(3, 10)
  b <- dual(7, 20)
  c <- dual(5, 30)
  r <- max(a, b, c)
  expect_equal(value(r), 7)
  expect_equal(deriv(r), 20)
})

test_that("max with 3 dual arguments when last is largest", {
  a <- dual(1, 10)
  b <- dual(2, 20)
  c <- dual(9, 30)
  r <- max(a, b, c)
  expect_equal(value(r), 9)
  expect_equal(deriv(r), 30)
})

test_that("min with 3 dual arguments selects the smallest", {
  a <- dual(3, 10)
  b <- dual(7, 20)
  c <- dual(5, 30)
  r <- min(a, b, c)
  expect_equal(value(r), 3)
  expect_equal(deriv(r), 10)
})

test_that("min with 3 dual arguments when last is smallest", {
  a <- dual(5, 10)
  b <- dual(3, 20)
  c <- dual(1, 30)
  r <- min(a, b, c)
  expect_equal(value(r), 1)
  expect_equal(deriv(r), 30)
})

test_that("max with mixed dual and numeric args", {
  a <- dual(3, 1)
  r <- max(a, 7, 5)
  expect_true(is_dual(r))
  expect_equal(value(r), 7)
  expect_equal(deriv(r), 0)
})

test_that("min with mixed dual and numeric args", {
  a <- dual(3, 1)
  r <- min(a, 1, 5)
  expect_true(is_dual(r))
  expect_equal(value(r), 1)
  expect_equal(deriv(r), 0)
})

# -- NaN / Inf propagation ----------------------------------------------------

test_that("NaN propagates through dual arithmetic", {
  x <- dual(NaN, 1)
  r <- x + dual(1, 1)
  expect_true(is.nan(value(r)))
})

test_that("Inf propagates through exp of large dual", {
  x <- dual(1000, 1)
  r <- exp(x)
  expect_true(is.infinite(value(r)))
})

test_that("log of negative dual produces NaN", {
  x <- dual(-1, 1)
  r <- log(x)
  expect_true(is.nan(value(r)))
})

# -- cumprod now errors --------------------------------------------------------

test_that("cumprod on dual throws an error", {
  expect_error(cumprod(dual(3, 1)), "cumprod.*not supported")
})

Try the nabla package in your browser

Any scripts or data that you put into this service are public.

nabla documentation built on Feb. 11, 2026, 1:06 a.m.