tests/testthat/test-plot.survdnn.R

test_that("plot.survdnn returns a ggplot object (default)", {
  skip_on_cran()
  skip_if_not(torch::torch_is_installed())

  data <- survival::veteran
  mod <- survdnn(Surv(time, status) ~ age + karno + celltype,
                 data = data, hidden = c(8), epochs = 5, verbose = FALSE)

  p <- plot(mod)

  expect_s3_class(p, "ggplot")
})

test_that("plot.survdnn supports group_by and plot_mean_only", {
  skip_on_cran()
  skip_if_not(torch::torch_is_installed())

  data <- survival::veteran
  mod <- survdnn(Surv(time, status) ~ age + karno + celltype,
                 data = data, hidden = c(8), epochs = 5, verbose = FALSE)

  # grouped + mean only
  p1 <- plot(mod, group_by = "celltype", plot_mean_only = TRUE, times = 1:100)
  expect_s3_class(p1, "ggplot")

  # grouped + individual + mean
  p2 <- plot(mod, group_by = "celltype", plot_mean_only = FALSE, add_mean = TRUE, times = 1:100)
  expect_s3_class(p2, "ggplot")
})

test_that("plot.survdnn fails with wrong inputs", {
  skip_on_cran()
  skip_if_not(torch::torch_is_installed())

  expect_error(survdnn:::plot.survdnn("not a model"), "inherits")
})

Try the survdnn package in your browser

Any scripts or data that you put into this service are public.

survdnn documentation built on Aug. 8, 2025, 6:05 p.m.