Nothing
test_that("tune_survdnn returns correct structure for all modes", {
skip_on_cran()
skip_if_not(torch::torch_is_installed())
data <- survival::veteran
param_grid <- list(
hidden = list(c(8), c(8, 4)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(5),
loss = c("cox")
)
times <- c(30, 90)
all_res <- tune_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = data,
times = times,
metrics = "cindex",
param_grid = param_grid,
folds = 2,
.seed = 123,
refit = FALSE,
return = "all"
)
expect_s3_class(all_res, "data.frame")
expect_true(all(c("fold", "metric", "value") %in% names(all_res)))
summary_res <- tune_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = data,
times = times,
metrics = "cindex",
param_grid = param_grid,
folds = 2,
.seed = 123,
refit = FALSE,
return = "summary"
)
expect_s3_class(summary_res, "data.frame")
expect_true(all(c("metric", "mean", "sd") %in% names(summary_res)))
best_cfg <- tune_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = data,
times = times,
metrics = "cindex",
param_grid = param_grid,
folds = 2,
.seed = 123,
refit = FALSE,
return = "best_model"
)
expect_s3_class(best_cfg, "data.frame")
expect_true(all(c("hidden", "lr", "activation", "epochs", "loss") %in% names(best_cfg)))
})
test_that("tune_survdnn works with refit = TRUE and returns survdnn model", {
skip_on_cran()
skip_if_not(torch::torch_is_installed())
data <- survival::veteran
param_grid <- list(
hidden = list(c(8)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(3),
loss = c("cox")
)
times <- c(30, 90)
mod <- tune_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = data,
times = times,
metrics = "cindex",
param_grid = param_grid,
folds = 2,
.seed = 42,
refit = TRUE,
return = "best_model"
)
expect_s3_class(mod, "survdnn")
expect_s3_class(mod$model, "nn_module")
})
test_that("summarize_tune_survdnn aggregates correctly and throws on bad input", {
skip_on_cran()
skip_if_not(torch::torch_is_installed())
data <- survival::veteran
param_grid <- list(
hidden = list(c(8)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(3),
loss = c("cox")
)
times <- c(30, 90)
all_res <- tune_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = data,
times = times,
metrics = "brier",
param_grid = param_grid,
folds = 2,
.seed = 123,
refit = FALSE,
return = "all"
)
sm <- summarize_tune_survdnn(all_res, by_time = TRUE)
expect_s3_class(sm, "data.frame")
if ("time" %in% names(all_res)) {
expect_true("time" %in% names(sm))
}
expect_error(summarize_tune_survdnn(data.frame(a = 1)), "Input must be the result")
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.