mlr_learners_surv.xgboost.aft: Extreme Gradient Boosting AFT Survival Learner

mlr_learners_surv.xgboost.aftR Documentation

Extreme Gradient Boosting AFT Survival Learner

Description

eXtreme Gradient Boosting regression using an Accelerated Failure Time objective. Calls xgboost::xgb.train() from package xgboost with objective set to survival:aft and eval_metric to aft-nloglik.

Prediction types

This learner returns three prediction types:

  1. response: the estimated survival time T for each test observation.

  2. lp: a vector of linear predictors (relative risk scores), one per observation, estimated as -log(T). Higher survival time denotes lower risk.

  3. crank: same as lp.

Initial parameter values

  • nrounds is initialized to 1000.

  • nthread is initialized to 1 to avoid conflicts with parallelization via future.

  • verbose is initialized to 0.

Early stopping

Early stopping can be used to find the optimal number of boosting rounds. The early_stopping_set parameter controls which set is used to monitor the performance. By default, early_stopping_set = "none" which disables early stopping. Set early_stopping_set = "test" to monitor the performance of the model on the test set while training. The test set for early stopping can be set with the "test" row role in the mlr3::Task. Additionally, the range must be set in which the performance must increase with early_stopping_rounds and the maximum number of boosting rounds with nrounds. While resampling, the test set is automatically applied from the mlr3::Resampling. Not that using the test set for early stopping can potentially bias the performance scores.

Dictionary

This Learner can be instantiated via lrn():

lrn("surv.xgboost.aft")

Meta Information

Parameters

Id Type Default Levels Range
aft_loss_distribution character normal normal, logistic, extreme -
aft_loss_distribution_scale numeric - (-\infty, \infty)
alpha numeric 0 [0, \infty)
base_score numeric 0.5 (-\infty, \infty)
booster character gbtree gbtree, gblinear, dart -
callbacks untyped list() -
colsample_bylevel numeric 1 [0, 1]
colsample_bynode numeric 1 [0, 1]
colsample_bytree numeric 1 [0, 1]
disable_default_eval_metric logical FALSE TRUE, FALSE -
early_stopping_rounds integer NULL [1, \infty)
eta numeric 0.3 [0, 1]
feature_selector character cyclic cyclic, shuffle, random, greedy, thrifty -
feval untyped NULL -
gamma numeric 0 [0, \infty)
grow_policy character depthwise depthwise, lossguide -
interaction_constraints untyped - -
iterationrange untyped - -
lambda numeric 1 [0, \infty)
lambda_bias numeric 0 [0, \infty)
max_bin integer 256 [2, \infty)
max_delta_step numeric 0 [0, \infty)
max_depth integer 6 [0, \infty)
max_leaves integer 0 [0, \infty)
maximize logical NULL TRUE, FALSE -
min_child_weight numeric 1 [0, \infty)
missing numeric NA (-\infty, \infty)
monotone_constraints integer 0 [-1, 1]
normalize_type character tree tree, forest -
nrounds integer - [1, \infty)
nthread integer 1 [1, \infty)
ntreelimit integer - [1, \infty)
num_parallel_tree integer 1 [1, \infty)
one_drop logical FALSE TRUE, FALSE -
print_every_n integer 1 [1, \infty)
process_type character default default, update -
rate_drop numeric 0 [0, 1]
refresh_leaf logical TRUE TRUE, FALSE -
sampling_method character uniform uniform, gradient_based -
sample_type character uniform uniform, weighted -
save_name untyped - -
save_period integer - [0, \infty)
scale_pos_weight numeric 1 (-\infty, \infty)
seed_per_iteration logical FALSE TRUE, FALSE -
skip_drop numeric 0 [0, 1]
strict_shape logical FALSE TRUE, FALSE -
subsample numeric 1 [0, 1]
top_k integer 0 [0, \infty)
tree_method character auto auto, exact, approx, hist, gpu_hist -
tweedie_variance_power numeric 1.5 [1, 2]
updater untyped - -
verbose integer 1 [0, 2]
watchlist untyped NULL -
xgb_model untyped - -
device untyped - -

Super classes

mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvXgboostAFT

Active bindings

internal_valid_scores

The last observation of the validation scores for all metrics. Extracted from model$evaluation_log

internal_tuned_values

Returns the early stopped iterations if early_stopping_rounds was set during training.

validate

How to construct the internal validation data. This parameter can be either NULL, a ratio, "test", or "predefined".

Methods

Public methods

Inherited methods

Method new()

Creates a new instance of this R6 class.

Usage
LearnerSurvXgboostAFT$new()

Method importance()

The importance scores are calculated with xgboost::xgb.importance().

Usage
LearnerSurvXgboostAFT$importance()
Returns

Named numeric().


Method clone()

The objects of this class are cloneable with this method.

Usage
LearnerSurvXgboostAFT$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

Note

To compute on GPUs, you first need to compile xgboost yourself and link against CUDA. See https://xgboost.readthedocs.io/en/stable/build.html#building-with-gpu-support.

Author(s)

bblodfon

References

Chen, Tianqi, Guestrin, Carlos (2016). “Xgboost: A scalable tree boosting system.” In Proceedings of the 22nd ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 785–794. ACM. \Sexpr[results=rd]{tools:::Rd_expr_doi("10.1145/2939672.2939785")}.

Avinash B, Hyunsu C, Toby H (2022). “Survival Regression with Accelerated Failure Time Model in XGBoost.” Journal of Computational and Graphical Statistics. ISSN 15372715, \Sexpr[results=rd]{tools:::Rd_expr_doi("10.1080/10618600.2022.2067548")}.

See Also

Examples


# Define the Learner
learner = mlr3::lrn("surv.xgboost.aft")
print(learner)

# Define a Task
task = mlr3::tsk("grace")

# Create train and test set
ids = mlr3::partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)
print(learner$importance())

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()


mlr-org/mlr3extralearners documentation built on Nov. 11, 2024, 11:11 a.m.