#' @title GLM with Elastic Net Regularization Survival Learner
#'
#' @usage NULL
#' @aliases mlr_learners_surv.glmnet
#' @format [R6::R6Class()] inheriting from [LearnerSurv].
#' @include LearnerSurv.R
#'
#' @section Construction:
#' ```
#' LearnerSurvGlmnet$new()
#' mlr_learners$get("surv.glmnet")
#' lrn("surv.glmnet")
#' ```
#'
#' @description
#' Generalized linear models with elastic net regularization.
#' Calls [glmnet::cv.glmnet()] from package \CRANpkg{glmnet}.
#'
#' @references
#' Jerome Friedman, Trevor Hastie, Robert Tibshirani (2010).
#' Regularization Paths for Generalized Linear Models via Coordinate Descent.
#' Journal of Statistical Software, 33(1), 1-22.
#' \doi{10.18637/jss.v033.i01}.
#'
#' @export
#' @template seealso_learner
#' @examples
#' library(mlr3)
#' task = tgen("simsurv")$generate(200)
#' learner = lrn("surv.glmnet")
#' resampling = rsmp("cv", folds = 3)
#' resample(task, learner, resampling)
LearnerSurvGlmnet = R6Class("LearnerSurvGlmnet", inherit = LearnerSurv,
public = list(
initialize = function() {
super$initialize(
id = "surv.glmnet",
param_set = ParamSet$new(
params = list(
ParamDbl$new(id = "alpha", default = 1, lower = 0, upper = 1, tags = "train"),
ParamInt$new(id = "nfolds", lower = 3L, default = 10L, tags = "train"),
ParamFct$new(id = "type.measure", levels = c("deviance", "class", "auc", "mse", "mae"), default = "deviance", tags = "train"),
ParamDbl$new(id = "s", lower = 0, special_vals = list("lambda.1se", "lambda.min"), default = "lambda.1se", tags = "predict"),
ParamInt$new(id = "nlambda", default = 100L, lower = 1L, tags = "train"),
ParamDbl$new(id = "lambda.min.ratio", lower = 0, upper = 1, tags = "train"),
ParamUty$new(id = "lambda", tags = "train"),
ParamLgl$new(id = "standardize", default = TRUE, tags = "train"),
ParamLgl$new(id = "intercept", default = TRUE, tags = "train"),
ParamDbl$new(id = "thresh", default = 1e-07, lower = 0, tags = "train"),
ParamInt$new(id = "dfmax", lower = 0L, tags = "train"),
ParamInt$new(id = "pmax", lower = 0L, tags = "train"),
ParamInt$new(id = "exclude", lower = 1L, tags = "train"),
ParamDbl$new(id = "penalty.factor", lower = 0, upper = 1, tags = "train"),
ParamUty$new(id = "lower.limits", tags = "train"),
ParamUty$new(id = "upper.limits", tags = "train"),
ParamInt$new(id = "maxit", default = 100000L, lower = 1L, tags = "train"),
ParamFct$new(id = "type.logistic", levels = c("Newton", "modified.Newton"), tags = "train"),
ParamFct$new(id = "type.multinomial", levels = c("ungrouped", "grouped"), tags = "train"),
ParamDbl$new(id = "fdev", default = 1.0e-5, lower = 0, upper = 1, tags = "train"),
ParamDbl$new(id = "devmax", default = 0.999, lower = 0, upper = 1, tags = "train"),
ParamDbl$new(id = "eps", default = 1.0e-6, lower = 0, upper = 1, tags = "train"),
ParamDbl$new(id = "big", default = 9.9e35, tags = "train"),
ParamInt$new(id = "mnlam", default = 5L, lower = 1L, tags = "train"),
ParamDbl$new(id = "pmin", default = 1.0e-9, lower = 0, upper = 1, tags = "train"),
ParamDbl$new(id = "exmx", default = 250.0, tags = "train"),
ParamDbl$new(id = "prec", default = 1e-10, tags = "train"),
ParamInt$new(id = "mxit", default = 100L, lower = 1L, tags = "train")
)
),
feature_types = c("integer", "numeric"),
properties = "weights",
packages = "glmnet"
)
},
train_internal = function(task) {
pars = self$param_set$get_values(tags = "train")
data = as.matrix(task$data(cols = task$feature_names))
target = task$truth()
if ("weights" %in% task$properties) {
pars$weights = task$weights$weight
}
saved_ctrl = glmnet::glmnet.control()
on.exit(invoke(glmnet::glmnet.control, .args = saved_ctrl))
glmnet::glmnet.control(factory = TRUE)
is_ctrl_pars = (names(pars) %in% names(saved_ctrl))
if (any(is_ctrl_pars)) {
do.call(glmnet::glmnet.control, pars[is_ctrl_pars])
pars = pars[!is_ctrl_pars]
}
invoke(glmnet::cv.glmnet, x = data, y = target, family = "cox", .args = pars)
},
predict_internal = function(task) {
pars = self$param_set$get_values(tags = "predict")
newdata = as.matrix(task$data(cols = task$feature_names))
response = invoke(predict, self$model, newx = newdata, type = "link", .args = pars)
PredictionSurv$new(task = task, risk = drop(response))
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.