R/train.RSF.R

Defines functions train_RSF

Documented in train_RSF

#' @title RSF: Random Survival Forest 
#' @description Train the random survival forest through the ranger package. The optimal RSF tuning parameters: 
#' min.node.size,mtry, and splitrule can be selected through grid search.   
#
#' @name RSF  
#' @param  form survival formula 
#' @param dat  data frame 
#' @param predict.times survival prediction times 
#' @param trControl  list of control parameters:  
#'  \enumerate{
#'  \item ntrees: mumber of trees 
#'  \item number: number of cross-validations 
#'  \item tuneLength: tuning paramer grid size  
#'  \item importance: ranger variable importance  
#' }  
#' @param parallel run cross-validation in parallel? Uses mclapply which works only on linux 
#' @param tuneLength  same as tuneLength in the caret package  
#' @param \dots further arguments passed to caret or other methods. 
#' @return returns a list with items: 
#' \itemize{
#' \item{finalModel: }{ final model trained on the complete data (dat) using optimal tuning paramters}
#' \item{fitted: }{ predictions on complete data (dat)} 
#' \item{threshold: }{ optimal classification threshold}
#' \item{resamples: }{ cross-validation results: predictions on resampled data } 
#' \item{predict.times: }{ survival prediction times}
#' \item{bestTune: }{ optimal tuning parameters} 
#' }

NULL 
#' @rdname RSF  
#' @export
train_RSF = function(form, dat, newdata = NULL, predict.times, trControl = NULL, seed = 123, parallel = FALSE, mc.cores = 2, ...){				
### 

if(is.null(trControl))
  trControl = list(number = 5, ntrees = 100, importance = "none", tuneLength = 3)
  
bestTune <- prob.ext <-  NULL
predict.times <- sort(predict.times)
nme <- paste0("T", 1:length(predict.times)) 

  if(parallel) {
    pfun <-  get("mclapply")
    set.seed(seed, "L'Ecuyer") 
  } else {
    set.seed(seed)
    pfun = get("lapply")
  }

#form <- formula(paste('Surv(', 'time', ',', status, ') ~ ', 
#             paste0(c(rhs.vars), collapse = " + ")))

resp.vars <- all.vars(form[[2]]) ## 
time = resp.vars[1]; status = resp.vars[2]
rhs.vars <- all.vars(form[[3]])
p <- length(rhs.vars)
n = nrow(dat) 

ix.cv <- createFolds(dat[, status], k= trControl$number, list=FALSE)
## cv 
#resamples <- pfun(1:trControl$number, function(kk, ...){ 
resamples <- lapply(1:trControl$number, function(kk){   
    dat.trn  <- dat[kk!=ix.cv, ]
    dat.tst <-  dat[kk==ix.cv, ] 
    obs = dat.trn[, status]
    ts = dat.trn[, time]
    Obs = dat.tst[, status]
    Ts = dat.tst[, time]

if(is.null(trControl$tuneLength)){
mtry <- floor(p/3)
model  <- ranger(data = dat.trn[, c(status, time, rhs.vars)],  
                     seed = seed,  
                     mtry = mtry, 
                     scale.permutation.importance = TRUE, 
                     verbose = FALSE, 
                     num.trees = trControl$ntrees, 
                     write.forest=TRUE, 
                     oob.error = FALSE, 
                     dependent.variable.name = time, 
                     status.variable.name = status)
                     
prob <- predictSurvProb_ranger(object = model, newdata = dat.tst, times= predict.times)
colnames(prob) <- paste0("T", 1:length(predict.times))
prob <- cbind.data.frame(class = dat.tst[, status], time = dat.tst[, time],  1-prob)
names(prob) <- c(status, time, paste0("T", 1:length(predict.times)))

if(!is.null(newdata)){
prob.ext <- predictSurvProb_ranger(object = model, newdata = newdata, times= predict.times)
prob.ext <- cbind.data.frame(class = newdata[, status],  time = newdata[, time], 1-prob.ext)
names(prob.ext) <- c(status, time, paste0("T", 1:length(predict.times)))
}

pp <- predictSurvProb_ranger(object = model, newdata = dat.trn, times= predict.times)

perf <-  lapply(1:ncol(pp), function(xx) {
y = dat.trn[, status]
y[(predict.times[xx] < ts)] <- 0
thrsh <- opt.thresh(prob=1-pp[, xx], obs=y)
auc = Performance.measures(pred=1-pp[, xx], obs= y, threshold= thrsh)$AUC
list(thresh = thrsh, AUC = auc)
} )
bestTune <- NULL
threshold <- sapply(perf, function(xx) xx$thresh)
AUC <- sapply(perf, function(xx) xx$AUC)
AUC.ave <- NULL                                         
} else {
len <- trControl$tuneLength 
srule = c("logrank", "extratrees", "C", "maxstat")
#mtry <- caret::var_seq(p = p, len = 3)
grid <- data.frame(min.node.size= sample(1:(min(20,nrow(dat.trn))), size = len, replace = TRUE),
                          mtry = sample(1:p, size = len, replace = TRUE),
                          splitrule = sample(srule, size = len, replace = TRUE))
pp <- PP <- list()
thresh <- AUC <- matrix(0, nrow = nrow(grid), ncol = length(predict.times))
auc.ave = numeric(nrow(grid))
### grid search                           
for(ii in 1:nrow(grid)){
  
model  <- ranger(data = dat.trn[, c(status, time, rhs.vars)],  seed = seed, importance = trControl$importance, 
                     mtry = grid$mtry[ii], 
                     scale.permutation.importance = TRUE,
                     min.node.size = grid$min.node.size[ii], 
                     splitrule = grid$splitrule[ii],
                     verbose = FALSE, 
                     num.trees = trControl$ntrees, 
                     oob.error = FALSE, 
                     num.threads = trControl$cores,
                     write.forest=TRUE,
                     dependent.variable.name = time, 
                     status.variable.name = status)

pp[[ii]] <- predictSurvProb_ranger(object = model, newdata = dat.trn, times= predict.times)
PP[[ii]] <- predictSurvProb_ranger(object = model, newdata = dat.tst, times= predict.times)

perf <- lapply(1:ncol(pp[[ii]]), function(xx) {
y = dat.trn[, status]
y[(predict.times[xx] < ts)] <- 0
Y = dat.tst[, status]
Y[(predict.times[xx] < Ts)] <- 0

thsh <- opt.thresh(prob=1-pp[[ii]][, xx], obs=y)
auc = Performance.measures(pred=1-PP[[ii]][, xx], obs= Y, threshold= thsh)$AUC
list(AUC = auc, thresh = thsh)
} )

AUC[ii, ] <- sapply(perf, function(xx) xx$AUC)
thresh[ii, ] <- sapply(perf, function(xx) xx$thresh) 
}

which.max1 <- function(x) {if(all(is.na(x))) return(NA) else return(which.max(x))}
ix <- apply(AUC, 2, which.max1) 
ix1 <- unique(ix)

AUC.ave = mean(AUC[ix1, ])
threshold <- thresh[ix1, ,drop = FALSE]
AUC <- AUC[ix, ]
bestTune <- grid[ix1, ,drop = FALSE]

mat <- as.data.frame(matrix(0, nrow = length(ix1), ncol = length(predict.times)))
rownames(mat) <- as.character(ix1)
colnames(mat) <- nme 

for(jj in 1:ncol(mat)) mat[as.character(ix[jj]), nme[jj]] <- 1  
  
colnames(mat) <- paste0("T", 1:length(predict.times)) 
bestTune <- cbind(bestTune, mat)

pp <- sapply(1:length(predict.times), function(xx) 1- PP[[ix[xx]]][, xx])
colnames(pp) <- paste0("T", 1:length(predict.times))
#prob <- cbind(class = dat.tst[, status], time = dat.tst[, time],  pp)
prob <- cbind.data.frame(status = dat.tst[, status],  time = dat.tst[, time], pp)
names(prob) <- c(status, time, paste0("T", 1:length(predict.times)))

if(!is.null(newdata)){
prob.ext <- predictSurvProb_ranger(object = model, newdata = newdata, times= predict.times)
prob.ext <- cbind.data.frame(class = newdata[, status],  time = newdata[, time], 1-prob.ext)
names(prob.ext) <- c(status, time, paste0("T", 1:length(predict.times)))
}
}
res <- list(AUC = AUC, AUC.ave = AUC.ave, prob = prob, threshold = threshold, bestTune = bestTune)
return(res)
#}, mc.cores = mc.cores)
})

if(!is.null(trControl$tuneLength)){
#A <- lapply(resamples, function(xx) xx$AUC)
AUC.ave <- sapply(resamples, function(xx) xx$AUC.ave)
kk <- which.max(AUC.ave)
bestTune <- lapply(resamples, function(xx) xx$bestTune)[[kk]]
threshold <- lapply(resamples, function(xx) xx$threshold)[[kk]]

mat <- as.matrix(bestTune[, nme])
ix <- which.max(apply(mat, 1, sum))
bestTune <- bestTune[ix, ]
threshold <- threshold[ix, ]
} else {
  threshold <- sapply(resamples, function(xx) xx$threshold)
  AUC <- sapply(resamples, function(xx) xx$AUC)
  which.max1 <- function(x) {if(all(is.na(x))) return(NA) else return(which.max(x))}
  ix <- apply(AUC, 1, which.max1)
  threshold <- as.numeric(sapply(1:nrow(AUC), function(xx) threshold[xx, ix[xx]]))
}

prob <- lapply(resamples, function(xx) xx$prob) 
names(prob) <- paste0("cv", 1:trControl$number)

if(!is.null(newdata)){
prob.ext <- lapply(resamples, function(xx) xx$prob.ext)
names(prob.ext) <- paste0("cv", 1:trControl$number)
}

### train final model 
if(!is.null(trControl$tuneLength)){
finalModel  <- ranger(data = dat[, c(status, time, rhs.vars)],  
                     seed = seed, 
                     mtry = bestTune$mtry, 
                     scale.permutation.importance = TRUE,  
                     importance = trControl$importance,
                     min.node.size = bestTune$min.node.size, 
                     splitrule = bestTune$splitrule,
                     verbose = FALSE, 
                     num.trees = trControl$ntrees, 
                     write.forest=TRUE, 
                     oob.error = FALSE, 
                     dependent.variable.name = time, 
                     status.variable.name = status)
                     
} else {
mtry <- floor(p/3)
finalModel  <- ranger(data = dat[, c(status, time, rhs.vars)],  
                     seed = seed, 
                     importance = trControl$importance, 
                     mtry = mtry, 
                     scale.permutation.importance = TRUE, 
                     verbose = FALSE, 
                     num.trees = trControl$ntrees, 
                     write.forest=TRUE, 
                     oob.error = FALSE, 
                     dependent.variable.name = time, 
                     status.variable.name = status)

}                     

fitted <- 1-predictSurvProb_ranger(object = finalModel, newdata = dat, times= predict.times)
colnames(fitted) <- paste0("T", 1:length(predict.times))
names(threshold) <- paste0("T", 1:length(predict.times))
fitted <- cbind(obs = dat[, status], fitted)

return(list(finalModel = finalModel, form = form, fitted = fitted, threshold = threshold, resamples = prob, resamples.ext = prob.ext,
            predict.times =  predict.times, bestTune = bestTune))
}
nguforche/MLSurvival documentation built on July 28, 2019, 1:59 p.m.