#' @title Survival DeepSurv Learner
#'
#' @name mlr_learners_surv.deepsurv
#'
#' @description
#' A [mlr3proba::LearnerSurv] implementing Deep-Surv from Python package
#' \href{pycox}{https://pypi.org/project/pycox/}.
#'
#' Calls `pycox.models.CoxPH`.
#'
#' @templateVar id surv.deepsurv
#' @template section_dictionary_learner
#'
#' @references
#' Katzman, J. L., Shaham, U., Cloninger, A., Bates, J., Jiang, T., & Kluger, Y. (2018).
#' DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural
#' network.
#' BMC Medical Research Methodology, 18(1), 24. https://doi.org/10.1186/s12874-018-0482-1
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerSurvDeepsurv = R6::R6Class("LearnerSurvDeepsurv",
inherit = mlr3proba::LearnerSurv,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ParamSet$new(
params = list(
ParamDbl$new("frac", default = 0, lower = 0, upper = 1, tags = c("train", "prep")),
ParamUty$new("num_nodes", tags = c("train", "net", "required")),
ParamLgl$new("batch_norm", default = TRUE, tags = c("train", "net")),
ParamDbl$new("dropout",
default = "None", special_vals = list("None"),
lower = 0, upper = 1, tags = c("train", "net")),
ParamFct$new("activation",
default = "relu", levels = activations,
tags = c("train", "act")),
ParamDbl$new("alpha", default = 1, lower = 0, tags = c("train", "opt")),
ParamDbl$new("lambd", default = 0.5, lower = 0, tags = c("train", "opt")),
ParamUty$new("device", tags = c("train", "mod")),
ParamUty$new("loss", tags = c("train", "mod")),
ParamFct$new("optimizer",
default = "adam", levels = optimizers,
tags = c("train", "opt")),
ParamDbl$new("rho", default = 0.9, tags = c("train", "opt")),
ParamDbl$new("eps", default = 1e-8, tags = c("train", "opt")),
ParamDbl$new("lr", default = 1, tags = c("train", "opt")),
ParamDbl$new("weight_decay", default = 0, tags = c("train", "opt")),
ParamDbl$new("learning_rate", default = 1e-2, tags = c("train", "opt")),
ParamDbl$new("lr_decay", default = 0, tags = c("train", "opt")),
ParamUty$new("betas", default = c(0.9, 0.999), tags = c("train", "opt")),
ParamLgl$new("amsgrad", default = FALSE, tags = c("train", "opt")),
ParamDbl$new("t0", default = 1e6, tags = c("train", "opt")),
ParamDbl$new("momentum", default = 0, tags = c("train", "opt")),
ParamLgl$new("centered", default = TRUE, tags = c("train", "opt")),
ParamUty$new("etas", default = c(0.5, 1.2), tags = c("train", "opt")),
ParamUty$new("step_sizes", default = c(1e-6, 50), tags = c("train", "opt")),
ParamDbl$new("dampening", default = 0, tags = c("train", "opt")),
ParamLgl$new("nesterov", default = FALSE, tags = c("train", "opt")),
ParamLgl$new("lr_finder", default = FALSE, tags = c("train", "lrf")),
ParamInt$new("batch_size", default = 256, tags = c("train", "lrf", "fit", "predict")),
ParamDbl$new("tolerance",
lower = 0, upper = Inf, default = Inf,
tags = c("train", "lrf")),
ParamInt$new("epochs", lower = 1, upper = Inf, default = 1, tags = c("train", "fit")),
ParamLgl$new("verbose", default = TRUE, tags = c("train", "fit")),
ParamInt$new("num_workers", default = 0L, tags = c("train", "fit", "predict")),
ParamLgl$new("shuffle", default = TRUE, tags = c("train", "fit")),
ParamLgl$new("best_weights", default = FALSE, tags = c("train", "callbacks")),
ParamLgl$new("early_stopping", default = FALSE, tags = c("train", "callbacks")),
ParamDbl$new("min_delta", default = 0, tags = c("train", "early")),
ParamInt$new("patience", default = 10, tags = c("train", "early"))
)
)
ps$add_dep("rho", "optimizer", CondEqual$new("adadelta"))
ps$add_dep("eps", "optimizer", CondAnyOf$new(setdiff(optimizers, c("asgd", "rprop", "sgd"))))
ps$add_dep("lr", "optimizer", CondEqual$new("adadelta"))
ps$add_dep(
"weight_decay", "optimizer",
CondAnyOf$new(setdiff(optimizers, c("rprop", "sparse_adam"))))
ps$add_dep("learning_rate", "optimizer", CondAnyOf$new(setdiff(optimizers, "adadelta")))
ps$add_dep("lr_decay", "optimizer", CondEqual$new("adadelta"))
ps$add_dep("betas", "optimizer", CondAnyOf$new(c("adam", "adamax", "adamw", "sparse_adam")))
ps$add_dep("amsgrad", "optimizer", CondAnyOf$new(c("adam", "adamw")))
ps$add_dep("lambd", "optimizer", CondEqual$new("asgd"))
ps$add_dep("t0", "optimizer", CondEqual$new("asgd"))
ps$add_dep("momentum", "optimizer", CondAnyOf$new(c("sgd", "rmsprop")))
ps$add_dep("centered", "optimizer", CondEqual$new("rmsprop"))
ps$add_dep("etas", "optimizer", CondEqual$new("rprop"))
ps$add_dep("step_sizes", "optimizer", CondEqual$new("rprop"))
ps$add_dep("dampening", "optimizer", CondEqual$new("sgd"))
ps$add_dep("nesterov", "optimizer", CondEqual$new("sgd"))
ps$add_dep("min_delta", "early_stopping", CondEqual$new(TRUE))
ps$add_dep("patience", "early_stopping", CondEqual$new(TRUE))
super$initialize(
id = "surv.deepsurv",
feature_types = c("integer", "numeric"),
predict_types = c("crank", "distr"),
param_set = ps,
man = "mlr3learners.pycox::surv.deepsurv",
packages = "mlr3learners.pycox"
)
}
),
private = list(
.train = function(task) {
# Prepare data and optionally standardise outcome
pars = self$param_set$get_values(tags = "prep")
data = mlr3misc::invoke(
prepare_train_data,
task = task,
.args = pars
)
x_train = data$x_train
y_train = data$y_train
# Set-up network architecture
pars = self$param_set$get_values(tags = "net")
net = mlr3misc::invoke(
torchtuples$practical$MLPVanilla,
in_features = x_train$shape[1],
num_nodes = reticulate::r_to_py(as.integer(pars$num_nodes)),
out_features = 1L,
activation = mlr3misc::invoke(get_activation,
construct = FALSE,
.args = self$param_set$get_values(tags = "act")),
output_bias = FALSE,
.args = pars[names(pars) %nin% "num_nodes"]
)
# Get optimizer and set-up model
pars = self$param_set$get_values(tags = "mod")
model = mlr3misc::invoke(
pycox$models$CoxPH,
net = net,
optimizer = mlr3misc::invoke(get_optim,
net = net,
.args = self$param_set$get_values(tags = "opt")),
.args = pars
)
# Optionally internally optimise learning rate for all optimizers except Adadelta
pars = self$param_set$get_values(tags = "lrf")
if (!is.null(pars$optimizer) && pars$optimizer != "adadelta") {
if (!is.null(pars$lr_finder) && pars$lr_finder) {
lrfinder = mlr3misc::invoke(
model$lr_finder,
input = x_train,
target = y_train,
.args = pars[names(pars) %nin% c("optimizer", "adadelta")]
)
model$optimizer$set_lr(lrfinder$get_best_lr())
}
}
# Optionally get callbacks
pars = self$param_set$get_values(tags = "callbacks")
early_stopping = !is.null(pars$early_stopping) && pars$early_stopping
if (!early_stopping && !is.null(pars$best_weights) && pars$best_weights) {
callbacks = reticulate::r_to_py(list(torchtuples$callbacks$BestWeights()))
} else if (early_stopping) {
callbacks = reticulate::r_to_py(list(
mlr3misc::invoke(torchtuples$callbacks$EarlyStopping,
.args = self$param_set$get_values(tags = "early"))
))
} else {
callbacks = NULL
}
# Fit model
pars = self$param_set$get_values(tags = "fit")
mlr3misc::invoke(
model$fit,
input = x_train,
target = y_train,
callbacks = callbacks,
val_data = data$val,
.args = pars
)
},
.predict = function(task) {
# compute baselines
self$model$model$compute_baseline_hazards()
# get test data
x_test = task$data(cols = task$feature_names)
x_test = reticulate::r_to_py(x_test)$values$astype("float32")
# predict survival probabilities
pars = self$param_set$get_values(tags = "predict")
surv = mlr3misc::invoke(
self$model$model$predict_surv_df,
x_test,
.args = pars
)
# cast to distr6
x = rep(list(list(x = round(as.numeric(rownames(surv)), 5), pdf = 0)), task$nrow)
for (i in seq_len(task$nrow)) {
# fix for rounding errors
x[[i]]$pdf = round(1 - surv[, i], 6)
x[[i]]$pdf = c(x[[i]]$pdf[1], diff(x[[i]]$pdf))
x[[i]]$pdf[x[[i]]$pdf < 0.000001] = 0L
}
distr = distr6::VectorDistribution$new(
distribution = "WeightedDiscrete", params = x,
decorators = c("CoreStatistics", "ExoticStatistics"))
# return prediction object
mlr3proba::PredictionSurv$new(
task = task,
distr = distr,
crank = distr$mean()
)
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.