testthat::test_that("gridSearch class forms", {
causalOT:::torch_check()
set.seed(12312)
hain <- causalOT:::Hainmueller$new(128)
hain$gen_data()
data <- dataHolder(hain)
# by raw class
testthat::expect_silent(gs <- new("gridSearch", penalty_list = 0, nboot = 100L,
solver = hain,
method = "COT",
estimand = "ATE"))
testthat::expect_true(inherits(gs, "gridSearch"))
})
testthat::test_that("gridSearch SBW", {
testthat::skip_if_not_installed("osqp")
set.seed(12312)
hain <- causalOT:::Hainmueller$new(n=64)
hain$gen_data()
data <- dataHolder(hain)
method <- "SBW"
# by function SBW
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATT",
method = method,
options = list()))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_equal(gs@nboot, 1000L)
testthat::expect_true(inherits(gs@penalty_list, "numeric"))
testthat::expect_true(length(gs@penalty_list) == 20)
testthat::expect_true(inherits(gs@solver, "SBW"))
testthat::expect_true(gs@method == "SBW")
testthat::expect_true(gs@estimand == "ATT")
mess <- testthat::capture_output(suppressWarnings(cw <- causalOT:::cot_solve(gs)))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATC",
method = method,
options = list()))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_true(gs@estimand == "ATC")
mess <- testthat::capture_output(suppressWarnings(cw <- causalOT:::cot_solve(gs)))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATE",
method = method,
options = list()))
testthat::expect_true(inherits(gs, "ateClass"))
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
})
testthat::test_that("gridSearch EBW", {
testthat::skip_if_not_installed("lbfgsb3c")
set.seed(12312)
hain <- causalOT:::Hainmueller$new(n=64)
hain$gen_data()
data <- dataHolder(hain)
method <- "EntropyBW"
# by function SBW
gs <- causalOT:::gridSearch(data = data,
estimand = "ATT",
method = method,
options = list())
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_equal(gs@nboot, 1000L)
testthat::expect_true(inherits(gs@penalty_list, "numeric"))
testthat::expect_true(length(gs@penalty_list) == 20)
testthat::expect_true(inherits(gs@solver, "EntropyBW"))
testthat::expect_true(gs@method == "EntropyBW")
testthat::expect_true(gs@estimand == "ATT")
mess <- testthat::capture_output(testthat::expect_warning(cw <- causalOT:::cot_solve(gs)))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATC",
method = method,
options = list()))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_true(gs@estimand == "ATC")
mess <- testthat::capture_output(testthat::expect_warning(cw <- causalOT:::cot_solve(gs)))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATE",
method = method,
options = list()))
testthat::expect_true(inherits(gs, "ateClass"))
mess <- testthat::capture_output(testthat::expect_warning(cw <- causalOT:::cot_solve(gs)))
testthat::expect_true(inherits(cw, "causalWeights"))
})
testthat::test_that("gridSearch COT", {
causalOT:::torch_check()
set.seed(12312)
hain <- causalOT:::Hainmueller$new(n=64)
hain$gen_data()
data <- dataHolder(hain)
method <- "COT"
options = list(niter = 2L,
nboot = 2,
debias = TRUE,
torch.optimizer = torch::optim_lbfgs)
# by function
gs <- testthat::expect_warning(causalOT:::gridSearch(data = data,
estimand = "ATT",
method = method,
options = options))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_equal(gs@nboot, 2)
testthat::expect_true(inherits(gs@penalty_list, "numeric"))
testthat::expect_true(length(gs@penalty_list) == 8)
testthat::expect_true(inherits(gs@solver, "COT"))
testthat::expect_true(gs@method == "COT")
testthat::expect_true(gs@estimand == "ATT")
cw <- causalOT:::cot_solve(gs)
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(testthat::expect_warning(gs <- causalOT:::gridSearch(data = data,
estimand = "ATC",
method = method,
options = options)))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_true(gs@estimand == "ATC")
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(testthat::expect_warning(gs <- causalOT:::gridSearch(data = data,
estimand = "ATE",
method = method,
options = options)))
testthat::expect_true(inherits(gs, "ateClass"))
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
})
testthat::test_that("gridSearch NNM", {
causalOT:::torch_check()
set.seed(12312)
hain <- causalOT:::Hainmueller$new(n=64)
hain$gen_data()
data <- dataHolder(hain)
method <- "NNM"
options = list(niter = 2L,
nboot = 2,
debias = TRUE,
torch.optimizer = torch::optim_lbfgs)
# by function
gs <- testthat::expect_warning(causalOT:::gridSearch(data = data,
estimand = "ATT",
method = method,
options = options))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_equal(gs@nboot, 2)
testthat::expect_true(inherits(gs@penalty_list, "numeric"))
testthat::expect_true(length(gs@penalty_list) == 1)
testthat::expect_true(inherits(gs@solver, "COT"))
testthat::expect_true(gs@method == "NNM")
testthat::expect_true(gs@estimand == "ATT")
cw <- causalOT:::cot_solve(gs)
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(testthat::expect_warning(gs <- causalOT:::gridSearch(data = data,
estimand = "ATC",
method = method,
options = options)))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_true(gs@estimand == "ATC")
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(testthat::expect_warning(gs <- causalOT:::gridSearch(data = data,
estimand = "ATE",
method = method,
options = options)))
testthat::expect_true(inherits(gs, "ateClass"))
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
})
testthat::test_that("gridSearch SCM", {
testthat::skip_if_not_installed("osqp")
set.seed(12312)
hain <- causalOT:::Hainmueller$new(n=64)
hain$gen_data()
data <- dataHolder(hain)
method <- "SCM"
# by function
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATT",
method = method,
options = list(NULL)))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_equal(gs@nboot, 1000)
testthat::expect_true(inherits(gs@penalty_list, "numeric"))
testthat::expect_true(is.na(gs@penalty_list))
testthat::expect_true(length(gs@penalty_list) == 1)
testthat::expect_true(inherits(gs@solver, "SCM"))
testthat::expect_true(gs@method == "SCM")
testthat::expect_true(gs@estimand == "ATT")
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATC",
method = method,
options = options))
testthat::expect_true(inherits(gs, "gridSearch"))
testthat::expect_true(gs@estimand == "ATC")
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
mess <- testthat::capture_output(gs <- causalOT:::gridSearch(data = data,
estimand = "ATE",
method = method,
options = options))
testthat::expect_true(inherits(gs, "ateClass"))
mess <- testthat::capture_output(cw <- causalOT:::cot_solve(gs))
testthat::expect_true(inherits(cw, "causalWeights"))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.