tests/testthat/test-balanceFunctions.R

testthat::test_that("balanceFunction class forms",{
  
  set.seed(20348)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  testthat::expect_silent(bf <- causalOT:::balanceFunction$new(source = x, 
                            target = y, 
                            a = a, b = b, 
                            delta = delta))
  
  testthat::expect_silent(ebw <- causalOT:::EntropyBW$new(source = x, 
                                                          target = y, 
                                                          a = a, b = b))
  
  testthat::skip_if_not_installed("osqp")
  testthat::expect_output(sbw <- causalOT:::SBW$new(source = x, 
                 target = y, 
                 a = a, b = b))
  
  
})

testthat::test_that("balanceFunction optimizes",{
  set.seed(20348)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  #L2
  testthat::skip_if_not_installed("osqp")
  testthat::expect_output(sbw <- causalOT:::SBW$new(source = x, 
                                                    target = y, 
                                                    a = a, b = b))
  
  delta <- sbw$gridInit(NULL, 7L)
  testthat::expect_output(w0 <- sbw$solve(delta[1], NULL))
  testthat::expect_equal(w0, a, tol = 1e-2)
  
  testthat::expect_output(w1 <- sbw$solve(delta[7], NULL))
  testthat::expect_true(is.character(all.equal(w1,a)))
  
  # entropy
  testthat::skip_if_not_installed("lbfgsb3c")
  testthat::expect_silent(ebw <- causalOT:::EntropyBW$new(source = x, 
                                                          target = y, 
                                                          a = a, b = b))
  delta <- ebw$gridInit(NULL, 7L)
  testthat::expect_warning(w0_e <- ebw$solve(delta[1], NULL))
  testthat::expect_equal(c(w0_e), a)
  testthat::expect_equal(c(w0_e), w0)
  
  testthat::expect_silent(w1_e <- ebw$solve(delta[7], NULL))
  testthat::expect_true(all(c(crossprod(ebw$A, w1_e)) < 1e-3))
  
  
})

testthat::test_that("functions catch errors", {
  set.seed(20348)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(0, n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  testthat::skip_if_not_installed("lbfgsb3c")
  testthat::expect_error(ebw <- causalOT:::EntropyBW$new(source = x, 
                                                         target = y, 
                                                         a = a, b = b))
  testthat::expect_silent(ebw <- causalOT:::EntropyBW$new(source = y, 
                                                          target = x, 
                                                          a = b, b = a))
  
  testthat::skip_if_not_installed("osqp")
  testthat::expect_error(sbw <- causalOT:::SBW$new(source = x, 
                                                    target = y, 
                                                    a = a, b = b))
  testthat::expect_output(sbw <- causalOT:::SBW$new(source = y, 
                                                   target = x, 
                                                   a = b, b = a))
  
  
})
ericdunipace/causalOT documentation built on Aug. 8, 2024, 6:14 p.m.