Deep Neural Networks for Survival Analysis Using torch
survdnn
implements neural network-based models for right-censored
survival analysis using the native torch
backend in R. It supports
multiple loss functions including Cox partial likelihood, L2-penalized
Cox, Accelerated Failure Time (AFT) objectives, as well as
time-dependent extension such as Cox-Time. The package provides a
formula interface, supports model evaluation using time-dependent
metrics (e.g., C-index, Brier score, IBS), cross-validation, and
hyperparameter tuning.
Surv() ~ .
models"cox"
: Cox partial likelihood"cox_l2"
: penalized Cox"aft"
: Accelerated Failure Time"coxtime"
: deep time-dependent Cox (like DeepSurv)cv_survdnn()
and tune_survdnn()
predict()
and plot()
# Install from GitHub
# install.packages("remotes")
remotes::install_github("ielbadisy/survdnn")
# Or clone and install locally
# git clone https://github.com/ielbadisy/survdnn.git
# setwd("survdnn")
# devtools::install()
library(survdnn)
library(survival, quietly = TRUE)
library(ggplot2)
veteran <- survival::veteran
mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
hidden = c(32, 16),
epochs = 100,
loss = "cox",
verbose = TRUE
)
## Epoch 50 - Loss: 3.987919
## Epoch 100 - Loss: 3.974391
summary(mod)
##
## ── Summary of survdnn model ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
##
## Formula:
## Surv(time, status) ~ age + karno + celltype
## <environment: 0x5b3739336aa0>
##
## Model architecture:
## Hidden layers: 32 : 16
## Activation: relu
## Dropout: 0.3
## Final loss: 3.974391
##
## Training summary:
## Epochs: 100
## Learning rate: 1e-04
## Loss function: cox
##
## Data summary:
## Observations: 137
## Predictors: age, karno, celltypesmallcell, celltypeadeno, celltypelarge
## Time range: [ 1, 999 ]
## Event rate: 93.4%
plot(mod, group_by = "celltype", times = 1:300)
# Cox partial likelihood
mod1 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "cox",
epochs = 100
)
## Epoch 50 - Loss: 4.216911
## Epoch 100 - Loss: 4.105076
# Accelerated Failure Time
mod2 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "aft",
epochs = 100
)
## Epoch 50 - Loss: 21.136486
## Epoch 100 - Loss: 20.663244
# Deep time-dependent Cox (Coxtime)
mod3 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "coxtime",
epochs = 100
)
## Epoch 50 - Loss: 4.856084
## Epoch 100 - Loss: 5.289982
cv_results <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(30, 90, 180),
metrics = c("cindex", "ibs"),
folds = 3,
hidden = c(16, 8),
loss = "cox",
epochs = 100
)
print(cv_results)
grid <- list(
hidden = list(c(16), c(32, 16)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(100, 300),
loss = c("cox", "aft", "coxtime")
)
tune_res <- tune_survdnn(
formula = Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(90, 300),
metrics = "cindex",
param_grid = grid,
folds = 3,
refit = FALSE,
return = "summary"
)
print(tune_res)
plot(mod1, group_by = "celltype", times = 1:300)
plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
help(package = "survdnn")
?survdnn
?tune_survdnn
?cv_survdnn
?plot.survdnn
# Run all tests
devtools::test()
The survdnn
R package is available at:
https://github.com/ielbadisy/survdnn
The package is currently under submission to CRAN.
Contributions, issues, and feature requests are welcome. Open an issue or submit a pull request!
MIT © Imad El Badisy
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.