skip_if_no_pycox()
set.seed(1)
np <- reticulate::import("numpy")
np$random$seed(1L)
torch <- reticulate::import("torch")
torch$manual_seed(1L)
test_that("get_pycox_optim", {
net <- build_pytorch_net(1L, 1L, 1L)
expect_is(get_pycox_optim("adadelta", net), "torch.optim.adadelta.Adadelta")
expect_is(get_pycox_optim("adagrad", net), "torch.optim.adagrad.Adagrad")
expect_is(get_pycox_optim("adamax", net), "torch.optim.adamax.Adamax")
expect_is(get_pycox_optim("adam", net), "torch.optim.adam.Adam")
expect_is(get_pycox_optim("adamw", net), "torch.optim.adamw.AdamW")
expect_is(get_pycox_optim("asgd", net), "torch.optim.asgd.ASGD")
expect_is(get_pycox_optim("rmsprop", net), "torch.optim.rmsprop.RMSprop")
expect_is(get_pycox_optim("rprop", net), "torch.optim.rprop.Rprop")
expect_is(get_pycox_optim("sgd", net), "torch.optim.sgd.SGD")
# expect_is(get_pycox_optim("sparse_adam", net), "torch.optim.sparse_adam.SparseAdam")
})
test_that("get_pycox_init", {
a <- 0; b <- 1; mean <- 0; std <- 1; val <- 0; gain <- 1; mode <- "fan_in"
non_linearity <- "leaky_relu"
expect_equal(get_pycox_init("uniform"),
paste0("torch.nn.init.uniform_(m.weight, ", a, ", ", b, ")"))
expect_equal(get_pycox_init("normal"),
paste0("torch.nn.init.normal_(m.weight, ", mean, ", ", std, ")"))
expect_equal(get_pycox_init("constant", val = val),
paste0("torch.nn.init.constant_(m.weight, ", val, ")"))
expect_equal(get_pycox_init("xavier_uniform"),
paste0("torch.nn.init.xavier_uniform_(m.weight, ", gain, ")"))
expect_equal(get_pycox_init("xavier_normal"),
paste0("torch.nn.init.xavier_normal_(m.weight, ", gain, ")"))
expect_equal(get_pycox_init("kaiming_uniform"),
paste0("torch.nn.init.kaiming_uniform_(m.weight, ", a, ", '",
mode, "', '", non_linearity, "')"))
expect_equal(get_pycox_init("kaiming_normal"),
paste0("torch.nn.init.kaiming_normal_(m.weight, ", a, ", '", mode, "', '",
non_linearity, "')"))
expect_equal(get_pycox_init("orthogonal"),
paste0("torch.nn.init.orthogonal_(m.weight, ", gain, ")"))
})
fit <- coxtime(Surv(time, status) ~ ., data = rats[1:50, ], verbose = FALSE)
test_that("predict", {
p <- predict(fit, type = "all", distr6 = FALSE)
expect_is(p, "list")
expect_is(p$surv, "matrix")
expect_is(p$risk, "numeric")
expect_equal(length(p$risk), 50)
expect_equal(dim(p$surv), c(50, 20))
})
test_that("predict distr6", {
if (!requireNamespace("distr6", quietly = TRUE)) {
skip("distr6 not installed.")
}
p <- predict(fit, type = "all", distr6 = TRUE)
expect_is(p, "list")
expect_is(p$surv, "Matdist")
expect_equal(nrow(distr6::gprm(p$surv, "cdf")), 50)
p <- predict(fit, type = "survival")
expect_is(p, "matrix")
})
test_that("build_pytorch_net", {
expect_silent(build_pytorch_net(2L, 2L, c(2, 4, 8), activation = c("relu", "elu", "glu"),
dropout = c(0.1, 1, 0.62)))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.