Nothing
# Test all math functions against numerical central differences.
# Each test verifies that dual-mode derivative matches the numerical estimate.
tol_deriv <- 1e-6 # tolerance for derivative comparison
# Helper: check a unary function's derivative at a point
check_deriv <- function(f, x, label = deparse(substitute(f))) {
d <- f(dual(x, 1))
num <- central_difference(f, x)
expect_equal(
deriv(d), num,
tolerance = tol_deriv,
label = sprintf("%s'(%g): AD=%g, numerical=%g", label, x, deriv(d), num)
)
}
# -- Exponential / Logarithm --------------------------------------------------
test_that("exp derivative", {
check_deriv(exp, 1)
check_deriv(exp, -2)
check_deriv(exp, 0)
})
test_that("expm1 derivative", {
check_deriv(expm1, 0.01)
check_deriv(expm1, -0.5)
})
test_that("log derivative", {
check_deriv(log, 1)
check_deriv(log, 2.5)
check_deriv(log, 0.1)
})
test_that("log2 derivative", {
check_deriv(log2, 1)
check_deriv(log2, 4)
})
test_that("log10 derivative", {
check_deriv(log10, 1)
check_deriv(log10, 100)
})
test_that("log1p derivative", {
check_deriv(log1p, 0.001)
check_deriv(log1p, 1)
})
test_that("log with base argument", {
f <- function(x) log(x, base = 3)
check_deriv(f, 2, "log_base3")
check_deriv(f, 9, "log_base3")
})
# -- Square root ---------------------------------------------------------------
test_that("sqrt derivative", {
check_deriv(sqrt, 4)
check_deriv(sqrt, 0.25)
})
# -- Trigonometric --------------------------------------------------------------
test_that("sin derivative", {
check_deriv(sin, 0)
check_deriv(sin, pi / 4)
check_deriv(sin, pi)
})
test_that("cos derivative", {
check_deriv(cos, 0)
check_deriv(cos, pi / 4)
check_deriv(cos, pi)
})
test_that("tan derivative", {
check_deriv(tan, 0)
check_deriv(tan, pi / 4)
check_deriv(tan, 0.5)
})
# -- Inverse trigonometric ------------------------------------------------------
test_that("asin derivative", {
check_deriv(asin, 0)
check_deriv(asin, 0.5)
check_deriv(asin, -0.5)
})
test_that("acos derivative", {
check_deriv(acos, 0)
check_deriv(acos, 0.5)
check_deriv(acos, -0.5)
})
test_that("atan derivative", {
check_deriv(atan, 0)
check_deriv(atan, 1)
check_deriv(atan, -3)
})
# -- Hyperbolic ----------------------------------------------------------------
test_that("sinh derivative", {
check_deriv(sinh, 0)
check_deriv(sinh, 1)
})
test_that("cosh derivative", {
check_deriv(cosh, 0)
check_deriv(cosh, 1)
})
test_that("tanh derivative", {
check_deriv(tanh, 0)
check_deriv(tanh, 1)
check_deriv(tanh, -2)
})
# -- Inverse hyperbolic --------------------------------------------------------
test_that("asinh derivative", {
check_deriv(asinh, 0)
check_deriv(asinh, 1)
check_deriv(asinh, -2)
})
test_that("acosh derivative", {
check_deriv(acosh, 1.5)
check_deriv(acosh, 3)
})
test_that("atanh derivative", {
check_deriv(atanh, 0)
check_deriv(atanh, 0.5)
check_deriv(atanh, -0.3)
})
# -- Gamma-related -------------------------------------------------------------
test_that("gamma derivative", {
check_deriv(gamma, 1)
check_deriv(gamma, 2.5)
check_deriv(gamma, 0.5)
})
test_that("lgamma derivative", {
check_deriv(lgamma, 1)
check_deriv(lgamma, 2.5)
check_deriv(lgamma, 5)
})
test_that("digamma derivative", {
check_deriv(digamma, 1)
check_deriv(digamma, 2.5)
check_deriv(digamma, 5)
})
test_that("trigamma derivative", {
check_deriv(trigamma, 1)
check_deriv(trigamma, 2.5)
check_deriv(trigamma, 5)
})
# -- Absolute value / sign ------------------------------------------------------
test_that("abs derivative", {
check_deriv(abs, 3)
check_deriv(abs, -3)
})
test_that("sign derivative is zero", {
d <- sign(dual(3, 1))
expect_equal(value(d), 1)
expect_equal(deriv(d), 0)
})
# -- Floor / ceiling / round ---------------------------------------------------
test_that("floor derivative is zero", {
d <- floor(dual(3.7, 1))
expect_equal(value(d), 3)
expect_equal(deriv(d), 0)
})
test_that("ceiling derivative is zero", {
d <- ceiling(dual(3.2, 1))
expect_equal(value(d), 4)
expect_equal(deriv(d), 0)
})
# -- atan2 ---------------------------------------------------------------------
test_that("atan2 derivative", {
# d/dx atan2(x, 2) at x=1 is 2/(1+4) = 2/5
y <- dual(1, 1)
r <- atan2(y, 2)
expect_equal(value(r), atan2(1, 2))
expect_equal(deriv(r), 2 / 5, tolerance = 1e-10)
})
test_that("atan2 both dual", {
# d/dx atan2(x, x) = d/dx atan2(1,1) = 0 (constant pi/4)
# More precisely: (x*1 - x*1)/(x^2+x^2) = 0
x <- dual(1, 1)
r <- atan2(x, x)
expect_equal(value(r), pi / 4)
expect_equal(deriv(r), 0, tolerance = 1e-10)
})
# -- Compositions ---------------------------------------------------------------
test_that("log(sin(x) + 2) derivative", {
f <- function(x) log(sin(x) + 2)
x <- 1.0
check_deriv(f, x, "log(sin(x)+2)")
})
test_that("exp(cos(x) * x) derivative", {
f <- function(x) exp(cos(x) * x)
check_deriv(f, 0.5, "exp(cos(x)*x)")
})
test_that("sqrt(1 + x^2) derivative", {
f <- function(x) sqrt(1 + x^2)
check_deriv(f, 2, "sqrt(1+x^2)")
check_deriv(f, 0, "sqrt(1+x^2)")
})
test_that("lgamma(exp(x)) derivative", {
f <- function(x) lgamma(exp(x))
check_deriv(f, 0.5, "lgamma(exp(x))")
})
# -- max / min ------------------------------------------------------------------
test_that("max selects correct branch", {
a <- dual(3, 1)
b <- dual(5, 2)
r <- max(a, b)
expect_equal(value(r), 5)
expect_equal(deriv(r), 2)
})
test_that("min selects correct branch", {
a <- dual(3, 1)
b <- dual(5, 2)
r <- min(a, b)
expect_equal(value(r), 3)
expect_equal(deriv(r), 1)
})
# -- max / min with 3+ arguments (bug fix regression tests) ------------------
test_that("max with 3 dual arguments selects the largest", {
a <- dual(3, 10)
b <- dual(7, 20)
c <- dual(5, 30)
r <- max(a, b, c)
expect_equal(value(r), 7)
expect_equal(deriv(r), 20)
})
test_that("max with 3 dual arguments when last is largest", {
a <- dual(1, 10)
b <- dual(2, 20)
c <- dual(9, 30)
r <- max(a, b, c)
expect_equal(value(r), 9)
expect_equal(deriv(r), 30)
})
test_that("min with 3 dual arguments selects the smallest", {
a <- dual(3, 10)
b <- dual(7, 20)
c <- dual(5, 30)
r <- min(a, b, c)
expect_equal(value(r), 3)
expect_equal(deriv(r), 10)
})
test_that("min with 3 dual arguments when last is smallest", {
a <- dual(5, 10)
b <- dual(3, 20)
c <- dual(1, 30)
r <- min(a, b, c)
expect_equal(value(r), 1)
expect_equal(deriv(r), 30)
})
test_that("max with mixed dual and numeric args", {
a <- dual(3, 1)
r <- max(a, 7, 5)
expect_true(is_dual(r))
expect_equal(value(r), 7)
expect_equal(deriv(r), 0)
})
test_that("min with mixed dual and numeric args", {
a <- dual(3, 1)
r <- min(a, 1, 5)
expect_true(is_dual(r))
expect_equal(value(r), 1)
expect_equal(deriv(r), 0)
})
# -- NaN / Inf propagation ----------------------------------------------------
test_that("NaN propagates through dual arithmetic", {
x <- dual(NaN, 1)
r <- x + dual(1, 1)
expect_true(is.nan(value(r)))
})
test_that("Inf propagates through exp of large dual", {
x <- dual(1000, 1)
r <- exp(x)
expect_true(is.infinite(value(r)))
})
test_that("log of negative dual produces NaN", {
x <- dual(-1, 1)
r <- log(x)
expect_true(is.nan(value(r)))
})
# -- cumprod now errors --------------------------------------------------------
test_that("cumprod on dual throws an error", {
expect_error(cumprod(dual(3, 1)), "cumprod.*not supported")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.