#' @title Cox proportional hazard model
#' @description Train the Cox model through cross-validation and select the optimal survival classification threshold.
#' A regularized Cox approach which performs feature selection is also implemeted. For regularize cox, the optimal set of variables is
#' selected through cross-validation and used to train the final model on the complete data
#' @name train.cox
#' @param form survival formula
#' @param dat data frame
#' @param predict.times survival prediction times
#' @param trControl list of control parameters:
#' \enumerate{
#' \item number: number of cross-validations
#' \item regularize: train regularize cox?
#' }
#' @param parallel run cross-validation in parallel? Uses mclapply which works only on linux
#' @param tuneLength same as tuneLength in the caret package
#' @param \dots further arguments passed to caret or other methods.
#' @return returns a list with items:
#' \itemize{
#' \item{finalModel: }{ final model trained on the complete data (dat) using optimal tuning paramters}
#' \item{fitted: }{ predictions on complete data (dat)}
#' \item{threshold: }{ optimal classification threshold}
#' \item{resamples: }{ cross-validation results: predictions on resampled data }
#' \item{predict.times: }{ survival prediction times}
#' \item{bestTune: }{ optimal tuning parameters}
#' }
NULL
#' @rdname train.cox
#' @export
train_cox = function(form, dat, newdata = NULL, predict.times, trControl = NULL, parallel = FALSE, mc.cores = 2,
seed = 123, ...){
###
prob.ext <- NULL
predict.times <- sort(predict.times)
if(parallel) {
pfun <- get("mclapply")
set.seed(seed, "L'Ecuyer")
} else {
set.seed(seed)
pfun = get("lapply")
}
if(is.null(trControl))
trControl = list(number = 5, regularize = FALSE, lambda = seq(0.001,0.1,by = 0.001), maxit = 10000, alpha = 1, nfolds = 3)
resp.vars <- all.vars(form[[2]]) ##
time = resp.vars[1]; status = resp.vars[2]
### cox model requires time > 0 ... add a small number to times == 0
dat[, time][dat[, time] <= 0] <- dat[, time][dat[, time] <= 0] + 0.01
ix.cv <- createFolds(dat[, status], k= trControl$number, list=FALSE)
## cv
resamples <- lapply(1:trControl$number, function(kk){
dat.trn <- dat[kk!=ix.cv, ]
dat.tst <- dat[kk==ix.cv, ]
obs = dat.trn[, status]
ts = dat.trn[, time]
Obs = dat.tst[, status]
Ts = dat.tst[, time]
#model <- survival::coxph(form, data = dat.trn, ...)
cox.mod <- COX(form, dat.trn, trControl = trControl)
model = cox.mod
form = cox.mod$formula
pp <- predictSurvProb_Cox(object = model, newdata = dat.trn, times= predict.times)
PP <- predictSurvProb_Cox(object = model, newdata = dat.tst, times= predict.times)
perf <- lapply(1:ncol(pp), function(xx) {
y = dat.trn[, status]
y[(predict.times[xx] < ts)] <- 0
Y = dat.tst[, status]
Y[(predict.times[xx] < Ts)] <- 0
thresh <- opt.thresh(prob=1-pp[, xx], obs=y)
AUC = Performance.measures(pred=1-PP[, xx], obs= Y, threshold= thresh)$AUC
list(AUC = AUC, thresh = thresh)
} )
AUC <- sapply(perf, function(xx) xx$AUC)
threshold <- sapply(perf, function(xx) xx$thresh)
colnames(PP) <- paste0("T", 1:length(predict.times))
prob <- cbind.data.frame(status = dat.tst[, status], time = dat.tst[, time], 1-PP)
names(prob) <- c(status, time, paste0("T", 1:length(predict.times)))
rhs.vars <- all.vars(form[[3]])
if(!is.null(newdata)){
prob.ext <- predictSurvProb_Cox(object = model, newdata = newdata, times= predict.times)
colnames(prob.ext) <- paste0("T", 1:length(predict.times))
prob.ext <- cbind.data.frame(status = newdata[, status], time = newdata[, time], 1-prob.ext)
names(prob.ext) <- c(status, time, paste0("T", 1:length(predict.times)))
}
res <- list(prob = prob, prob.ext = prob.ext, AUC=AUC, threshold = threshold, rhs.vars = rhs.vars)
return(res)
})
prob <- lapply(resamples, function(xx) xx$prob)
names(prob) <- paste0("cv", 1:trControl$number)
if(!is.null(newdata)){
prob.ext <- lapply(resamples, function(xx) xx$prob.ext)
names(prob.ext) <- paste0("cv", 1:trControl$number)
}
AUC <- do.call(rbind, lapply(resamples, function(xx) xx$AUC))
which.max1 <- function(x) {if(all(is.na(x))) return(NA) else return(which.max(x))}
ix <- as.numeric(apply(AUC, 2, which.max1))
threshold <- do.call(rbind, lapply(resamples, function(xx) xx$threshold))
threshold <- sapply(1:length(predict.times), function(ii) {ifelse(is.na(ix[ii]), 0.5, threshold[ix[ii], ii])})
names(threshold) <- paste0("T", 1:length(predict.times))
ix <- which.max(apply(AUC, 1, mean, na.rm = TRUE))
rhs.vars <- lapply(resamples, function(xx) xx$rhs.vars)[[ix]]
### train final model on all data
form <- formula(paste("Surv(", time, ", ", status, ") ~", paste0(rhs.vars, collapse = " + ")))
trControl$regularize = FALSE
finalModel <- COX(form, dat, trControl = trControl)
fitted <- 1-predictSurvProb_Cox(object =finalModel, newdata = dat, times= predict.times)
colnames(fitted) <- paste0("T", 1:length(predict.times))
fitted <- cbind(obs = dat[, status], fitted)
return(list(finalModel = finalModel, formula = form, fitted = fitted, threshold = threshold, resamples = prob, resamples.ext = prob.ext,
predict.times = predict.times, bestTune = NULL))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.