#' @title RSF: Random Survival Forest
#' @description Train the random survival forest through the ranger package. The optimal RSF tuning parameters:
#' min.node.size,mtry, and splitrule can be selected through grid search.
#
#' @name RSF
#' @param form survival formula
#' @param dat data frame
#' @param predict.times survival prediction times
#' @param trControl list of control parameters:
#' \enumerate{
#' \item ntrees: mumber of trees
#' \item number: number of cross-validations
#' \item tuneLength: tuning paramer grid size
#' \item importance: ranger variable importance
#' }
#' @param parallel run cross-validation in parallel? Uses mclapply which works only on linux
#' @param tuneLength same as tuneLength in the caret package
#' @param \dots further arguments passed to caret or other methods.
#' @return returns a list with items:
#' \itemize{
#' \item{finalModel: }{ final model trained on the complete data (dat) using optimal tuning paramters}
#' \item{fitted: }{ predictions on complete data (dat)}
#' \item{threshold: }{ optimal classification threshold}
#' \item{resamples: }{ cross-validation results: predictions on resampled data }
#' \item{predict.times: }{ survival prediction times}
#' \item{bestTune: }{ optimal tuning parameters}
#' }
NULL
#' @rdname RSF
#' @export
train_RSF = function(form, dat, newdata = NULL, predict.times, trControl = NULL, seed = 123, parallel = FALSE, mc.cores = 2, ...){
###
if(is.null(trControl))
trControl = list(number = 5, ntrees = 100, importance = "none", tuneLength = 3)
bestTune <- prob.ext <- NULL
predict.times <- sort(predict.times)
nme <- paste0("T", 1:length(predict.times))
if(parallel) {
pfun <- get("mclapply")
set.seed(seed, "L'Ecuyer")
} else {
set.seed(seed)
pfun = get("lapply")
}
#form <- formula(paste('Surv(', 'time', ',', status, ') ~ ',
# paste0(c(rhs.vars), collapse = " + ")))
resp.vars <- all.vars(form[[2]]) ##
time = resp.vars[1]; status = resp.vars[2]
rhs.vars <- all.vars(form[[3]])
p <- length(rhs.vars)
n = nrow(dat)
ix.cv <- createFolds(dat[, status], k= trControl$number, list=FALSE)
## cv
#resamples <- pfun(1:trControl$number, function(kk, ...){
resamples <- lapply(1:trControl$number, function(kk){
dat.trn <- dat[kk!=ix.cv, ]
dat.tst <- dat[kk==ix.cv, ]
obs = dat.trn[, status]
ts = dat.trn[, time]
Obs = dat.tst[, status]
Ts = dat.tst[, time]
if(is.null(trControl$tuneLength)){
mtry <- floor(p/3)
model <- ranger(data = dat.trn[, c(status, time, rhs.vars)],
seed = seed,
mtry = mtry,
scale.permutation.importance = TRUE,
verbose = FALSE,
num.trees = trControl$ntrees,
write.forest=TRUE,
oob.error = FALSE,
dependent.variable.name = time,
status.variable.name = status)
prob <- predictSurvProb_ranger(object = model, newdata = dat.tst, times= predict.times)
colnames(prob) <- paste0("T", 1:length(predict.times))
prob <- cbind.data.frame(class = dat.tst[, status], time = dat.tst[, time], 1-prob)
names(prob) <- c(status, time, paste0("T", 1:length(predict.times)))
if(!is.null(newdata)){
prob.ext <- predictSurvProb_ranger(object = model, newdata = newdata, times= predict.times)
prob.ext <- cbind.data.frame(class = newdata[, status], time = newdata[, time], 1-prob.ext)
names(prob.ext) <- c(status, time, paste0("T", 1:length(predict.times)))
}
pp <- predictSurvProb_ranger(object = model, newdata = dat.trn, times= predict.times)
perf <- lapply(1:ncol(pp), function(xx) {
y = dat.trn[, status]
y[(predict.times[xx] < ts)] <- 0
thrsh <- opt.thresh(prob=1-pp[, xx], obs=y)
auc = Performance.measures(pred=1-pp[, xx], obs= y, threshold= thrsh)$AUC
list(thresh = thrsh, AUC = auc)
} )
bestTune <- NULL
threshold <- sapply(perf, function(xx) xx$thresh)
AUC <- sapply(perf, function(xx) xx$AUC)
AUC.ave <- NULL
} else {
len <- trControl$tuneLength
srule = c("logrank", "extratrees", "C", "maxstat")
#mtry <- caret::var_seq(p = p, len = 3)
grid <- data.frame(min.node.size= sample(1:(min(20,nrow(dat.trn))), size = len, replace = TRUE),
mtry = sample(1:p, size = len, replace = TRUE),
splitrule = sample(srule, size = len, replace = TRUE))
pp <- PP <- list()
thresh <- AUC <- matrix(0, nrow = nrow(grid), ncol = length(predict.times))
auc.ave = numeric(nrow(grid))
### grid search
for(ii in 1:nrow(grid)){
model <- ranger(data = dat.trn[, c(status, time, rhs.vars)], seed = seed, importance = trControl$importance,
mtry = grid$mtry[ii],
scale.permutation.importance = TRUE,
min.node.size = grid$min.node.size[ii],
splitrule = grid$splitrule[ii],
verbose = FALSE,
num.trees = trControl$ntrees,
oob.error = FALSE,
num.threads = trControl$cores,
write.forest=TRUE,
dependent.variable.name = time,
status.variable.name = status)
pp[[ii]] <- predictSurvProb_ranger(object = model, newdata = dat.trn, times= predict.times)
PP[[ii]] <- predictSurvProb_ranger(object = model, newdata = dat.tst, times= predict.times)
perf <- lapply(1:ncol(pp[[ii]]), function(xx) {
y = dat.trn[, status]
y[(predict.times[xx] < ts)] <- 0
Y = dat.tst[, status]
Y[(predict.times[xx] < Ts)] <- 0
thsh <- opt.thresh(prob=1-pp[[ii]][, xx], obs=y)
auc = Performance.measures(pred=1-PP[[ii]][, xx], obs= Y, threshold= thsh)$AUC
list(AUC = auc, thresh = thsh)
} )
AUC[ii, ] <- sapply(perf, function(xx) xx$AUC)
thresh[ii, ] <- sapply(perf, function(xx) xx$thresh)
}
which.max1 <- function(x) {if(all(is.na(x))) return(NA) else return(which.max(x))}
ix <- apply(AUC, 2, which.max1)
ix1 <- unique(ix)
AUC.ave = mean(AUC[ix1, ])
threshold <- thresh[ix1, ,drop = FALSE]
AUC <- AUC[ix, ]
bestTune <- grid[ix1, ,drop = FALSE]
mat <- as.data.frame(matrix(0, nrow = length(ix1), ncol = length(predict.times)))
rownames(mat) <- as.character(ix1)
colnames(mat) <- nme
for(jj in 1:ncol(mat)) mat[as.character(ix[jj]), nme[jj]] <- 1
colnames(mat) <- paste0("T", 1:length(predict.times))
bestTune <- cbind(bestTune, mat)
pp <- sapply(1:length(predict.times), function(xx) 1- PP[[ix[xx]]][, xx])
colnames(pp) <- paste0("T", 1:length(predict.times))
#prob <- cbind(class = dat.tst[, status], time = dat.tst[, time], pp)
prob <- cbind.data.frame(status = dat.tst[, status], time = dat.tst[, time], pp)
names(prob) <- c(status, time, paste0("T", 1:length(predict.times)))
if(!is.null(newdata)){
prob.ext <- predictSurvProb_ranger(object = model, newdata = newdata, times= predict.times)
prob.ext <- cbind.data.frame(class = newdata[, status], time = newdata[, time], 1-prob.ext)
names(prob.ext) <- c(status, time, paste0("T", 1:length(predict.times)))
}
}
res <- list(AUC = AUC, AUC.ave = AUC.ave, prob = prob, threshold = threshold, bestTune = bestTune)
return(res)
#}, mc.cores = mc.cores)
})
if(!is.null(trControl$tuneLength)){
#A <- lapply(resamples, function(xx) xx$AUC)
AUC.ave <- sapply(resamples, function(xx) xx$AUC.ave)
kk <- which.max(AUC.ave)
bestTune <- lapply(resamples, function(xx) xx$bestTune)[[kk]]
threshold <- lapply(resamples, function(xx) xx$threshold)[[kk]]
mat <- as.matrix(bestTune[, nme])
ix <- which.max(apply(mat, 1, sum))
bestTune <- bestTune[ix, ]
threshold <- threshold[ix, ]
} else {
threshold <- sapply(resamples, function(xx) xx$threshold)
AUC <- sapply(resamples, function(xx) xx$AUC)
which.max1 <- function(x) {if(all(is.na(x))) return(NA) else return(which.max(x))}
ix <- apply(AUC, 1, which.max1)
threshold <- as.numeric(sapply(1:nrow(AUC), function(xx) threshold[xx, ix[xx]]))
}
prob <- lapply(resamples, function(xx) xx$prob)
names(prob) <- paste0("cv", 1:trControl$number)
if(!is.null(newdata)){
prob.ext <- lapply(resamples, function(xx) xx$prob.ext)
names(prob.ext) <- paste0("cv", 1:trControl$number)
}
### train final model
if(!is.null(trControl$tuneLength)){
finalModel <- ranger(data = dat[, c(status, time, rhs.vars)],
seed = seed,
mtry = bestTune$mtry,
scale.permutation.importance = TRUE,
importance = trControl$importance,
min.node.size = bestTune$min.node.size,
splitrule = bestTune$splitrule,
verbose = FALSE,
num.trees = trControl$ntrees,
write.forest=TRUE,
oob.error = FALSE,
dependent.variable.name = time,
status.variable.name = status)
} else {
mtry <- floor(p/3)
finalModel <- ranger(data = dat[, c(status, time, rhs.vars)],
seed = seed,
importance = trControl$importance,
mtry = mtry,
scale.permutation.importance = TRUE,
verbose = FALSE,
num.trees = trControl$ntrees,
write.forest=TRUE,
oob.error = FALSE,
dependent.variable.name = time,
status.variable.name = status)
}
fitted <- 1-predictSurvProb_ranger(object = finalModel, newdata = dat, times= predict.times)
colnames(fitted) <- paste0("T", 1:length(predict.times))
names(threshold) <- paste0("T", 1:length(predict.times))
fitted <- cbind(obs = dat[, status], fitted)
return(list(finalModel = finalModel, form = form, fitted = fitted, threshold = threshold, resamples = prob, resamples.ext = prob.ext,
predict.times = predict.times, bestTune = bestTune))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.