xgb.train.surv: Train a survival xgboost including baseline hazard

View source: R/xgb.train.surv.R

xgb.train.survR Documentation

Train a survival xgboost including baseline hazard

Description

xgb.train.surv is a thin wrapper around xgboost::xgb.train that produces a xgb.Booster.surv object. This object has a predict method that enables full survival curve prediction in addition to the usual relative risk score predictions.

Usage

xgb.train.surv(
  params = list(),
  data,
  label,
  weight = NULL,
  nrounds,
  watchlist = list(),
  verbose = 1,
  print_every_n = 1L,
  early_stopping_rounds = NULL,
  save_period = NULL,
  save_name = "xgboost_surv.model",
  xgb_model = NULL,
  callbacks = list(),
  ...
)

Arguments

params

same as in xgb.train. If provided then objective must be set to "survival:cox" and eval_metric must be set to "cox-nloglik"

data

must be a covariates **matrix**

label

survival label. These are survival times with negative magnitude for censored cases (so a case where someone survived 10 days and hasn't died yet would be coded as -10)

weight

(optional) weight vector for training samples

nrounds

same as in xgb.train

watchlist

same as in xgb.train. This can be tricky, see example

verbose

same as in xgb.train

print_every_n

same as in xgb.train

early_stopping_rounds

same as in xgb.train

save_period

same as in xgb.train

save_name

same as in xgb.train, defaults to

xgb_model

same as in xgb.train

callbacks

same as in xgb.train

...

additional arguments passed to xgb.train

Details

The xgboost package supports the cox proportional hazards model but the predict method returns only the risk score (which is equivalent to exp(X\beta) or type = "risk" in survival::coxph). This function returns a xgb.Booster.surv object which enables prediction of both the risk score as well the entire survival curve. Baseline hazard rate is obtained using the survival::survfit function with stype = 2 to obtain the Breslow estimator

Value

an object of class xgb.Booster.surv

See Also

predict.xgb.Booster.surv

Examples

library(survival)
data("lung")
library(survXgboost)
library(xgboost)
lung <- lung[complete.cases(lung), ] # doesn't handle missing values at the moment
lung$status <- lung$status - 1 # format status variable correctly such that 1 is event/death and 0 is censored/alive
label <- ifelse(lung$status == 1, lung$time, -lung$time)

val_ind <- sample.int(nrow(lung), 0.1 * nrow(lung))
x_train <- as.matrix(lung[-val_ind, !names(lung) %in% c("time", "status")])
x_label <- label[-val_ind]
x_val <- xgb.DMatrix(as.matrix(lung[val_ind, !names(lung) %in% c("time", "status")]),
                     label = label[val_ind])

# train surv_xgboost
surv_xgboost_model <- xgb.train.surv(
  params = list(
    objective = "survival:cox",
    eval_metric = "cox-nloglik",
    eta = 0.05 # larger eta leads to algorithm not converging, resulting in NaN predictions
  ), data = x_train, label = x_label,
  watchlist = list(val2 = x_val),
  nrounds = 1000, early_stopping_rounds = 30
)

# predict survival curves
times <- seq(10, 1000, 50)
survival_curves <- predict(object = surv_xgboost_model, newdata = x_train, type = "surv", times = times)
matplot(times, t(survival_curves[1:5, ]), type = "l")

# predict risk score
risk_scores <- predict(object = surv_xgboost_model, newdata = x_train, type = "risk")
hist(risk_scores)

IyarLin/survXgboost documentation built on Feb. 4, 2024, 5:38 p.m.