R/random_forest.R

Defines functions random_forest

random_forest <- function(train_inputs, train_labels, folds.no, sampling.no, trees.no){
  if(missing(trees.no)){
    trees.no <- seq(10, 30, 10)
  }
  models <- list()
  R2 <- NULL
  for(t in trees.no){
    model <- caret::train(train_inputs,
                          train_labels,
                          method="rf",
                          ntree=t,
                          trControl=trainControl(method="repeatedcv",
                                                 number=folds.no,
                                                 repeats=sampling.no,
                                                 search="grid"))
    rr <- model$results
    R2 <- c(R2, rr[which(rr$mtry == model$bestTune[[1]]), "Rsquared"])
    models[[t]] <- model
  }
  return(models[[trees.no[which.max(R2)]]])
}
bhklab/PharmacoGxML documentation built on July 9, 2019, 2:44 a.m.