#' @title DNNSurv Neural Network for Conditional Survival Probabilities
#'
#' @description DNNSurv neural fits a neural network based on pseudo-conditional survival
#' probabilities.
#'
#' @details Code for generating the conditional probabilities and pre-processing data is taken from
#' \url{https://github.com/lilizhaoUM/DNNSurv}.
#'
#' @template param_traindata
#' @template return_train
#'
#' @param cutpoints `(numeric())`\cr
#' Points at which to cut survival time into discrete points.
#' @param cuts `(integer(1))`\cr
#' If `cutpoints` not provided then number of equally spaced points at which to cut survival time.
#' @param custom_model `(keras.engine.training.Model(1))` \cr
#' Optional custom architecture built with [build_keras_net] or directly with \CRANpkg{keras}.
#' Output layer should be of length `1` input is number of features plus number of cuts.
#' @param loss_weights,weighted_metrics See [keras::compile.keras.engine.training.Model].
#' @param optimizer `(character(1))`\cr
#' See [get_keras_optimizer].
#' @param early_stopping `(logical(1))`\cr
#' If `TRUE` then early stopping callback is included.
#' @param min_delta,patience,baseline,restore_best_weights See [keras::callback_early_stopping].
#' @param verbose `(integer(1))` \cr
#' Level of verbosity for printing, `0` or `1`.
#' @param batch_size,epochs,validation_split,shuffle,sample_weight,initial_epoch,steps_per_epoch,validation_steps See [keras::fit.keras.engine.training.Model]. # nolint
#' @param ... `ANY` \cr
#' Passed to [get_keras_optimizer].
#'
#' @references
#' Zhao, L., & Feng, D. (2020).
#' DNNSurv: Deep Neural Networks for Survival Analysis Using Pseudo Values.
#' https://arxiv.org/abs/1908.02337
#'
#' @examples
#' \donttest{
#' if (requireNamespaces(c("keras", "pseudo")))
#' # all defaults
#' dnnsurv(data = simsurvdata(10))
#'
#' # setting common parameters
#' dnnsurv(time_variable = "time", status_variable = "status", data = simsurvdata(10),
#' early_stopping = TRUE, epochs = 100L, validation_split = 0.3)
#'
#' # custom model
#' library(keras)
#' cuts <- 10
#' df <- simsurvdata(50)
#' # shape = features + cuts
#' input <- layer_input(shape = c(3L + cuts), name = 'input')
#' output <- input %>%
#' layer_dense(units = 4L, use_bias = TRUE) %>%
#' layer_dense(units = 1L, use_bias = TRUE ) %>%
#' layer_activation(activation="sigmoid")
#'
#' model <- keras_model(input, output)
#' class(model)
#'
#' dnnsurv(custom_model = model, time_variable = "time",
#' status_variable = "status", data = df, cuts = cuts)
#'
#' }
#'
#' @export
dnnsurv <- function(formula = NULL, data = NULL, reverse = FALSE,
time_variable = "time", status_variable = "status",
x = NULL, y = NULL, cutpoints = NULL, cuts = 5L,
custom_model = NULL, loss_weights = NULL, weighted_metrics = NULL,
optimizer = "adam", early_stopping = FALSE, min_delta = 0, patience = 0L,
verbose = 0L, baseline = NULL, restore_best_weights = FALSE,
batch_size = 32L, epochs = 10L, validation_split = 0, shuffle = TRUE,
sample_weight = NULL, initial_epoch = 0L, steps_per_epoch = NULL,
validation_steps = NULL, ...) {
if (!requireNamespace("keras", quietly = TRUE)) {
stop("Package 'keras' required but not installed.") # nocov
}
if (!requireNamespace("pseudo", quietly = TRUE)) {
stop("Package 'pseudo' required but not installed.") # nocov
}
call <- match.call()
data <- clean_train_data(formula, data, time_variable, status_variable, x, y, reverse)
time <- data$y[, 1L]
if (is.null(cutpoints)) {
if (is.null(cuts)) {
stop("One of 'cuts' or 'cutpoints' must be provided.")
}
cutpoints <- seq.int(min(time), max(time), length.out = cuts)
}
pseudo_cond <- .get_pseudo_conditional(
time,
data$y[, 2L],
cutpoints)
x_train <- cbind(data$x[pseudo_cond$id, ],
model.matrix(~ as.factor(pseudo_cond$s) + 0))
y_train <- pseudo_cond$pseudost
if (!is.null(custom_model)) {
model <- custom_model
} else {
model <- keras::keras_model_sequential()
keras::layer_dense(model,
units = 8, kernel_regularizer = keras::regularizer_l2(0.0001),
activation = "tanh",
input_shape = dim(x_train)[[2]])
keras::layer_dense(model,
units = 4, kernel_regularizer = keras::regularizer_l2(0.01),
activation = "tanh")
keras::layer_dense(model, units = 1, activation = "sigmoid")
}
keras::compile(
model,
loss = "mse",
metrics = "mae",
loss_weights = loss_weights,
weighted_metrics = weighted_metrics,
optimizer = get_keras_optimizer(optimizer = optimizer, ...)
)
if (early_stopping) {
callbacks <- list(
keras::callback_early_stopping(
min_delta = min_delta,
patience = patience,
verbose = verbose,
baseline = baseline,
restore_best_weights = restore_best_weights
)
)
} else {
callbacks <- NULL
}
keras::fit(
model,
x_train,
y_train,
callbacks = callbacks,
batch_size = batch_size,
epochs = epochs,
verbose = verbose,
validation_split = validation_split,
shuffle = shuffle,
sample_weight = sample_weight,
initial_epoch = initial_epoch,
steps_per_epoch = steps_per_epoch,
validation_steps = validation_steps
)
structure(list(y = data$y, x = data$x,
xnames = colnames(data$x),
model = model,
call = call,
cutpoints = cutpoints),
name = "DNNSurv Neural Network",
class = c("dnnsurv", "survivalmodel")
)
}
#' @title Predict Method for DNNSurv
#'
#' @description Predicted values from a fitted object of class dnnsurv.
#'
#' @template return_predict
#'
#' @param object `(dnnsurv(1))`\cr
#' Object of class inheriting from `"dnnsurv"`.
#' @param newdata `(data.frame(1))`\cr
#' Testing data of `data.frame` like object, internally is coerced with [stats::model.matrix()].
#' If missing then training data from fitted object is used.
#' @param batch_size `(integer(1))`\cr
#' Passed to [keras::predict.keras.engine.training.Model], elements in each batch.
#' @param verbose `(integer(1))`\cr
#' Level of verbosity for printing, `0` or `1`.
#' @param steps `(integer(1))`\cr
#' Number of batches before evaluation finished, see [keras::predict.keras.engine.training.Model].
#' @param callbacks `(list())`\cr
#' Optional callbacks to apply during prediction.
#' @param type (`character(1)`)\cr
#' Type of predicted value. Choices are survival probabilities over all time-points in training
#' data (`"survival"`) or a relative risk ranking (`"risk"`), which is the negative mean survival
#' time so higher rank implies higher risk of event, or both (`"all"`).
#' @param distr6 `(logical(1))`\cr
#' If `FALSE` (default) and `type` is `"survival"` or `"all"` returns matrix of survival
#' probabilities, otherwise returns a [distr6::Matdist()].
#' @param ... `ANY` \cr
#' Currently ignored.
#'
#' @examples
#' \donttest{
#' if (requireNamespaces(c("keras", "pseudo")))
#' fit <- dnnsurv(data = simsurvdata(10))
#'
#' # predict survival matrix and relative risks
#' predict(fit, simsurvdata(10), type = "all")
#'
#' # return as distribution
#' if (requireNamespaces("distr6")) {
#' predict(fit, simsurvdata(10), distr6 = TRUE)
#' }
#' }
#'
#' @export
predict.dnnsurv <- function(object, newdata, batch_size = 32L, verbose = 0L,
steps = NULL, callbacks = NULL,
type = c("survival", "risk", "all"), distr6 = FALSE,
...) {
newdata <- clean_test_data(object, newdata)
x_test_all <- do.call(rbind, replicate(length(object$cutpoints), newdata, simplify = FALSE))
smatrix_test <- model.matrix(~ as.factor(rep(object$cutpoints, each = nrow(newdata))) + 0)
x_test_all <- cbind(x_test_all, smatrix_test)
# predict test data
pred <- predict(
object$model,
x = x_test_all,
batch_size = batch_size,
verbose = verbose,
steps = steps,
callbacks = callbacks
)
pred <- matrix(pred, nrow = nrow(newdata))
ypred <- lapply(seq_along(object$cutpoints), function(i) {
apply(pred[, 1:i, drop = FALSE], 1, prod)
})
surv <- Reduce(cbind, ypred)
colnames(surv) <- object$cutpoints
stopifnot(nrow(newdata) == nrow(surv))
surv <- fill_na(round(surv, 4))
times <- as.numeric(colnames(surv))
colnames(surv) <- times
ret <- list()
type <- match.arg(type)
if (type %in% c("survival", "all")) {
if (!distr6 || !requireNamespace("distr6", quietly = TRUE)) {
if (distr6) {
warning("'distr6' not installed, returning 'surv' as matrix.") # nocov
}
ret$surv <- surv
} else {
ret$surv <- distr6::as.Distribution(1 - surv, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics")
)
}
}
if (type %in% c("risk", "all")) {
ret$risk <- surv_to_risk(surv)
}
if (length(ret) == 1) {
return(ret[[1]])
} else {
return(ret)
}
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.