#' @title Survival Cox Model with Likelihood Based Boosting Learner
#'
#' @name mlr_learners_surv.coxboost
#'
#' @description
#' A [mlr3proba::LearnerSurv] implementing coxboost from package
#' \CRANpkg{CoxBoost}.
#' Calls [CoxBoost::CoxBoost()].
#'
#' @details
#' Use [LearnerSurvCoxboost] and [LearnerSurvCVCoxboost] for Cox boosting without and with internal
#' cross-validation of boosting step number, respectively. Tuning using the internal optimizer in
#' [LearnerSurvCVCoxboost] may be more efficient when tuning `stepno` only. However, for tuning
#' multiple hyperparameters, \CRANpkg{mlr3tuning} and [LearnerSurvCoxboost] will likely give better
#' results.
#'
#' @templateVar id surv.coxboost
#' @template section_dictionary_learner
#'
#' @references
#' Binder, H., Allignol, A., Schumacher, M., and Beyersmann, J. (2009).
#' Boosting for high-dimensional time-to-event data with competing risks.
#' Bioinformatics, 25:890-896.
#'
#' @template seealso_learner
#' @template example
#' @export
# <Adapt the name to your learner. For regression learners inherit = LearnerRegr>
LearnerSurvCoxboost = R6Class("LearnerSurvCoxboost",
inherit = LearnerSurv,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ParamSet$new(
params = list(
ParamUty$new(id = "unpen.index", tags = "train"),
ParamLgl$new(id = "standardize", default = TRUE, tags = "train"),
ParamInt$new(id = "stepno", default = 100, lower = 0, tags = "train"),
ParamDbl$new(id = "penalty", tags = "train"),
ParamFct$new(
id = "criterion", default = "pscore",
levels = c("pscore", "score", "hpscore", "hscore"), tags = "train"),
ParamDbl$new(id = "stepsize.factor", default = 1, tags = "train"),
ParamFct$new(
id = "sf.scheme", default = "sigmoid", levels = c("sigmoid", "linear"),
tags = "train"),
ParamUty$new(id = "pendistmat", tags = "train"),
ParamUty$new(id = "connected.index", tags = "train"),
ParamLgl$new(id = "x.is.01", default = FALSE, tags = "train"),
ParamLgl$new(id = "return.score", default = TRUE, tags = "train"),
ParamLgl$new(id = "trace", default = FALSE, tags = "train"),
ParamUty$new(id = "at.step", tags = "predict")
))
super$initialize(
# see the mlr3book for a description: https://mlr3book.mlr-org.com/extending-mlr3.html
id = "surv.coxboost",
packages = "CoxBoost",
feature_types = c("integer", "numeric", "factor", "logical"),
predict_types = c("distr", "crank", "lp"),
param_set = ps,
properties = "weights",
# the help file name is the one used as @name in the roxygen2 block
man = "mlr3learners.coxboost::mlr_learners_surv.coxboost"
)
}
# <Add method for importance, if learner supports that>
# <See mlr3learners.randomForest for an example>
# <Add method for oob_error, if learner supports that.>
),
private = list(
.train = function(task) {
pars = self$param_set$get_values(tags = "train")
if ("weights" %in% task$properties) {
pars$weights = as.numeric(task$weights$weight)
}
with_package("CoxBoost", {
mlr3misc::invoke(
CoxBoost::CoxBoost,
time = task$truth()[, 1],
status = task$truth()[, 2],
x = model.matrix(
~.,
as.data.frame(task$data(cols = task$feature_names)))[, -1, drop = FALSE],
.args = pars
)
})
},
.predict = function(task) {
lp = as.numeric(mlr3misc::invoke(predict,
self$model,
newdata = model.matrix(
~.,
as.data.frame(task$data(cols = task$feature_names)))[, -1,
drop = FALSE],
.args = self$param_set$get_values(tags = "predict"),
type = "lp"))
cdf = mlr3misc::invoke(predict,
self$model,
newdata = model.matrix(
~.,
as.data.frame(task$data(cols = task$feature_names)))[, -1,
drop = FALSE],
.args = self$param_set$get_values(tags = "predict"),
type = "CIF",
times = sort(unique(self$model$time)))
# define WeightedDiscrete distr6 object from predicted survival function
x = rep(list(list(x = sort(unique(self$model$time)), cdf = 0)), task$nrow)
for (i in 1:task$nrow) {
x[[i]]$cdf = cdf[i, ]
}
distr = distr6::VectorDistribution$new(
distribution = "WeightedDiscrete", params = x,
decorators = c("CoreStatistics", "ExoticStatistics"))
mlr3proba::PredictionSurv$new(task = task, crank = lp, distr = distr, lp = lp)
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.