#' @title Predict Survival with classifcation methods
#' @description Train machine learning classification models on time to event data using the caret package
#
#' @name train.classifier
#' @param form survival formula
#' @param dat data frame
#' @param method classifcation algorithm. The following algorithms have been implemented.
#' \enumerate{
#' \item glm logistic regression
#' \item glmnet elastic net
#' \item gbm gradient boosting machine
#' \item ranger random forest
#' \item svmRadial support vector machine with radial basis kernel
#' \item xgbTree extreme gradient boosting machine
#' }
#' @param predict.times survival prediction times
#' @param trControl control parameters for the caret train function. Set to NULL to use a default 5-fold cross-validation
#'
#' @param parallel run cross-validation in parallel? Uses mclapply which works only on linux
#' @param tuneLength same as tuneLength in the caret package
#' @param \dots further arguments passed to caret or other methods.
#' @return returns a list with items:
#' \itemize{
#' \item{finalModel: }{ final model trained on the complete data (dat) using optimal tuning paramters}
#' \item{fitted: }{ predictions on complete data (dat)}
#' \item{threshold: }{ optimal classification threshold}
#' \item{resamples: }{ cross-validation results: predictions on resampled data }
#' \item{predict.times: }{ survival prediction times}
#' \item{bestTune: }{ optimal tuning parameters}
#' \item{method: }{ classification algorithm}
#' }
#' @importFrom caret train trainControl
#' @importFrom gbm gbm gbm.fit
#' @importFrom ranger ranger
#'
#
NULL
#' @rdname train.classifier
#' @export
train_classifier = function(form, dat, method = "gbm", predict.times, trControl=NULL, parallel = FALSE, mc.cores = 2,
seed = 123, ...){
###
dots <- list(...)
#predict.times <- sort(predict.times)
if(parallel) {
pfun <- get("mclapply")
set.seed(seed, "L'Ecuyer")
} else {
set.seed(seed)
pfun = get("lapply")
}
if(is.null(trControl)){
tuneLength = 3
cntrl <- trainControl(method = "cv", number = 5, summaryFunction = twoClassSummary, classProbs = TRUE,
savePredictions = "final", allowParallel = FALSE)
} else {
tuneLength = trControl$tuneLength
cfun1 <- get("trainControl")
cntrl <- trainControl()
nme <- names(trControl)[names(trControl)%in%methods::formalArgs(cfun1)]
cntrl[nme] <- trControl[nme]
}
#form <- formula(paste('Surv(', 'time', ',', status, ') ~ ',
# paste0(c(rhs.vars), collapse = " + ")))
### use same survival formular for classification
resp.vars <- all.vars(form[[2]]) ##
time = resp.vars[1]; status = resp.vars[2]
rhs.vars <- all.vars(form[[3]])
form <- formula(paste(status, ' ~ ', paste0(c(rhs.vars), collapse = " + ")))
extra.para <- Models(method)
pkg <- getModelInfo(method, regex = FALSE)[[1]]$library
for(i in seq(along=pkg)) do.call("requireNamespace", list(package = pkg[i]))
fun <- classifier.name(method)
cfun <- get(fun)
csel <- dots[names(dots)%in%methods::formalArgs(cfun)]
csel <- csel[unique(names(csel))]
obs = dat[, status]
ti = dat[, time]
pp <- c()
model <- c()
thresh <- AUC <- numeric(length(predict.times))
for(ii in 1:length(predict.times)){
y = dat[, status]
dd <- dat
y[(predict.times[ii] < ti)] <- 0
# y[(predict.times[ii] >= v & predict.times[ii] < ti)] <- 0
dd[, status] <- y
dd[, status] <- factor(ifelse(dd[, status]==1, "Yes", "No"))
def.args <- list(form = form, data = dd, method = method, trControl = cntrl,
tuneLength = tuneLength, metric = "ROC")
#def.args <- list(x = dd[, rhs.vars], y = dd[, status], method = method, trControl = cntrl,
# tuneLength = tuneLength, metric = "ROC")
args <- c(def.args, csel, extra.para)
args[sapply(args, is.null)] <- NULL
args <- args[unique(names(args))]
############ training
model[[ii]] <- do.call(train, args)
pred <- model[[ii]]$pred
pred <- pred[, c("obs", "Yes", "rowIndex", "Resample")]
tx <- ti[pred$rowIndex]
cls <- obs[pred$rowIndex]
pred$rowIndex <- NULL
pred$time = tx
pred$class = cls
pred$time.id = ii
pp[[ii]] <- pred
## find thresholds
samp <- unique(pred$Resample)
perf <- lapply(samp, function(xx) {
p <- pred$Yes[pred$Resample != xx]
yy <- pred$obs[pred$Resample != xx]
yy = ifelse(yy == "Yes", 1, 0)
thsh <- opt.thresh(prob=p, obs=yy)
p <- pred$Yes[pred$Resample == xx]
yy <- pred$obs[pred$Resample == xx]
yy = ifelse(yy == "Yes", 1, 0)
AUC = Performance.measures(pred= p, obs= yy, threshold= thsh)
list(AUC = AUC$AUC, threshold = thsh)
})
A <- sapply(perf, function(xx) xx$AUC)
ix <- which.max(A)
AUC[ii] <- A[ix]
thresh[ii] <- sapply(perf, function(xx) xx$thresh)[ix]
}
prob <- do.call(rbind, pp)
fitted <- do.call(cbind, lapply(model, function(xx)
predict(xx, newdata = dat, type = "prob")[,"Yes"]))
colnames(fitted) <- paste0("T", 1:length(predict.times))
#pp <- apply(fitted, 2, rescale, to = c(1, 10))
#colnames(pp) <- paste0(paste0("T", 1:length(predict.times)), ".scale")
fitted <- cbind(obs = dat[, status], fitted)
#fitted <- cbind(fitted, Total.Risk = rescale(apply(fitted[, paste0("T", 1:length(predict.times))], 1, sum),to = c(1, 10)))
threshold = thresh
names(model) <- paste0("T", 1:length(predict.times))
names(threshold) <- paste0("T", 1:length(predict.times))
return(list(finalModel = model, fitted = fitted, threshold = threshold, resamples = prob,
predict.times = predict.times, bestTune = NULL, method = method))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.