DeepHit Survival Neural Network
DeepHit fits a neural network based on the PMF of a discrete Cox model. This is
the single (non-competing) event implementation.
formula = NULL,
data = NULL,
reverse = FALSE,
time_variable = "time",
status_variable = "status",
x = NULL,
y = NULL,
frac = 0,
cuts = 10,
cutpoints = NULL,
scheme = c("equidistant", "quantiles"),
cut_min = 0,
activation = "relu",
custom_net = NULL,
num_nodes = c(32L, 32L),
batch_norm = TRUE,
dropout = NULL,
device = NULL,
mod_alpha = 0.2,
sigma = 0.1,
early_stopping = FALSE,
best_weights = FALSE,
min_delta = 0,
patience = 10L,
batch_size = 256L,
epochs = 1L,
verbose = FALSE,
num_workers = 0L,
shuffle = TRUE,
formula |
Object specifying the model fit, left-hand-side of formula should describe a survival::Surv()
data |
Training data of data.frame like object, internally is coerced with stats::model.matrix() .
reverse |
If TRUE fits estimator on censoring distribution, otherwise (default) survival distribution.
time_variable |
Alternative method to call the function. Name of the 'time' variable, required if formula .
or x and Y not given.
status_variable |
Alternative method to call the function. Name of the 'status' variable, required if formula
or x and Y not given.
x |
Alternative method to call the function. Required if formula, time_variable and
status_variable not given. Data frame like object of features which is internally
coerced with model.matrix .
y |
Alternative method to call the function. Required if formula, time_variable and
status_variable not given. Survival outcome of right-censored observations.
frac |
Fraction of data to use for validation dataset, default is 0 and therefore no separate
validation dataset.
cuts |
If discretise is TRUE then determines number of cut-points for discretisation.
cutpoints |
Alternative to cuts if discretise is true, provide exact cutpoints for discretisation.
cuts is ignored if cutpoints is non-NULL.
scheme |
Method of discretisation, either "equidistant" (default) or "quantiles" .
See reticulate::py_help(pycox$models$LogisticHazard$label_transform) for more detail.
cut_min |
Starting duration for discretisation, see
reticulate::py_help(pycox$models$LogisticHazard$label_transform) for more detail.
activation |
See get_pycox_activation.
custom_net |
Optional custom network built with build_pytorch_net, otherwise default architecture used.
Note that if building a custom network the number of output channels depends on cuts or
cutpoints .
num_nodes , batch_norm , dropout |
See build_pytorch_net.
device |
Passed to pycox.models.DeepHitSingle , specifies device to compute models on.
mod_alpha |
Weighting in (0,1) for combining likelihood (L1) and rank loss (L2). See reference and
py_help(pycox$models$DeepHitSingle) for more detail.
sigma |
From eta in rank loss (L2) of ref. See reference and
py_help(pycox$models$DeepHitSingle) for more detail.
early_stopping , best_weights , min_delta , patience |
See get_pycox_callbacks.
batch_size |
Passed to , elements in each batch.
epochs |
Passed to , number of epochs.
verbose |
Passed to , should information be displayed during
num_workers |
Passed to , number of workers used in the
shuffle |
Passed to , should order of dataset be shuffled?
... |
Passed to get_pycox_optim.
Implemented from the pycox
Python package via reticulate.
Calls pycox.models.DeepHitSingle
An object inheriting from class deephit
An object of class survivalmodel
Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar.
Deephit: A deep learning approach to survival analysis with competing risks.
In Thirty-Second AAAI Conference on Artificial Intelligence, 2018.