R/train.classifier.R

Defines functions train_classifier

Documented in train_classifier

#' @title Predict Survival with classifcation methods  
#' @description Train  machine learning classification models on time to event data using the caret package 
#
#' @name train.classifier   
#' @param  form survival formula 
#' @param dat  data frame 
#' @param method classifcation algorithm. The following algorithms have been implemented. 
#'  \enumerate{
#'  \item glm  logistic regression 
#'  \item glmnet elastic net 
#'  \item gbm gradient boosting machine 
#'  \item ranger random forest 
#'  \item svmRadial  support vector machine with radial basis kernel 
#'  \item xgbTree extreme gradient boosting machine  
#' }  
#' @param predict.times survival prediction times  
#' @param trControl  control parameters for the caret train function. Set to NULL to use a default 5-fold cross-validation 
#'  
#' @param parallel run cross-validation in parallel? Uses mclapply which works only on linux
#' @param tuneLength  same as tuneLength in the caret package 
#' @param \dots further arguments passed to caret or other methods.  
#' @return returns a list with items: 
#' \itemize{
#' \item{finalModel: }{ final model trained on the complete data (dat) using optimal tuning paramters}
#' \item{fitted: }{ predictions on complete data (dat)} 
#' \item{threshold: }{ optimal classification threshold}
#' \item{resamples: }{ cross-validation results: predictions on resampled data } 
#' \item{predict.times: }{ survival prediction times}
#' \item{bestTune: }{ optimal tuning parameters} 
#' \item{method: }{ classification algorithm}
#' }


#' @importFrom caret train trainControl
#' @importFrom gbm gbm gbm.fit 
#' @importFrom ranger ranger 
#' 
#
NULL 
#' @rdname train.classifier   
#' @export
train_classifier = function(form, dat, method = "gbm",  predict.times, trControl=NULL, parallel = FALSE, mc.cores = 2, 
                        seed = 123, ...){				
### 
dots <- list(...)

#predict.times <- sort(predict.times)

  if(parallel) {
    pfun <-  get("mclapply")
    set.seed(seed, "L'Ecuyer") 
  } else {
    set.seed(seed)
    pfun = get("lapply")
  }

if(is.null(trControl)){
tuneLength = 3
cntrl <- trainControl(method = "cv", number = 5, summaryFunction = twoClassSummary, classProbs = TRUE, 
savePredictions = "final", allowParallel = FALSE)
} else {
tuneLength = trControl$tuneLength
cfun1 <- get("trainControl")
cntrl <- trainControl()
nme <- names(trControl)[names(trControl)%in%methods::formalArgs(cfun1)]
cntrl[nme] <- trControl[nme]
}

#form <- formula(paste('Surv(', 'time', ',', status, ') ~ ', 
#             paste0(c(rhs.vars), collapse = " + ")))
### use same survival formular for classification  
resp.vars <- all.vars(form[[2]]) ## 
time = resp.vars[1]; status = resp.vars[2]
rhs.vars <- all.vars(form[[3]])
form <- formula(paste(status, ' ~ ', paste0(c(rhs.vars), collapse = " + ")))
 
extra.para <- Models(method)
pkg <- getModelInfo(method, regex = FALSE)[[1]]$library
for(i in seq(along=pkg)) do.call("requireNamespace", list(package = pkg[i]))

fun <- classifier.name(method)
cfun <- get(fun)

csel <- dots[names(dots)%in%methods::formalArgs(cfun)]
csel <- csel[unique(names(csel))]
                 
obs = dat[, status]
ti = dat[, time]    

pp <- c()
model <- c() 
thresh <- AUC <- numeric(length(predict.times)) 

for(ii in 1:length(predict.times)){
    y = dat[, status] 
    dd <- dat 
    
    y[(predict.times[ii] < ti)] <- 0
#    y[(predict.times[ii] >= v & predict.times[ii] < ti)] <- 0
    dd[, status] <- y 
    dd[, status] <- factor(ifelse(dd[, status]==1, "Yes", "No"))      

def.args <- list(form = form, data = dd, method = method, trControl = cntrl, 
                 tuneLength = tuneLength, metric = "ROC")
#def.args <- list(x = dd[, rhs.vars], y = dd[, status], method = method, trControl = cntrl, 
#                 tuneLength = tuneLength, metric = "ROC")

args <- c(def.args, csel, extra.para)
args[sapply(args, is.null)] <- NULL
args <- args[unique(names(args))]

############ training 
model[[ii]] <- do.call(train, args)

pred <- model[[ii]]$pred

pred <- pred[, c("obs", "Yes", "rowIndex", "Resample")]

tx <- ti[pred$rowIndex] 
cls <- obs[pred$rowIndex]
 
pred$rowIndex <- NULL 
pred$time = tx
pred$class = cls
pred$time.id = ii 
pp[[ii]] <- pred 

## find thresholds 
samp <- unique(pred$Resample)

perf <- lapply(samp, function(xx) {
  
p <- pred$Yes[pred$Resample != xx]
yy <- pred$obs[pred$Resample != xx]
yy = ifelse(yy == "Yes", 1, 0) 
thsh <- opt.thresh(prob=p, obs=yy)

p <- pred$Yes[pred$Resample == xx]
yy <- pred$obs[pred$Resample == xx]
yy = ifelse(yy == "Yes", 1, 0) 
AUC = Performance.measures(pred= p, obs= yy, threshold= thsh)
list(AUC = AUC$AUC, threshold = thsh)
})
A <- sapply(perf, function(xx) xx$AUC)
ix <- which.max(A)
AUC[ii] <- A[ix] 
thresh[ii] <- sapply(perf, function(xx) xx$thresh)[ix]
  }

prob <- do.call(rbind, pp)

fitted <- do.call(cbind, lapply(model, function(xx)
predict(xx, newdata = dat, type = "prob")[,"Yes"])) 
colnames(fitted) <- paste0("T", 1:length(predict.times))
#pp <- apply(fitted, 2, rescale, to = c(1, 10))
#colnames(pp) <- paste0(paste0("T", 1:length(predict.times)), ".scale")
fitted <- cbind(obs = dat[, status], fitted)
#fitted <- cbind(fitted, Total.Risk = rescale(apply(fitted[, paste0("T", 1:length(predict.times))], 1, sum),to = c(1, 10)))
threshold = thresh 

names(model) <- paste0("T", 1:length(predict.times))
names(threshold) <- paste0("T", 1:length(predict.times)) 

return(list(finalModel = model, fitted = fitted, threshold = threshold, resamples = prob, 
predict.times =  predict.times,  bestTune = NULL, method = method))
}
nguforche/MLSurvival documentation built on July 28, 2019, 1:59 p.m.