predict.xgb.Booster.surv: Prediction method for xgb.Booster.surv model

View source: R/predict.xgb.Booster.surv.R

predict.xgb.Booster.survR Documentation

Prediction method for xgb.Booster.surv model

Description

predict.xgb.Booster.surv is a method for xgb.Booster.surv objects that enables preidcting either risk (implemented also in the xgboost package) or the full survival curve.

Usage

## S3 method for class 'xgb.Booster.surv'
predict(object, newdata, type = "risk", times = NULL)

Arguments

object

an xgb.Booster.surv object obtained by xgb.train.surv

newdata

a data.frame/matrix to make predictions for

type

either "risk" or "surv"

times

times at which to estimate the survival curve at. Default is original dataset unique death times.

Value

for type = "risk" a vector of risk scores, for type = "surv" a matrix with columns corresponding to times and rows corresponding to input newdata rows.

See Also

xgb.train.surv

Examples

library(survival)
data("lung")
library(survXgboost)
library(xgboost)
lung <- lung[complete.cases(lung), ] # doesn't handle missing values at the moment
lung$status <- lung$status - 1 # format status variable correctly such that 1 is event/death and 0 is censored/alive
label <- ifelse(lung$status == 1, lung$time, -lung$time)

val_ind <- sample.int(nrow(lung), 0.1 * nrow(lung))
x_train <- as.matrix(lung[-val_ind, !names(lung) %in% c("time", "status")])
x_label <- label[-val_ind]
x_val <- xgb.DMatrix(as.matrix(lung[val_ind, !names(lung) %in% c("time", "status")]),
                     label = label[val_ind])

# train surv_xgboost
surv_xgboost_model <- xgb.train.surv(
  params = list(
    objective = "survival:cox",
    eval_metric = "cox-nloglik",
    eta = 0.05 # larger eta leads to algorithm not converging, resulting in NaN predictions
  ), data = x_train, label = x_label,
  watchlist = list(val2 = x_val),
  nrounds = 1000, early_stopping_rounds = 30
)

# predict survival curves
times <- seq(10, 1000, 50)
survival_curves <- predict(object = surv_xgboost_model, newdata = x_train, type = "surv", times = times)
matplot(times, t(survival_curves[1:5, ]), type = "l")

# predict risk score
risk_scores <- predict(object = surv_xgboost_model, newdata = x_train, type = "risk")
hist(risk_scores)

IyarLin/survXgboost documentation built on Feb. 4, 2024, 5:38 p.m.