tests/testthat/test-cotClass.R

testthat::test_that("COT objects form", {
  testthat::skip_on_cran()
  causalOT:::torch_check()
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  #cot defaults
  testthat::expect_silent(causalOT:::COT$new(source = x, target = y))
  
  #setup
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y))
  testthat::expect_silent(cot$weights <- c(1, rep(0,n-1)))
  testthat::expect_true(length(cot$.__enclos_env__$private$optimizer$penalty$lambda) == 8)
  testthat::expect_true(is.na(cot$.__enclos_env__$private$optimizer$penalty$delta))
  testthat::expect_equal(as_numeric(cot$weights), c(1, rep(0,n-1)))
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(balance.formula = "~X1+X2")))
  testthat::expect_true(inherits(cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$nn_holder, "cotDualBfOpt"))
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(lambda = 0,
                                                                   debias = TRUE)))
  
  # entropy not debiased
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(
                                                                   debias = FALSE)))
  testthat::expect_true(all(class(cot$.__enclos_env__$private$optimizer) !=
                                 "cotDualTrain"))
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(
                                                                   debias = TRUE,
                                                                   balance.formula = "~X1*X3")))
  testthat::expect_true("bf" %in% names(cot$.__enclos_env__$private$optimizer$.__enclos_env__$private))
  testthat::expect_true(inherits(cot$.__enclos_env__$private$optimizer,
                              "cotDualTrain"))
  
  # debias without torch optim should throw error
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(
                                                                   debias = TRUE,
                                                                   torch.optimizer = NULL)
                                                    ))
  
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(
                                                      debias = TRUE,
                                                      opt.direction = "primal")
  ))
  testthat::expect_true(inherits(cot$.__enclos_env__$private$optimizer, "OTProblem"))
  
})

testthat::test_that("ent works", {
  testthat::skip_on_cran()
  testthat::skip_on_ci()
  causalOT:::torch_check()
  
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  lambda <- 10
  wlist <- list(a,a)
  
  #debias = FALSE
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(debias = FALSE, 
                                                                   niter = 4,
                                                                   opt.direction = "primal")))
  ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
  testthat::expect_true(inherits(ot[[ls(ot)]]$C_xy, "costTensor"))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
  )
  testthat::expect_warning(cot$solve())
  cot <- causalOT:::COT$new(source = x, target = y,
                            options = list(debias = FALSE, 
                                           niter = 1,
                                           opt.direction = "primal",
                                           line_search_fn = "strong_wolfe"))
  testthat::expect_equivalent(cot$.__enclos_env__$private$torch_optim_args, list(line_search_fn = "strong_wolfe"))
  testthat::expect_silent(cot$solve())
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(niter = 1,
                                                                   balance.formula = "~X1",
                                                                   debias = FALSE)))
  
  osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
  
  testthat::expect_true(as_numeric((cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target)$abs()$max()$to(device = "cpu")$item()) < 1.5)
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(debias = FALSE,
                                                                   balance.formula = "~X1+X2", niter = 1, verbose = FALSE, delta = 0.1)))
  osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve() ))
  testthat::expect_true(as_numeric((cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target)$abs()$max()$item()) < 0.9)
  
})

testthat::test_that("ent debiased works, online", {
  testthat::skip_on_cran()
  testthat::skip_on_ci()
  causalOT:::torch_check()
  causalOT:::rkeops_check()
  
  testthat::skip_if_not_installed("rkeops")
  
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  lambda <- 10
  wlist <- list(a,a)
  
  #debias = FALSE
  mess <- testthat::capture_output(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(debias = FALSE, 
                                                                   niter = 4,
                                                                   opt.direction = "primal",
                                                                   cost.online = "online")))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
  )
  testthat::expect_warning(cot$solve())
  cot <- causalOT:::COT$new(source = x, target = y,
                            options = list(debias = FALSE, 
                                           niter = 1,
                                           opt.direction = "primal",
                                           line_search_fn = "strong_wolfe",
                                           cost.online = "online"))
  testthat::expect_equivalent(cot$.__enclos_env__$private$torch_optim_args, list(line_search_fn = "strong_wolfe"))
  testthat::expect_silent(cot$solve())
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(niter = 1,
                                                                   balance.formula = "~X1",
                                                                   debias = FALSE,
                                                                   cost.online = "online")))
  
  osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
  
  testthat::expect_true(as_numeric(max(abs(cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target))) < 1.5)
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(debias = FALSE,
                                                                   balance.formula = "~X1+X2", niter = 1, verbose = FALSE, delta = 0.1,
                                                                   cost.online = "online")))
  osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve() ))
  testthat::expect_true(as_numeric(max(abs(cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target))) < 0.9)
  
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(debias = TRUE,
                                                                   balance.formula = "~X1+X2", niter = 1, verbose = FALSE, delta = 0.1,
                                                                   cost.online = "online")))
  
  osqpout <- testthat::capture_output(
    cot$solve() )

  testthat::expect_true(as_numeric(max(abs(cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target))) < 0.9)
  
})

testthat::test_that("grid_search function works",{
  testthat::skip_on_cran()
  testthat::skip_on_ci()
  causalOT:::torch_check()
  
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  #setup
  torch::torch_manual_seed(123123)
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(verbose = FALSE, 
                                                                   opt.direction = "dual",
                                                                   # torch.scheduler = NULL,
                                                                   niter = 1L,
                                                                   nboot = 10L
                                                                   )))
  cot$solve()
  output <- cot$grid_search()
  testthat::expect_true(output$penalty[1] < 2704243.27850342 + 1 ) 
  testthat::expect_warning(cot$grid_search())
})

testthat::test_that("NNM works",{
  testthat::skip_on_cran()
  testthat::skip_on_ci()
  causalOT:::torch_check()
  
  
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  lambda <- 10
  wlist <- list(a,a)
  
  #### tensor ####
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(lambda = 0,
                                                                   debias = FALSE, 
                                                                   niter = 4,
                                                                   opt.direction = "primal")))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$optimizer, "NNM")
  )
  
  ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
  testthat::expect_true(inherits(ot[[ls(ot)]]$C_xy, "costTensor"))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
  )
  testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
  testthat::expect_invisible(cot$solve())
  res <- cot$grid_search()
  testthat::expect_true(all(names(res) %in% c("weight", "penalty", "metric","penalty.grid")))
  
  cot <- causalOT:::COT$new(source = x, target = y,
                            options = list(debias = TRUE, lambda = 0))
  ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
  testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
  
  cot <- causalOT:::COT$new(source = x, target = y,
                            options = list(balance.formula = "~.", lambda = 0))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$optimizer, "cotDualTrain")
  )
  
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(niter = 1,
                                                                   balance.formula = "~X1",
                                                                   debias = FALSE)))
  testthat::expect_equal(class(cot$.__enclos_env__$private$optimizer), c("OTProblem", "R6"))
  osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
  
  #### keops ####
  causalOT:::rkeops_check()
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(lambda = 0,
                                                                   debias = FALSE, 
                                                                   niter = 4,
                                                                   cost.online = "online",
                                                                   opt.direction = "primal")))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$optimizer, "NNM")
  )
  
  ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
  testthat::expect_true(inherits(ot[[ls(ot)]]$C_xy, "costOnline"))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
  )
  testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
  testthat::expect_invisible(cot$solve())
  res <- cot$grid_search()
  testthat::expect_true(all(names(res) %in% c("weight", "penalty", "metric","penalty.grid")))
  
  cot <- causalOT:::COT$new(source = x, target = y,
                            options = list(debias = TRUE, 
                                           cost.online = "online",lambda = 0))
  ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
  testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
  
  cot <- causalOT:::COT$new(source = x, target = y,
                            options = list(balance.formula = "~.", 
                                           cost.online = "online",lambda = 0))
  testthat::expect_true(
    inherits(cot$.__enclos_env__$private$optimizer, "cotDualTrain")
  )
  
  
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(niter = 1,
                                                                   cost.online = "online",
                                                                   balance.formula = "~X1",
                                                                   debias = FALSE)))
  testthat::expect_equal(class(cot$.__enclos_env__$private$optimizer), c("OTProblem", "R6"))
  osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
})

testthat::test_that("weights function works",{
  testthat::skip_on_cran()
  testthat::skip_on_ci()
  causalOT:::torch_check()
  
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  #setup
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(verbose = FALSE, 
                                                                   opt.direction = "dual",
                                                                   # torch.scheduler = NULL,
                                                                   niter = 1L,
                                                                   nboot = 10L
                                                    )))
  cot$solve()
  testthat::expect_silent(w <- cot$weights)
  testthat::expect_true(all(w[1] == w))
  
  # run rkeops version
  causalOT:::rkeops_check() #skips if rkeops fails or is not installed
  testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
                                                    options = list(verbose = FALSE, 
                                                                   opt.direction = "dual",
                                                                   cost.online = "online",
                                                                   # torch.scheduler = NULL,
                                                                   niter = 1L,
                                                                   nboot = 10L
                                                    )))
  cot$solve()
  testthat::expect_silent(w <- cot$weights)
  testthat::expect_true(all(w[1] == w))
})

testthat::test_that("cotOptions error checking works", {
  testthat::skip_on_cran()
  testthat::skip_on_ci()
  causalOT:::torch_check()
  set.seed(234808)
  n <- 15
  m <- 13
  d <- 3
  penalty <- 10
  x <- matrix(stats::rnorm(n*d), n, d)
  y <- matrix(stats::rnorm(m*d), m, d)
  a <- rep(1/n, n)
  b <- rep(1/m, m)
  delta <- 0.5
  
  testthat::expect_silent(opt <- cotOptions())
  
  testthat::expect_warning(opt <- cotOptions(opt.direction = "dual", torch.optimizer = torch::optim_lbfgs))
  testthat::expect_true(inherits(opt$torch.optimizer, "optim_rmsprop"))
  testthat::expect_error(opt <- cotOptions(torch.optimizer = sum))
  testthat::expect_error(opt <- cotOptions(torch.scheduler = sum))
  opt <- cotOptions(debias = FALSE, opt.direction = "dual")
  testthat::expect_true(opt$opt.direction == "primal")
})

Try the causalOT package in your browser

Any scripts or data that you put into this service are public.

causalOT documentation built on May 29, 2024, 6:16 a.m.