deephit: DeepHit Survival Neural Network

View source: R/deephit.R

deephitR Documentation

DeepHit Survival Neural Network

Description

DeepHit fits a neural network based on the PMF of a discrete Cox model. This is the single (non-competing) event implementation.

Usage

deephit(
  formula = NULL,
  data = NULL,
  reverse = FALSE,
  time_variable = "time",
  status_variable = "status",
  x = NULL,
  y = NULL,
  frac = 0,
  cuts = 10,
  cutpoints = NULL,
  scheme = c("equidistant", "quantiles"),
  cut_min = 0,
  activation = "relu",
  custom_net = NULL,
  num_nodes = c(32L, 32L),
  batch_norm = TRUE,
  dropout = NULL,
  device = NULL,
  mod_alpha = 0.2,
  sigma = 0.1,
  early_stopping = FALSE,
  best_weights = FALSE,
  min_delta = 0,
  patience = 10L,
  batch_size = 256L,
  epochs = 1L,
  verbose = FALSE,
  num_workers = 0L,
  shuffle = TRUE,
  ...
)

Arguments

formula

(formula(1))
Object specifying the model fit, left-hand-side of formula should describe a survival::Surv() object.

data

(data.frame(1))
Training data of data.frame like object, internally is coerced with stats::model.matrix().

reverse

(logical(1))
If TRUE fits estimator on censoring distribution, otherwise (default) survival distribution.

time_variable

(character(1))
Alternative method to call the function. Name of the 'time' variable, required if formula. or x and Y not given.

status_variable

(character(1))
Alternative method to call the function. Name of the 'status' variable, required if formula or x and Y not given.

x

(data.frame(1))
Alternative method to call the function. Required if formula, time_variable and status_variable not given. Data frame like object of features which is internally coerced with model.matrix.

y

([survival::Surv()])
Alternative method to call the function. Required if formula, time_variable and status_variable not given. Survival outcome of right-censored observations.

frac

(numeric(1))
Fraction of data to use for validation dataset, default is 0 and therefore no separate validation dataset.

cuts

(integer(1))
If discretise is TRUE then determines number of cut-points for discretisation.

cutpoints

(numeric())
Alternative to cuts if discretise is true, provide exact cutpoints for discretisation. cuts is ignored if cutpoints is non-NULL.

scheme

(character(1))
Method of discretisation, either "equidistant" (default) or "quantiles". See reticulate::py_help(pycox$models$LogisticHazard$label_transform) for more detail.

cut_min

(integer(1))
Starting duration for discretisation, see reticulate::py_help(pycox$models$LogisticHazard$label_transform) for more detail.

activation

(character(1))
See get_pycox_activation.

custom_net

(torch.nn.modules.module.Module(1))
Optional custom network built with build_pytorch_net, otherwise default architecture used. Note that if building a custom network the number of output channels depends on cuts or cutpoints.

num_nodes, batch_norm, dropout

(integer()/logical(1)/numeric(1))
See build_pytorch_net.

device

(integer(1)|character(1))
Passed to pycox.models.DeepHitSingle, specifies device to compute models on.

mod_alpha

(numeric(1))
Weighting in (0,1) for combining likelihood (L1) and rank loss (L2). See reference and py_help(pycox$models$DeepHitSingle) for more detail.

sigma

(numeric(1))
From eta in rank loss (L2) of ref. See reference and py_help(pycox$models$DeepHitSingle) for more detail.

early_stopping, best_weights, min_delta, patience

(logical(1)/logical(1)/numeric(1)/integer(1)
See get_pycox_callbacks.

batch_size

(integer(1))
Passed to pycox.models.DeepHitSingle.fit, elements in each batch.

epochs

(integer(1))
Passed to pycox.models.DeepHitSingle.fit, number of epochs.

verbose

(logical(1))
Passed to pycox.models.DeepHitSingle.fit, should information be displayed during fitting.

num_workers

(integer(1))
Passed to pycox.models.DeepHitSingle.fit, number of workers used in the dataloader.

shuffle

(logical(1))
Passed to pycox.models.DeepHitSingle.fit, should order of dataset be shuffled?

...

ANY
Passed to get_pycox_optim.

Details

Implemented from the pycox Python package via reticulate. Calls pycox.models.DeepHitSingle.

Value

An object inheriting from class deephit.

An object of class survivalmodel.

References

Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018. http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit

Examples


if (requireNamespaces("reticulate")) {
  # all defaults
  deephit(data = simsurvdata(50))

  # common parameters
  deephit(data = simsurvdata(50), frac = 0.3, activation = "relu",
    num_nodes = c(4L, 8L, 4L, 2L), dropout = 0.1, early_stopping = TRUE, epochs = 100L,
    batch_size = 32L)
}



survivalmodels documentation built on March 24, 2022, 9:05 a.m.