gridsearch_survdnn: Grid Search for survdnn Hyperparameters

View source: R/gridsearch_survdnn.R

gridsearch_survdnnR Documentation

Grid Search for survdnn Hyperparameters

Description

Performs grid search over user-specified hyperparameters and evaluates performance on a validation set.

Usage

gridsearch_survdnn(
  formula,
  train,
  valid,
  times,
  metrics = c("cindex", "ibs"),
  param_grid,
  .seed = 42,
  .device = c("auto", "cpu", "cuda")
)

Arguments

formula

A survival formula (e.g., 'Surv(time, status) ~ .')

train

Training dataset

valid

Validation dataset

times

Evaluation time points (numeric vector)

metrics

Evaluation metrics (character vector): any of "cindex" and "ibs".

param_grid

A named list of hyperparameters to search over. Currently supported entries are hidden, lr, activation, epochs, and loss.

.seed

Optional random seed for reproducibility

.device

Character string indicating the computation device used when fitting all models in the grid search. One of "auto", "cpu", or "cuda". This is a runtime setting and is not part of the hyperparameter grid.

Value

A tibble with configurations and their validation metrics

Examples


library(survdnn)
library(survival)
set.seed(123)

# Simulate small dataset
n <- 300
x1 <- rnorm(n); x2 <- rbinom(n, 1, 0.5)
time <- rexp(n, rate = 0.1)
status <- rbinom(n, 1, 0.7)
df <- data.frame(time, status, x1, x2)

# Split into training and validation
idx <- sample(seq_len(n), 0.7 * n)
train <- df[idx, ]
valid <- df[-idx, ]

# Define formula and param grid
formula <- Surv(time, status) ~ x1 + x2
param_grid <- list(
  hidden     = list(c(16, 8), c(32, 16)),
  lr         = c(1e-3),
  activation = c("relu"),
  epochs     = c(100),
  loss       = c("cox", "coxtime")
)

# Run grid search
results <- gridsearch_survdnn(
  formula = formula,
  train   = train,
  valid   = valid,
  times   = c(10, 20, 30),
  metrics = c("cindex", "ibs"),
  param_grid = param_grid
)

# View summary
dplyr::group_by(results, hidden, lr, activation, epochs, loss, metric) |>
  dplyr::summarise(mean = mean(value, na.rm = TRUE), .groups = "drop")


survdnn documentation built on Jan. 8, 2026, 9:07 a.m.