survdnn: Fit a Deep Neural Network for Survival Analysis

View source: R/survdnn.R

survdnnR Documentation

Fit a Deep Neural Network for Survival Analysis

Description

Trains a deep neural network (DNN) to model right-censored survival data using one of the predefined loss functions: Cox, AFT, or Coxtime.

Usage

survdnn(
  formula,
  data,
  hidden = c(32L, 16L),
  activation = "relu",
  lr = 1e-04,
  epochs = 300L,
  loss = c("cox", "cox_l2", "aft", "coxtime"),
  optimizer = c("adam", "adamw", "sgd", "rmsprop", "adagrad"),
  optim_args = list(),
  verbose = TRUE,
  dropout = 0.3,
  batch_norm = TRUE,
  callbacks = NULL,
  .seed = NULL,
  .device = c("auto", "cpu", "cuda"),
  na_action = c("omit", "fail")
)

Arguments

formula

A survival formula of the form 'Surv(time, status) ~ predictors'.

data

A data frame containing the variables in the model.

hidden

Integer vector. Sizes of the hidden layers (default: c(32, 16)).

activation

Character string specifying the activation function to use in each layer. Supported options: '"relu"', '"leaky_relu"', '"tanh"', '"sigmoid"', '"gelu"', '"elu"', '"softplus"'.

lr

Learning rate for the optimizer (default: '1e-4').

epochs

Number of training epochs (default: 300).

loss

Character name of the loss function to use. One of '"cox"', '"cox_l2"', '"aft"', or '"coxtime"'.

optimizer

Character string specifying the optimizer to use. One of '"adam"', '"adamw"', '"sgd"', '"rmsprop"', or '"adagrad"'. Defaults to '"adam"'.

optim_args

Optional named list of additional arguments passed to the underlying torch optimizer (e.g., 'list(weight_decay = 1e-4, momentum = 0.9)').

verbose

Logical; whether to print loss progress every 50 epochs (default: TRUE).

dropout

Numeric between 0 and 1. Dropout rate applied after each hidden layer (default = 0.3). Set to 0 to disable dropout.

batch_norm

Logical; whether to add batch normalization after each hidden linear layer (default = TRUE).

callbacks

Optional list of callback functions. Each callback should have signature 'function(epoch, current)' and return TRUE if training should stop, FALSE otherwise. Used, for example, with [callback_early_stopping()].

.seed

Optional integer. If provided, sets both R and torch random seeds for reproducible weight initialization, shuffling, and dropout.

.device

Character string indicating the computation device. One of '"auto"', '"cpu"', or '"cuda"'. '"auto"' uses CUDA if available, otherwise falls back to CPU.

na_action

Character. How to handle missing values in the model variables: '"omit"' drops incomplete rows (and reports how many were removed when 'verbose=TRUE'); '"fail"' stops with an error if any missing values are present.

Value

An object of class '"survdnn"' containing:

model

Trained 'nn_module' object.

formula

Original survival formula.

data

Training data used for fitting.

xnames

Predictor variable names.

x_center

Column means of predictors.

x_scale

Column standard deviations of predictors.

loss_history

Vector of loss values per epoch.

final_loss

Final training loss.

loss

Loss function name used ("cox", "aft", etc.).

activation

Activation function used.

hidden

Hidden layer sizes.

lr

Learning rate.

epochs

Number of training epochs.

optimizer

Optimizer name used.

optim_args

List of optimizer arguments used.

device

Torch device used for training ('torch_device').

aft_log_sigma

Learned global log(sigma) for 'loss="aft"'; 'NA_real_' otherwise.

aft_loc

AFT log-time location offset used for centering when 'loss="aft"'; 'NA_real_' otherwise.

coxtime_time_center

Mean used to scale time for CoxTime; 'NA_real_' otherwise.

coxtime_time_scale

SD used to scale time for CoxTime; 'NA_real_' otherwise.


survdnn documentation built on Jan. 8, 2026, 9:07 a.m.