R/xgb.train.surv.R

Defines functions xgb.train.surv

Documented in xgb.train.surv

#' @title Train a survival xgboost including baseline hazard
#'
#' @description \code{xgb.train.surv} is a thin wrapper around \code{xgboost::xgb.train} that
#' produces a \code{xgb.Booster.surv} object. This object has a predict method that enables full
#' survival curve prediction in addition to the usual relative risk score predictions.
#' @param params same as in xgb.train. If provided then objective must be set to "survival:cox"
#' and eval_metric must be set to "cox-nloglik"
#' @param data must be a covariates **matrix**
#' @param label survival label. These are survival times with negative magnitude for censored
#' cases (so a case where someone survived 10 days and hasn't died yet would be coded as -10)
#' @param weight (optional) weight vector for training samples
#' @param nrounds same as in xgb.train
#' @param watchlist same as in xgb.train. This can be tricky, see example
#' @param verbose same as in xgb.train
#' @param print_every_n same as in xgb.train
#' @param early_stopping_rounds same as in xgb.train
#' @param save_period same as in xgb.train
#' @param save_name same as in xgb.train, defaults to
#' @param xgb_model same as in xgb.train
#' @param callbacks same as in xgb.train
#' @param ... additional arguments passed to xgb.train
#' @return an object of class \code{xgb.Booster.surv}
#' @example inst/train_predict_xgb.Booster.surv.R
#' @seealso \code{\link{predict.xgb.Booster.surv}}
#' @details The xgboost package supports the cox proportional hazards model but the predict method
#' returns only the risk score (which is equivalent to \eqn{exp(X\beta)} or \code{type = "risk"} in \code{survival::coxph}).
#' This function returns a \code{xgb.Booster.surv} object which enables prediction of both the risk score as well
#' the entire survival curve. Baseline hazard rate is obtained using the \code{survival::survfit} function with stype = 2 to obtain the Breslow estimator
#' @export

xgb.train.surv <- function(params = list(), data, label, weight = NULL, nrounds,
                           watchlist = list(), verbose = 1, print_every_n = 1L,
                           early_stopping_rounds = NULL, save_period = NULL,
                           save_name = "xgboost_surv.model", xgb_model = NULL, callbacks = list(), ...) {
  if (length(params) > 0) {
    if (params$objective != "survival:cox") stop("params objective must be set to survival:cox")
    if (params$eval_metric != "cox-nloglik") stop("params eval_metric must be set to cox-nloglik")
  } else {
    params <- list(
      objective = "survival:cox",
      eval_metric = "cox-nloglik"
    )
  }

  if(is.null(weight)) weight <- rep(1, nrow(data))

  data_DMatrix <- xgboost:::xgb.DMatrix(data = data, label = label, weight = weight)

  xgboost_model <- xgboost:::xgb.train(
    params = params, data = data_DMatrix, nrounds = nrounds, watchlist = watchlist, verbose = verbose,
    print_every_n = print_every_n, early_stopping_rounds = early_stopping_rounds, save_period = save_period,
    save_name = "surv_xgboost.model", xgb_model = xgb_model, callbacks = callbacks, ...
  )

  # generate baseline hazard
  data_data.frame <- data.frame(data, time = abs(label), status = ifelse(sign(label) == 1, 1, 0))

  breslow_estimate <- survival:::survfit(formula = Surv(time, status) ~ 1, data = data_data.frame, stype = 2)

  # Save mean prediction to center linear predictors to 0 (so HRs to 1)

  xgboost_model$mean_prediction <- mean(predict(xgboost_model, data_DMatrix, outputmargin = TRUE))

  xgboost_model$baseline_hazard <- data.frame(hazard = c(0, breslow_estimate$cumhaz), time = c(0, breslow_estimate$time))
  class(xgboost_model) <- c("xgb.Booster.surv", "xgb.Booster")
  return(xgboost_model)
}
IyarLin/survXgboost documentation built on Feb. 4, 2024, 5:38 p.m.