tune_survdnn: Tune Hyperparameters for a survdnn Model via Cross-Validation

View source: R/tune_survdnn.R

tune_survdnnR Documentation

Tune Hyperparameters for a survdnn Model via Cross-Validation

Description

Performs k-fold cross-validation over a user-defined hyperparameter grid and selects the best configuration according to the specified evaluation metric.

Usage

tune_survdnn(
  formula,
  data,
  times,
  metrics = "cindex",
  param_grid,
  folds = 3,
  .seed = 42,
  .device = c("auto", "cpu", "cuda"),
  na_action = c("omit", "fail"),
  refit = FALSE,
  return = c("all", "summary", "best_model")
)

Arguments

formula

A survival formula, e.g., 'Surv(time, status) ~ x1 + x2'.

data

A data frame.

times

A numeric vector of evaluation time points.

metrics

A character vector of evaluation metrics: "cindex", "brier", or "ibs". Only the first metric is used for model selection.

param_grid

A named list defining hyperparameter combinations to evaluate. Required names: 'hidden', 'lr', 'activation', 'epochs', 'loss'.

folds

Number of cross-validation folds (default: 3).

.seed

Optional seed for reproducibility.

.device

Character string indicating the computation device used when fitting models during cross-validation and refitting. One of '"auto"', '"cpu"', or '"cuda"'. '"auto"' uses CUDA if available, otherwise falls back to CPU.

na_action

Character. How to handle missing values: '"omit"' drops incomplete rows; '"fail"' errors if any NA is present.

refit

Logical. If TRUE, refits the best model on the full dataset.

return

One of "all", "summary", or "best_model":

"all"

Returns the full cross-validation result across all combinations.

"summary"

Returns averaged results per configuration.

"best_model"

Returns the refitted model or best hyperparameters.

Value

A tibble or model object depending on the 'return' value.


survdnn documentation built on Jan. 8, 2026, 9:07 a.m.