Nothing
testthat::test_that("COT objects form", {
testthat::skip_on_cran()
causalOT:::torch_check()
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
#cot defaults
testthat::expect_silent(causalOT:::COT$new(source = x, target = y))
#setup
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y))
testthat::expect_silent(cot$weights <- c(1, rep(0,n-1)))
testthat::expect_true(length(cot$.__enclos_env__$private$optimizer$penalty$lambda) == 8)
testthat::expect_true(is.na(cot$.__enclos_env__$private$optimizer$penalty$delta))
testthat::expect_equal(as_numeric(cot$weights), c(1, rep(0,n-1)))
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(balance.formula = "~X1+X2")))
testthat::expect_true(inherits(cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$nn_holder, "cotDualBfOpt"))
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(lambda = 0,
debias = TRUE)))
# entropy not debiased
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(
debias = FALSE)))
testthat::expect_true(all(class(cot$.__enclos_env__$private$optimizer) !=
"cotDualTrain"))
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(
debias = TRUE,
balance.formula = "~X1*X3")))
testthat::expect_true("bf" %in% names(cot$.__enclos_env__$private$optimizer$.__enclos_env__$private))
testthat::expect_true(inherits(cot$.__enclos_env__$private$optimizer,
"cotDualTrain"))
# debias without torch optim should throw error
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(
debias = TRUE,
torch.optimizer = NULL)
))
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(
debias = TRUE,
opt.direction = "primal")
))
testthat::expect_true(inherits(cot$.__enclos_env__$private$optimizer, "OTProblem"))
})
testthat::test_that("ent works", {
testthat::skip_on_cran()
testthat::skip_on_ci()
causalOT:::torch_check()
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
lambda <- 10
wlist <- list(a,a)
#debias = FALSE
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = FALSE,
niter = 4,
opt.direction = "primal")))
ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
testthat::expect_true(inherits(ot[[ls(ot)]]$C_xy, "costTensor"))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
)
testthat::expect_warning(cot$solve())
cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = FALSE,
niter = 1,
opt.direction = "primal",
line_search_fn = "strong_wolfe"))
testthat::expect_equivalent(cot$.__enclos_env__$private$torch_optim_args, list(line_search_fn = "strong_wolfe"))
testthat::expect_silent(cot$solve())
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(niter = 1,
balance.formula = "~X1",
debias = FALSE)))
osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
testthat::expect_true(as_numeric((cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target)$abs()$max()$to(device = "cpu")$item()) < 1.5)
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = FALSE,
balance.formula = "~X1+X2", niter = 1, verbose = FALSE, delta = 0.1)))
osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve() ))
testthat::expect_true(as_numeric((cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target)$abs()$max()$item()) < 0.9)
})
testthat::test_that("ent debiased works, online", {
testthat::skip_on_cran()
testthat::skip_on_ci()
causalOT:::torch_check()
causalOT:::rkeops_check()
testthat::skip_if_not_installed("rkeops")
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
lambda <- 10
wlist <- list(a,a)
#debias = FALSE
mess <- testthat::capture_output(cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = FALSE,
niter = 4,
opt.direction = "primal",
cost.online = "online")))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
)
testthat::expect_warning(cot$solve())
cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = FALSE,
niter = 1,
opt.direction = "primal",
line_search_fn = "strong_wolfe",
cost.online = "online"))
testthat::expect_equivalent(cot$.__enclos_env__$private$torch_optim_args, list(line_search_fn = "strong_wolfe"))
testthat::expect_silent(cot$solve())
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(niter = 1,
balance.formula = "~X1",
debias = FALSE,
cost.online = "online")))
osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
testthat::expect_true(as_numeric(max(abs(cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target))) < 1.5)
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = FALSE,
balance.formula = "~X1+X2", niter = 1, verbose = FALSE, delta = 0.1,
cost.online = "online")))
osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve() ))
testthat::expect_true(as_numeric(max(abs(cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target))) < 0.9)
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = TRUE,
balance.formula = "~X1+X2", niter = 1, verbose = FALSE, delta = 0.1,
cost.online = "online")))
osqpout <- testthat::capture_output(
cot$solve() )
testthat::expect_true(as_numeric(max(abs(cot$.__enclos_env__$private$source$balance_functions$transpose(2,1)$matmul(cot$.__enclos_env__$private$source$weights) - cot$.__enclos_env__$private$source$balance_target))) < 0.9)
})
testthat::test_that("grid_search function works",{
testthat::skip_on_cran()
testthat::skip_on_ci()
causalOT:::torch_check()
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
#setup
torch::torch_manual_seed(123123)
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(verbose = FALSE,
opt.direction = "dual",
# torch.scheduler = NULL,
niter = 1L,
nboot = 10L
)))
cot$solve()
output <- cot$grid_search()
testthat::expect_true(output$penalty[1] < 2704243.27850342 + 1 )
testthat::expect_warning(cot$grid_search())
})
testthat::test_that("NNM works",{
testthat::skip_on_cran()
testthat::skip_on_ci()
causalOT:::torch_check()
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
lambda <- 10
wlist <- list(a,a)
#### tensor ####
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(lambda = 0,
debias = FALSE,
niter = 4,
opt.direction = "primal")))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$optimizer, "NNM")
)
ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
testthat::expect_true(inherits(ot[[ls(ot)]]$C_xy, "costTensor"))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
)
testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
testthat::expect_invisible(cot$solve())
res <- cot$grid_search()
testthat::expect_true(all(names(res) %in% c("weight", "penalty", "metric","penalty.grid")))
cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = TRUE, lambda = 0))
ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
cot <- causalOT:::COT$new(source = x, target = y,
options = list(balance.formula = "~.", lambda = 0))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$optimizer, "cotDualTrain")
)
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(niter = 1,
balance.formula = "~X1",
debias = FALSE)))
testthat::expect_equal(class(cot$.__enclos_env__$private$optimizer), c("OTProblem", "R6"))
osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
#### keops ####
causalOT:::rkeops_check()
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(lambda = 0,
debias = FALSE,
niter = 4,
cost.online = "online",
opt.direction = "primal")))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$optimizer, "NNM")
)
ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
testthat::expect_true(inherits(ot[[ls(ot)]]$C_xy, "costOnline"))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$torch_optim, "optim_lbfgs")
)
testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
testthat::expect_invisible(cot$solve())
res <- cot$grid_search()
testthat::expect_true(all(names(res) %in% c("weight", "penalty", "metric","penalty.grid")))
cot <- causalOT:::COT$new(source = x, target = y,
options = list(debias = TRUE,
cost.online = "online",lambda = 0))
ot <- cot$.__enclos_env__$private$optimizer$.__enclos_env__$private$ot_objects
testthat::expect_true(isFALSE(ot[[ls(ot)]]$debias))
cot <- causalOT:::COT$new(source = x, target = y,
options = list(balance.formula = "~.",
cost.online = "online",lambda = 0))
testthat::expect_true(
inherits(cot$.__enclos_env__$private$optimizer, "cotDualTrain")
)
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(niter = 1,
cost.online = "online",
balance.formula = "~X1",
debias = FALSE)))
testthat::expect_equal(class(cot$.__enclos_env__$private$optimizer), c("OTProblem", "R6"))
osqpout <- testthat::capture_output(testthat::expect_warning(cot$solve()))
})
testthat::test_that("weights function works",{
testthat::skip_on_cran()
testthat::skip_on_ci()
causalOT:::torch_check()
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
#setup
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(verbose = FALSE,
opt.direction = "dual",
# torch.scheduler = NULL,
niter = 1L,
nboot = 10L
)))
cot$solve()
testthat::expect_silent(w <- cot$weights)
testthat::expect_true(all(w[1] == w))
# run rkeops version
causalOT:::rkeops_check() #skips if rkeops fails or is not installed
testthat::expect_silent(cot <- causalOT:::COT$new(source = x, target = y,
options = list(verbose = FALSE,
opt.direction = "dual",
cost.online = "online",
# torch.scheduler = NULL,
niter = 1L,
nboot = 10L
)))
cot$solve()
testthat::expect_silent(w <- cot$weights)
testthat::expect_true(all(w[1] == w))
})
testthat::test_that("cotOptions error checking works", {
testthat::skip_on_cran()
testthat::skip_on_ci()
causalOT:::torch_check()
set.seed(234808)
n <- 15
m <- 13
d <- 3
penalty <- 10
x <- matrix(stats::rnorm(n*d), n, d)
y <- matrix(stats::rnorm(m*d), m, d)
a <- rep(1/n, n)
b <- rep(1/m, m)
delta <- 0.5
testthat::expect_silent(opt <- cotOptions())
testthat::expect_warning(opt <- cotOptions(opt.direction = "dual", torch.optimizer = torch::optim_lbfgs))
testthat::expect_true(inherits(opt$torch.optimizer, "optim_rmsprop"))
testthat::expect_error(opt <- cotOptions(torch.optimizer = sum))
testthat::expect_error(opt <- cotOptions(torch.scheduler = sum))
opt <- cotOptions(debias = FALSE, opt.direction = "dual")
testthat::expect_true(opt$opt.direction == "primal")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.