tests/testthat/test-tune_survdnn.R

test_that("tune_survdnn returns correct structure for all modes", {
  skip_on_cran()
  skip_if_not(torch::torch_is_installed())

  data <- survival::veteran
  param_grid <- list(
    hidden     = list(c(8), c(8, 4)),
    lr         = c(1e-3),
    activation = c("relu"),
    epochs     = c(5),
    loss       = c("cox")
  )
  times <- c(30, 90)

  all_res <- tune_survdnn(
    Surv(time, status) ~ age + karno + celltype,
    data = data,
    times = times,
    metrics = "cindex",
    param_grid = param_grid,
    folds = 2,
    .seed = 123,
    refit = FALSE,
    return = "all"
  )

  expect_s3_class(all_res, "data.frame")
  expect_true(all(c("fold", "metric", "value") %in% names(all_res)))

  summary_res <- tune_survdnn(
    Surv(time, status) ~ age + karno + celltype,
    data = data,
    times = times,
    metrics = "cindex",
    param_grid = param_grid,
    folds = 2,
    .seed = 123,
    refit = FALSE,
    return = "summary"
  )

  expect_s3_class(summary_res, "data.frame")
  expect_true(all(c("metric", "mean", "sd") %in% names(summary_res)))

  best_cfg <- tune_survdnn(
    Surv(time, status) ~ age + karno + celltype,
    data = data,
    times = times,
    metrics = "cindex",
    param_grid = param_grid,
    folds = 2,
    .seed = 123,
    refit = FALSE,
    return = "best_model"
  )

  expect_s3_class(best_cfg, "data.frame")
  expect_true(all(c("hidden", "lr", "activation", "epochs", "loss") %in% names(best_cfg)))
})

test_that("tune_survdnn works with refit = TRUE and returns survdnn model", {
  skip_on_cran()
  skip_if_not(torch::torch_is_installed())

  data <- survival::veteran
  param_grid <- list(
    hidden     = list(c(8)),
    lr         = c(1e-3),
    activation = c("relu"),
    epochs     = c(3),
    loss       = c("cox")
  )
  times <- c(30, 90)

  mod <- tune_survdnn(
    Surv(time, status) ~ age + karno + celltype,
    data = data,
    times = times,
    metrics = "cindex",
    param_grid = param_grid,
    folds = 2,
    .seed = 42,
    refit = TRUE,
    return = "best_model"
  )

  expect_s3_class(mod, "survdnn")
  expect_s3_class(mod$model, "nn_module")
})

test_that("summarize_tune_survdnn aggregates correctly and throws on bad input", {
  skip_on_cran()
  skip_if_not(torch::torch_is_installed())

  data <- survival::veteran
  param_grid <- list(
    hidden     = list(c(8)),
    lr         = c(1e-3),
    activation = c("relu"),
    epochs     = c(3),
    loss       = c("cox")
  )
  times <- c(30, 90)

  all_res <- tune_survdnn(
    Surv(time, status) ~ age + karno + celltype,
    data = data,
    times = times,
    metrics = "brier",  
    param_grid = param_grid,
    folds = 2,
    .seed = 123,
    refit = FALSE,
    return = "all"
  )

  sm <- summarize_tune_survdnn(all_res, by_time = TRUE)
  expect_s3_class(sm, "data.frame")

  if ("time" %in% names(all_res)) {
    expect_true("time" %in% names(sm))
  }

  expect_error(summarize_tune_survdnn(data.frame(a = 1)), "Input must be the result")
})

Try the survdnn package in your browser

Any scripts or data that you put into this service are public.

survdnn documentation built on Aug. 8, 2025, 6:05 p.m.