View source: R/xgb.train.surv.R
| xgb.train.surv | R Documentation | 
xgb.train.surv is a thin wrapper around xgboost::xgb.train that
produces a xgb.Booster.surv object. This object has a predict method that enables full
survival curve prediction in addition to the usual relative risk score predictions.
xgb.train.surv(
  params = list(),
  data,
  label,
  weight = NULL,
  nrounds,
  watchlist = list(),
  verbose = 1,
  print_every_n = 1L,
  early_stopping_rounds = NULL,
  save_period = NULL,
  save_name = "xgboost_surv.model",
  xgb_model = NULL,
  callbacks = list(),
  ...
)
| params | same as in xgb.train. If provided then objective must be set to "survival:cox" and eval_metric must be set to "cox-nloglik" | 
| data | must be a covariates **matrix** | 
| label | survival label. These are survival times with negative magnitude for censored cases (so a case where someone survived 10 days and hasn't died yet would be coded as -10) | 
| weight | (optional) weight vector for training samples | 
| nrounds | same as in xgb.train | 
| watchlist | same as in xgb.train. This can be tricky, see example | 
| verbose | same as in xgb.train | 
| print_every_n | same as in xgb.train | 
| early_stopping_rounds | same as in xgb.train | 
| save_period | same as in xgb.train | 
| save_name | same as in xgb.train, defaults to | 
| xgb_model | same as in xgb.train | 
| callbacks | same as in xgb.train | 
| ... | additional arguments passed to xgb.train | 
The xgboost package supports the cox proportional hazards model but the predict method
returns only the risk score (which is equivalent to exp(X\beta) or type = "risk" in survival::coxph).
This function returns a xgb.Booster.surv object which enables prediction of both the risk score as well
the entire survival curve. Baseline hazard rate is obtained using the survival::survfit function with stype = 2 to obtain the Breslow estimator
an object of class xgb.Booster.surv
predict.xgb.Booster.surv
library(survival)
data("lung")
library(survXgboost)
library(xgboost)
lung <- lung[complete.cases(lung), ] # doesn't handle missing values at the moment
lung$status <- lung$status - 1 # format status variable correctly such that 1 is event/death and 0 is censored/alive
label <- ifelse(lung$status == 1, lung$time, -lung$time)
val_ind <- sample.int(nrow(lung), 0.1 * nrow(lung))
x_train <- as.matrix(lung[-val_ind, !names(lung) %in% c("time", "status")])
x_label <- label[-val_ind]
x_val <- xgb.DMatrix(as.matrix(lung[val_ind, !names(lung) %in% c("time", "status")]),
                     label = label[val_ind])
# train surv_xgboost
surv_xgboost_model <- xgb.train.surv(
  params = list(
    objective = "survival:cox",
    eval_metric = "cox-nloglik",
    eta = 0.05 # larger eta leads to algorithm not converging, resulting in NaN predictions
  ), data = x_train, label = x_label,
  watchlist = list(val2 = x_val),
  nrounds = 1000, early_stopping_rounds = 30
)
# predict survival curves
times <- seq(10, 1000, 50)
survival_curves <- predict(object = surv_xgboost_model, newdata = x_train, type = "surv", times = times)
matplot(times, t(survival_curves[1:5, ]), type = "l")
# predict risk score
risk_scores <- predict(object = surv_xgboost_model, newdata = x_train, type = "risk")
hist(risk_scores)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.