tests/testthat/test-dual_opts.R

testthat::test_that("test forward functions", {
  causalOT:::torch_check()
  
  n <- 256
  
  z  <- matrix(rnorm(n/2*2), n/2, 2) + matrix(c(0,.5), n/2,2, byrow = TRUE)
  x  <- matrix(rnorm(n *2), n, 2)
  
  
  m1 <- Measure(x, target.values = colMeans(z), adapt = "weights")
  mt <- Measure(z)
  
  gamma <- torch::torch_tensor(stats::rnorm(n), 
                               device = m1$device,
                               dtype = m1$dtype)
  
  ot_tens <- causalOT:::OT$new(x = x, y = z, debias = TRUE, tensorized = "tensorized", penalty = 10)
  
  C_xy <- ot_tens$C_xy$data
  C_xx <- ot_tens$C_xx$data
  
  a_log<- causalOT:::log_weights(ot_tens$a)
  b_log<- causalOT:::log_weights(ot_tens$b)
  lambda <- ot_tens$penalty
  delta <- 0.01
  
  dual_forwards <- torch::jit_compile(causalOT:::dual_forward_code_tensorized)
  
  a1_script <- dual_forwards$calc_w1(gamma$detach(), C_xy, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))
  
  a2_script <- dual_forwards$calc_w2(gamma$detach(), C_xx, a_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))
  
  g    <- b_log -  ((gamma$detach() + a_log)$detach()$view(c(n,1))-C_xy/lambda)$logsumexp(1)
  K    <- (gamma$detach() + a_log)$view(c(n,1)) + g - C_xy/lambda
  a1   <- (K )$logsumexp(2)$exp()$detach()
  a1   <- as.numeric((a1/a1$sum())$to(device = "cpu"))
  
  testthat::expect_equal(a1, as.numeric(a1_script$to(device = "cpu")), label = "calc_w1")
  
  f_star <- gamma$detach() + a_log
  K2   <- (f_star$view(c(n,1)) + f_star -C_xx/lambda)
  norm  <-  K2$view(c(n*n,1))$logsumexp(1)
  a2     <- as.numeric((K2 - norm)$logsumexp(1)$exp()$detach()$to(device = "cpu"))
  
  testthat::expect_equal(a2, as.numeric(a2_script$to(device = "cpu")), label = "calc_w2")
  
  testthat::expect_equal(gamma$dot(a1_script-a2_script)$item() * - 1,
                         dual_forwards$cot_dual(gamma$detach(), C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))$loss$item(), label = "loss calc")
  
  beta1 <- torch::torch_tensor(stats::rnorm(2), 
                               device = gamma$device,
                               dtype = gamma$dtype)
  
  f_prime <- gamma$detach() + a_log #- m1$balance_functions$matmul(beta1_det)
  beta1_det <- beta1$detach()
  g    <- b_log -  (f_prime$view(c(n,1))-C_xy/lambda)$logsumexp(1)
  K    <- (f_prime$view(c(n,1)) + g - C_xy/lambda)
  a1   <- (K )$logsumexp(2)$exp()$detach()
  a1   <- as.numeric((a1)$to(device = "cpu"))

  f_star <- gamma$detach() + a_log#- m1$balance_functions$matmul(beta2_det)
  K2   <- (f_star$view(c(n,1)) + f_star -C_xx/lambda)
  norm  <-  K2$view(c(n*n,1))$logsumexp(1)
  a2     <- as.numeric((K2 - norm)$logsumexp(1)$log_softmax(1)$exp()$detach()$to(device = "cpu"))
  
  testthat::expect_equal(a1, as.numeric(a1_script$to(device = "cpu")), label = "calc_w1")
  testthat::expect_equal(a2, as.numeric(a2_script$to(device = "cpu")), label = "calc_w2")
  
  res <- dual_forwards$cot_dual(gamma$detach(), C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))

  loss_gamma <- gamma$dot(a1_script-a2_script)$item()
  
  testthat::expect_equal(loss_gamma * - 1,
                         res$loss$item(), label = "loss calc",
                         tol = 1e-5)
  
  
  testthat::expect_equal( (a1_script-a2_script)$norm()$item(),
                          res$avg_diff$item(),
                          tol = 1e-5)
  
  testthat::expect_equal(res$bf_diff$item(), 0.0,
                         tol = 1e-5)

  
  
  diff1 <- (m1$balance_functions$transpose(2,1)$matmul(a1_script) - m1$balance_target)
  beta_check1 <- diff1 * beta1$detach() - delta * beta1$detach()$abs()

  loss_beta = diff1$dot(beta1) - delta * beta1$abs()$sum()
  
  loss <- loss_gamma + loss_beta

  loss$multiply_(-1.0) #to make min
  
  
  res2 <- dual_forwards$cot_bf_dual(gamma, C_xy, C_xx, a_log,
                                    b_log,
                                   torch::jit_scalar(lambda),
                                   torch::jit_scalar(as.integer(n)),
                                   beta1, m1$balance_functions,
                                   m1$balance_target,
                                   torch::jit_scalar(delta))
  
  testthat::expect_equal(loss$item(), res2$loss$item(), tol = 1e-4)
  testthat::expect_equal(res2$avg_diff$item(), res$avg_diff$item())
  testthat::expect_equal(res2$bf_diff$item(), diff1$abs()$max()$item())
  
  
  # test keops versions
  causalOT:::rkeops_check()
  
  ot_keops <- causalOT:::OT$new(x = x, y = z, debias = TRUE, tensorized = "online", penalty = 10)
  
  C_xy <- ot_keops$C_xy
  C_xx <- ot_keops$C_xx
  
  keops_fun <- causalOT:::dual_forwards_keops
  
  a1_script <- keops_fun$calc_w1(gamma$detach(), C_xy, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))
  
  a2_script <- keops_fun$calc_w2(gamma$detach(), C_xx, a_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))
  
  res_keops <- keops_fun$cot_dual(
    gamma, C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n))
  )
  res_keops_2 <- keops_fun$cot_bf_dual(
    gamma, C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)),
    beta1, m1$balance_functions,
    m1$balance_target,
    torch::jit_scalar(delta)
  )
  
  testthat::expect_equal(as.numeric(a1_script$to(device = "cpu")), a1,
                         tol = 1e-3)
  testthat::expect_equal(as.numeric(a2_script$to(device = "cpu")), a2,
                         tol = 1e-3)  
  
  testthat::expect_equal(loss_gamma * -1,  res_keops$loss$item(),
                         tol = 1e-5)
  testthat::expect_equal(loss$item(),  res_keops_2$loss$item(), tol = 1e-5)
  testthat::expect_equal(diff1$abs()$max()$item(), res_keops_2$bf_diff$item(), tol = 1e-5)
  testthat::expect_equal(as.numeric(res2$beta_check$to(device = "cpu")),
                         as.numeric(res_keops_2$beta_check$to(device = "cpu")), tol = 1e-5 )

  
})

testthat::test_that("dual nn modules work as expected",{
  causalOT:::torch_check()
  set.seed(1231)
  
  n <- 256
  
  z  <- matrix(rnorm(n/2*2), n/2, 2) + matrix(c(0,.5), n/2,2, byrow = TRUE)
  x  <- matrix(rnorm(n *2), n, 2)
  
  
  m1 <- Measure(x, target.values = colMeans(z), adapt = "weights")
  mt <- Measure(z)
  
  opt <- causalOT:::cotDualOpt$new(n, 2)
  
  gamma <- torch::torch_tensor(stats::rnorm(n), 
                               device = m1$device,
                               dtype = m1$dtype)
  
  torch::with_no_grad(opt$gamma$copy_(gamma))
  
  ot_tens <- causalOT:::OT$new(x = x, y = z, debias = TRUE, tensorized = "tensorized", penalty = 10)
  
  C_xy <- ot_tens$C_xy
  C_xx <- ot_tens$C_xx
  
  a_log<- causalOT:::log_weights(ot_tens$a)
  b_log<- causalOT:::log_weights(ot_tens$b)
  lambda <- ot_tens$penalty
  delta <- 0.01
  
  dual_forwards <- torch::jit_compile(causalOT:::dual_forward_code_tensorized)
  
  res <- dual_forwards$cot_dual(gamma$detach(), C_xy$data, C_xx$data, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))
  
  res_mod <- opt$forward(C_xy, C_xx, a_log, b_log, lambda)
  
  tests <- function(res, res_mod, opt, gamma) {
    testthat::expect_equal(res, res_mod)
    testthat::expect_equal(res$loss$item(), res_mod$loss$item(),
                           tol = 1e-5) 
    testthat::expect_equal(res$avg_diff$item(), res_mod$avg_diff$item(),
                           tol = 1e-5)
    testthat::expect_equal(res$bf_diff$item(), res_mod$bf_diff$item(),
                           tol = 1e-5)
    
    
    param <- opt$clone_param()
    testthat::expect_equal(as.numeric(param$gamma$to(device = "cpu")), as.numeric(gamma$to(device = "cpu")), tol = 1e-5)
    testthat::expect_true(param$gamma$requires_grad == FALSE)
    testthat::expect_true(opt$gamma$requires_grad == TRUE)
    
    # test convergence function
    param$gamma<- param$gamma * 0.0
    testthat::expect_true(isFALSE(opt$converged(res_mod,
                                                1e-5, 1e-6, param,
                                                tol = 1e-8, lambda, delta)))
                          
    testthat::expect_true(opt$converged(res_mod,
                                        1e-5, 1e-6, param,
                                        tol = 300, lambda, delta)
                                                
                          )
                          
  }
  
  tests(res, res_mod, opt, gamma)
  
  # bf
  optbf <- causalOT:::cotDualBfOpt$new(n,2)
  
  torch::with_no_grad({
    optbf$beta$copy_(c(1,2))
    optbf$gamma$copy_(gamma)
    }
    )
  beta1 <- optbf$beta$detach()$clone()
  res <- dual_forwards$cot_bf_dual(
    gamma$detach() - m1$balance_functions$matmul(beta1), 
                                   C_xy$data, C_xx$data, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)),
                                   beta1, m1$balance_functions,
                                   m1$balance_target,
                                   torch::jit_scalar(delta)
                                   )
  
  res_mod <- optbf$forward(C_xy, C_xx, a_log, b_log, lambda,
                           m1$balance_functions,
                           m1$balance_target,
                           torch::jit_scalar(delta))
  
  tests(res, res_mod, optbf,  gamma -  m1$balance_functions$matmul(beta1))
  
  #### check keops opt ####
  
  causalOT:::rkeops_check()
  
  ot_keops <- causalOT:::OT$new(x = x, y = z, debias = TRUE, tensorized = "online", penalty = 10)
  
  C_xy <- ot_keops$C_xy
  C_xx <- ot_keops$C_xx
  
  opt <- causalOT:::cotDualOpt_keops$new(n, 2)
  torch::with_no_grad(opt$gamma$copy_(gamma))
  
  keops_fun <- causalOT:::dual_forwards_keops
  
  res <- keops_fun$cot_dual(gamma$detach(), C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)))
  
  res_mod <- opt$forward(C_xy, C_xx, a_log, b_log, lambda)
  tests(res, res_mod, opt, gamma)
  
  
  optbf <- causalOT:::cotDualBfOpt_keops$new(n, 2)
  torch::with_no_grad({
    optbf$gamma$copy_(gamma)
    optbf$beta$copy_(beta1)
    })
  
  res <- keops_fun$cot_bf_dual(
    gamma$detach() - m1$balance_functions$matmul(beta1), 
    C_xy, C_xx, a_log, b_log, torch::jit_scalar(lambda), torch::jit_scalar(as.integer(n)),
    beta1, m1$balance_functions,
    m1$balance_target,
    torch::jit_scalar(delta)
  )
  
  res_mod <- optbf$forward(C_xy, C_xx, a_log, b_log, lambda,
                           m1$balance_functions,
                           m1$balance_target,
                           torch::jit_scalar(delta))
  tests(res, res_mod, optbf, gamma - m1$balance_functions$matmul(beta1))
  testthat::expect_true(all(as.logical((optbf$beta == c(1,2))$to(device = "cpu"))))
  
})

testthat::test_that("training function works for dual optimizer",{
  causalOT:::torch_check()
  set.seed(1231)
  
  n <- 256
  
  z  <- matrix(rnorm(n/2*2), n/2, 2) + matrix(c(0,.5), n/2,2, byrow = TRUE)
  x  <- matrix(rnorm(n *2), n, 2)
  
  
  m1 <- Measure(x, target.values = colMeans(z), adapt = "weights")
  mt <- Measure(z)
  
  cot <- causalOT:::cotDualTrain$new(m1,mt)
  otp <- OTProblem(m1,mt)
  
  cot_names <- names(formals(cot$setup_arguments))
  otp_names <- names(formals(otp$setup_arguments))
  
  testthat::expect_equal(cot_names, otp_names)
  
  # test that setup arg makes correct nn_holder
  testthat::expect_silent(cot$setup_arguments())
  testthat::expect_silent(otp$setup_arguments())
  
  testthat::expect_true(inherits(cot$.__enclos_env__$private$nn_holder, "cotDualBfOpt"))
  
  testthat::expect_true(length(cot$.__enclos_env__$private$nn_holder$beta) == ncol(z))
  testthat::expect_true(length(cot$.__enclos_env__$private$nn_holder$beta) == ncol(x))
  
  # no bf, tensor
  testthat::expect_true(inherits(causalOT:::cotDualTrain$new(Measure(x, adapt = "weights"), Measure(z))$setup_arguments()$.__enclos_env__$private$nn_holder, "cotDualOpt"))
  
  #no bf, keops
  causalOT:::rkeops_check()
  testthat::expect_true(inherits(causalOT:::cotDualTrain$new(Measure(x, adapt = "weights"), Measure(z))$setup_arguments(cost.online = "online")$.__enclos_env__$private$nn_holder, "cotDualOpt_keops"))
  
  #no bf, keops
  testthat::expect_true(inherits(causalOT:::cotDualTrain$new(Measure(x, adapt = "weights", target.values = colMeans(z)), Measure(z))$setup_arguments(cost.online = "online")$.__enclos_env__$private$nn_holder, "cotDualBfOpt_keops"))
  
  
  #### test weights function ###
  nnh <- cot$.__enclos_env__$private$nn_holder
  priv <- cot$.__enclos_env__$private
  a1 <- nnh$calc_w1(nnh$gamma, priv$C_xy$data, priv$a_log,
              priv$b_log, torch::jit_scalar(priv$lambda),
              torch::jit_scalar(as.integer(n)))
  a2 <- nnh$calc_w2(nnh$gamma, priv$C_xx$data, priv$a_log,
                     torch::jit_scalar(priv$lambda),
                    torch::jit_scalar(as.integer(n)))
  # debugonce(cot$.__enclos_env__$.__active__$weights)
  w <- cot$weights
  # testthat::expect_true(length(w) == 3)
  # testthat::expect_equal(as.numeric(w[[2]]), as.numeric(a1))
  # testthat::expect_equal(as.numeric(w[[3]]), as.numeric(a2))
  # testthat::expect_equal(as.numeric(w[[1]]), as.numeric(a2 + a1)*0.5)
  testthat::expect_equal(as.numeric(w$to(device = "cpu")), as.numeric(((a2 + a1)*0.5)$to(device = "cpu")))
  
  testthat::expect_equal(names(cot$.__enclos_env__$private$parameters),
                         c("gamma", "beta"))
  testthat::expect_equal(as.numeric(cot$.__enclos_env__$private$nn_holder$gamma$to(device = "cpu")),
                         as.numeric(cot$.__enclos_env__$private$parameters$gamma$to(device = "cpu")))
  testthat::expect_equal(rlang::obj_address(cot$.__enclos_env__$private$nn_holder$gamma),
                         rlang::obj_address(cot$.__enclos_env__$private$parameters$gamma))
  torch::with_no_grad(cot$.__enclos_env__$private$nn_holder$beta$copy_(c(1,2)))
  testthat::expect_equal(as.numeric(cot$.__enclos_env__$private$nn_holder$beta$to(device = "cpu")),
                         c(1,2))
  testthat::expect_equal(as.numeric(cot$.__enclos_env__$private$nn_holder$parameters$beta$to(device = "cpu")),
                         c(1,2))
  
  
  # test that set_lambda works
  testthat::expect_true(length(cot$penalty$lambda) > 1)
  priv <- cot$.__enclos_env__$private
  testthat::expect_true(priv$lambda == cot$penalty$lambda[1L])
  priv$set_lambda(4)
  testthat::expect_equal(priv$lambda , torch::jit_scalar(4))
  testthat::expect_error(priv$set_lambda(-1))
  
  # test that set_delta works
  testthat::expect_true(length(cot$penalty$delta) > 1)
  priv <- cot$.__enclos_env__$private
  testthat::expect_true(priv$delta == "numeric")
  priv$set_delta(.4)
  testthat::expect_equal(priv$delta , torch::jit_scalar(.4))
  testthat::expect_error(priv$set_lambda(-1))
  
  # test that set_penalties works
  priv <- cot$.__enclos_env__$private
  priv$set_penalties(c(lambda = Inf, delta = .4))
  testthat::expect_equal(priv$delta , torch::jit_scalar(.4))
  testthat::expect_equal(priv$lambda, torch::jit_scalar(359871.9312),
                         tol = 1e-5)
  
  testthat::expect_warning(priv$set_penalties(c(5,5)))
  testthat::expect_silent(priv$set_penalties(5))
  testthat::expect_error(priv$set_penalties(c(steve = 5,5)))
  
  priv$set_penalties(list(lambda = 50, delta = 5))
  testthat::expect_equal(priv$delta , torch::jit_scalar(5),
                         tol = 1e-5)
  testthat::expect_equal(priv$lambda, torch::jit_scalar(50),
                         tol = 1e-5)
  
  # make sure optimization setup works
  # debugonce(priv$torch_optim_setup)
  priv$torch_optim_setup(torch_optim = torch::optim_rmsprop,
                         torch_scheduler = torch::lr_multiplicative,
                         torch_args = NULL)
  testthat::expect_true(
    inherits(priv$opt, "optim_rmsprop")
  )
  testthat::expect_true(
    inherits(priv$sched, "lr_multiplicative")
  )
  # testthat::expect_equal(
  #   capture.output(print(priv$sched$lr_lambdas[[1]]))[1], 
  #   "function(epoch) {0.99}"
  # )
  
  testthat::expect_equal(as.numeric(cot$.__enclos_env__$private$nn_holder$gamma$to(device = "cpu")),
                         as.numeric(cot$.__enclos_env__$private$parameters$gamma$params$to(device = "cpu")),
                         tol = 1e-5)
  
  testthat::expect_equal(1e-2, #priv$lambda/100,
                         cot$.__enclos_env__$private$parameters$gamma$lr)
  
  
  testthat::expect_equal(as.numeric(cot$.__enclos_env__$private$nn_holder$beta$to(device = "cpu")),
                         as.numeric(cot$.__enclos_env__$private$parameters$beta$params$to(device = "cpu")))
  
  testthat::expect_equal(0.01,
                         cot$.__enclos_env__$private$parameters$beta$lr)
  
  # torch_optim_reset
  # debugonce(priv$torch_optim_reset)
  priv <- cot$.__enclos_env__$private
  old_add <- rlang::obj_address(priv$opt)
  priv$torch_optim_reset(0.44)
  testthat::expect_equal(0.44, #priv$lambda/100,
                         cot$.__enclos_env__$private$parameters$gamma$lr)
  testthat::expect_equal(0.44,
                         cot$.__enclos_env__$private$parameters$beta$lr)
  
  
  testthat::expect_true(rlang::obj_address(priv$opt) != old_add)
  testthat::expect_equal(rlang::obj_address(priv$nn_holder$gamma),
                        rlang::obj_address(priv$parameters$gamma$params))
  
  
  # optimization_loop
  # debugonce(priv$optimization_loop)
  out <- priv$optimization_loop(2, 1e-4)
  testthat::expect_true(out$iter == 2)
  
  testthat::expect_equal(rlang::obj_address(priv$nn_holder$gamma),
                         rlang::obj_address(priv$parameters$gamma$params))
  testthat::expect_equal(as.numeric(cot$.__enclos_env__$private$nn_holder$gamma$to(device = "cpu")),
                         as.numeric(cot$.__enclos_env__$private$parameters$gamma$params$to(device = "cpu")))
  testthat::expect_true(all(as.numeric(cot$.__enclos_env__$private$nn_holder$gamma$to(device = "cpu")) != 0) )
  
  # test parameters get set
  pars <- priv$parameters
  testthat::expect_true(pars$gamma$params$requires_grad == TRUE)
  testthat::expect_equal(as.numeric(pars$gamma$params$to(device = "cpu")), as.numeric(cot$.__enclos_env__$private$nn_holder$gamma$to(device = "cpu")))
  
  
  pars <- priv$parameters_get_set()
  ws    <- pars[[ls(pars)]]
  w2    <- cot$weights
  # testthat::expect_equal(as.numeric(ws[[1]]), as.numeric(w2[[1]]))
  testthat::expect_equal(as.numeric(ws$to(device = "cpu")), as.numeric(w2$to(device = "cpu")))
  
  ms <- cot$.__enclos_env__$private$measures
  m  <- NULL
  for (i in ls(ms)) {
    if(ms[[i]]$adapt == "weights") {
      m <- ms[[i]]
      break
    }
  }
  
  testthat::expect_error(priv$parameters_get_set(ws ))
  testthat::expect_error(priv$parameters_get_set(list(ws,ws) ))
  testthat::expect_silent(priv$parameters_get_set(list(ws) ))
  testthat::expect_equal(as.numeric(m$weights$to(device = "cpu")),
                         as.numeric(ws$to(device = "cpu")), tol = 1e-5)
  
  # testthat::expect_true(rlang::obj_address(cot$.__enclos_env__$private$nn_holder$gamma) == rlang::obj_address(pars$gamma))
  # 
  # pars <- priv$parameters_get_set(clone = TRUE)
  # testthat::expect_true(pars$gamma$requires_grad == FALSE)
  # testthat::expect_equal(as.numeric(pars$gamma), as.numeric(cot$.__enclos_env__$private$nn_holder$gamma))
  # testthat::expect_true(rlang::obj_address(cot$.__enclos_env__$private$nn_holder$gamma) != rlang::obj_address(pars$gamma))
  # 
  # pars$gamma <- pars$gamma * 0 + 1
  # priv$parameters_get_set(pars)
  # testthat::expect_equal(as.numeric(pars$gamma), as.numeric(priv$nn_holder$gamma))
  # testthat::expect_true(rlang::obj_address(cot$.__enclos_env__$private$nn_holder$gamma) != rlang::obj_address(pars$gamma))
  # testthat::expect_true(inherits(pars, "weightEnv"))
  
  
  # hyperparam
  cot <- causalOT:::cotDualTrain$new(m1,mt)
  cot$setup_arguments()
  # debugonce(cot$solve)
  cot$solve(niter = 1L, torch_optim = torch::optim_rmsprop, torch_scheduler = torch::lr_multiplicative)
  # debugonce(private$parameters_get_set)
  # debugonce(private$iterate_over_delta)
  # f
  # debugonce(cot$choose_hyperparameters)
  # cot$choose_hyperparameters()
  # debugonce(private$setup_choose_hyperparameters)
  
  
  
  testthat::expect_silent( cot$choose_hyperparameters(n_boot_lambda = 10, n_boot_delta = 10) )
  testthat::expect_true(is.numeric(cot$selected_delta[[1]]))
  testthat::expect_true(cot$selected_lambda < 359871.93116805560749)
})
ericdunipace/causalOT documentation built on Aug. 8, 2024, 6:14 p.m.