tests/testthat/test_pycox.R

skip_if_no_pycox()

set.seed(1)
np <- reticulate::import("numpy")
np$random$seed(1L)
torch <- reticulate::import("torch")
torch$manual_seed(1L)

test_that("get_pycox_optim", {
  net <- build_pytorch_net(1L, 1L, 1L)
  expect_is(get_pycox_optim("adadelta", net),  "torch.optim.adadelta.Adadelta")
  expect_is(get_pycox_optim("adagrad", net),  "torch.optim.adagrad.Adagrad")
  expect_is(get_pycox_optim("adamax", net),  "torch.optim.adamax.Adamax")
  expect_is(get_pycox_optim("adam", net),  "torch.optim.adam.Adam")
  expect_is(get_pycox_optim("adamw", net),  "torch.optim.adamw.AdamW")
  expect_is(get_pycox_optim("asgd", net),  "torch.optim.asgd.ASGD")
  expect_is(get_pycox_optim("rmsprop", net),  "torch.optim.rmsprop.RMSprop")
  expect_is(get_pycox_optim("rprop", net),  "torch.optim.rprop.Rprop")
  expect_is(get_pycox_optim("sgd", net),  "torch.optim.sgd.SGD")
  # expect_is(get_pycox_optim("sparse_adam", net),  "torch.optim.sparse_adam.SparseAdam")
})

test_that("get_pycox_init", {
  a <- 0; b <- 1; mean <- 0; std <- 1; val <- 0; gain <- 1; mode <- "fan_in"
  non_linearity <- "leaky_relu"

  expect_equal(get_pycox_init("uniform"),
               paste0("torch.nn.init.uniform_(m.weight, ", a, ", ", b, ")"))
  expect_equal(get_pycox_init("normal"),
               paste0("torch.nn.init.normal_(m.weight, ", mean, ", ", std, ")"))
  expect_equal(get_pycox_init("constant", val = val),
               paste0("torch.nn.init.constant_(m.weight, ", val, ")"))
  expect_equal(get_pycox_init("xavier_uniform"),
               paste0("torch.nn.init.xavier_uniform_(m.weight, ", gain, ")"))
  expect_equal(get_pycox_init("xavier_normal"),
               paste0("torch.nn.init.xavier_normal_(m.weight, ", gain, ")"))
  expect_equal(get_pycox_init("kaiming_uniform"),
               paste0("torch.nn.init.kaiming_uniform_(m.weight, ", a, ", '",
                 mode, "', '", non_linearity, "')"))
  expect_equal(get_pycox_init("kaiming_normal"),
    paste0("torch.nn.init.kaiming_normal_(m.weight, ", a, ", '", mode, "', '",
           non_linearity, "')"))
  expect_equal(get_pycox_init("orthogonal"),
               paste0("torch.nn.init.orthogonal_(m.weight, ", gain, ")"))
})

fit <- coxtime(Surv(time, status) ~ ., data = rats[1:50, ], verbose = FALSE)

test_that("predict", {
  p <- predict(fit, type = "all", distr6 = FALSE)
  expect_is(p, "list")
  expect_is(p$surv, "matrix")
  expect_is(p$risk, "numeric")
  expect_equal(length(p$risk), 50)
  expect_equal(dim(p$surv), c(50, 20))
})

test_that("predict distr6", {
  if (!requireNamespace("distr6", quietly = TRUE)) {
    skip("distr6 not installed.")
  }
  p <- predict(fit, type = "all", distr6 = TRUE)
  expect_is(p, "list")
  expect_is(p$surv, "Matdist")
  expect_equal(nrow(distr6::gprm(p$surv, "cdf")), 50)
  p <- predict(fit, type = "survival")
  expect_is(p, "matrix")
})

test_that("build_pytorch_net", {
  expect_silent(build_pytorch_net(2L, 2L, c(2, 4, 8), activation = c("relu", "elu", "glu"),
                                  dropout = c(0.1, 1, 0.62)))
})

Try the survivalmodels package in your browser

Any scripts or data that you put into this service are public.

survivalmodels documentation built on March 24, 2022, 9:05 a.m.