tests/testthat/test-einsum.R

test_that("basic matrix operations work", {

  mat1 <- matrix(rnorm(n = 4 * 8), nrow = 4, ncol = 8)
  mat2 <- matrix(rnorm(n = 8 * 3), nrow = 8, ncol = 3)

  # Matrix Multiply
  expect_equal(einsum("ij,jk->ik", mat1, mat2), mat1 %*% mat2)
  expect_equal(einsum("ij,jk->ik", mat1, mat2), tcrossprod(mat1, t(mat2)))
  expect_equal(einsum("ij,jk->ik", mat1, mat2), einsum_generator("ij,jk->ik")(mat1, mat2))

  # Diag
  mat_sq <- matrix(rnorm(n = 4 * 4), nrow = 4, ncol = 4)
  expect_equal(c(einsum("ii->i", mat_sq)), diag(mat_sq))
  expect_equal(einsum("ii->i", mat_sq), einsum_generator("ii->i")(mat_sq))

  expect_equal(einsum("ii->ii", mat_sq), diag(diag(mat_sq)))
  expect_equal(einsum("ii->ii", mat_sq), einsum_generator("ii->ii")(mat_sq))

  expect_equal(einsum("ii->iii", mat_sq), einsum_generator("ii->iii")(mat_sq))

  # Trace
  expect_equal(c(einsum("ii->", mat_sq)), sum(diag(mat_sq)))
  expect_equal(einsum("ii->", mat_sq), einsum_generator("ii->")(mat_sq))

  # Row sum
  expect_equal(c(einsum("ij->i", mat1)), rowSums(mat1))
  expect_equal(einsum("ij->i", mat1), einsum_generator("ij->i")(mat1))

  # Col sum
  expect_equal(c(einsum("ij->j", mat1)), colSums(mat1))
  expect_equal(einsum("ij->j", mat1), einsum_generator("ij->j")(mat1))

  # Scalar product
  mat3 <- matrix(rnorm(n = 4 * 8), nrow = 4, ncol = 8)
  expect_equal(einsum("ij,ij->ij", mat3, mat1), mat3 * mat1)
  expect_equal(einsum("ij,ij->ij", mat3, mat1), einsum_generator("ij,ij->ij")(mat3, mat1))

  # Transpose
  expect_equal(einsum("ij->ji", mat1), t(mat1))
  expect_equal(einsum("ij->ji", mat1), einsum_generator("ij->ji")(mat1))

  # Matrix times vector
  vec <- rnorm(8)
  expect_equal(c(einsum("ij,j->i", mat3, vec)), c(mat3 %*% vec))
  expect_equal(einsum("ij,j->i", mat3, vec), einsum_generator("ij,j->i")(mat3, vec))

  # Batched L2 norm
  arr1 <- array(c(mat1, mat3), dim = c(dim(mat1), 2))  # eq. to abind()
  expect_equal(c(einsum("ijb,ijb->b", arr1, arr1)), c(sum(mat1^2), sum(mat3^2)))
  expect_equal(einsum("ijb,ijb->b", arr1, arr1), einsum_generator("ijb,ijb->b")(arr1, arr1))

  # More complex example
  expect_equal(c(einsum("ij,kj,kl,l->i", mat1, mat3, mat_sq, c(1:4))), c((mat1 %*% t(mat3)) %*% mat_sq %*% c(1:4)))
  expect_equal(einsum("ij,kj,kl,l->i", mat1, mat3, mat_sq, c(1:4)), einsum_generator("ij,kj,kl,l->i")(mat1, mat3, mat_sq, c(1:4)))

})


test_that("einsum can handle whitespace in equation_string", {
  mat1 <- matrix(rnorm(n = 4 * 8), nrow = 4, ncol = 8)
  mat2 <- matrix(rnorm(n = 8 * 3), nrow = 8, ncol = 3)

  expect_equal(einsum("ij,j k -> ik", mat1, mat2), mat1 %*% mat2)
})



test_that("einsum gives appropriate error messages", {
  mat1 <- matrix(rnorm(n = 4 * 8), nrow = 4, ncol = 8)
  mat2 <- matrix(rnorm(n = 8 * 3), nrow = 8, ncol = 3)

  # j is 8 and 3
  expect_error(einsum("ij,jj -> ik", mat1, mat2))

  # more arrays than elements in the lhs
  expect_error(einsum("ij,jj -> ik", mat1, mat2, mat1))

  # length(dim(array)) does not match number of indices
  expect_error(einsum("ij,jk -> ik", mat1, 1:5))

  # Invalid character in equation_string
  expect_error(einsum("ij,jk -> i3k", mat1, mat2))
  expect_error(einsum("ij,j$k -> ik", mat1, mat2))

  # Missing ->
  expect_error(einsum("ij,jk", mat1, mat2))

  # Multiple ->
  expect_error(einsum("ij,jk->k->a", mat1, mat2))

  # Wrong index in result
  expect_error(einsum("ij,jk -> a", mat1, mat2))

})

test_that("einsum_generator gives appropriate error messages", {

  # Missing ->
  expect_error(einsum_generator("ij,jk"))

  # Multiple ->
  expect_error(einsum_generator("ij,jk->k->a"))

  # Wrong index in result
  expect_error(einsum_generator("ij,jk -> a"))

})



# quad_mm <- einsum_generator("ij,kj,kl,l->i")
#
# i <- 3
# j <- 500
# k <- 70
# l <- 200
#
# mat1 <- array(rnorm(n = i * j), dim = c(i,j))
# mat2 <- array(rnorm(n = k * j), dim = c(k,j))
# mat3 <- array(rnorm(n = k * l), dim = c(k,l))
# mat4 <- array(rnorm(n = l), dim = c(l))
#
# bench::mark(base_r = c((mat1 %*% t(mat2)) %*% mat3 %*% mat4),
#             einsum = c(einsum("ij,kj,kl,l->i", mat1, mat2, mat3, mat4)),
#             einsum_compiled = c(quad_mm(mat1, mat2, mat3, mat4)))
const-ae/einsum documentation built on Sept. 3, 2023, 3:50 a.m.