R/RF1.R

Defines functions RF1

Documented in RF1

# Random Forest classifier 
# ----------------------------------------------------------------------
#' Random Forest classifier
#'
#' Classify the input with a Random forest classifier. Function uses a vector of tree depth
#' and does a grid search with mtry parameter for the best training error. Once the best
#' Parameters have been identified, it uses them to predict on the test data.
#'
#' @param data A list generated by the function PrepareData. It is a list with two components. 
#' One is a vector of labels, the other is a data frame of features. The other element of the list
#' is a data frame of test features.
#' @param ntree.vec A vector vector of vaues of number of trees that should be tested.
#' @return A list with a dataframe of errors a vector of predictions with the best parameters
#' @import randomForest
#' @import assertthat
#' @export
#' @examples 
#' path <- "/home/rishabh/mres/ml_comp/data/"
#' data <- PrepareData(path, mode = 2, sample = TRUE, size = 100)
#' ntree.vec <- seq(50, 100, 10)
#' RF1(data, ntree.vec)

RF1 <- function(data, ntree.vec){
  
  tune.RF.result <- data.frame(mtry = numeric(),
                               OOBError = double(), 
                               ntree = numeric())
  
  for(ntree in ntree.vec){
    print(paste("ntree =", ntree, " "))
    tune <- tuneRF(data$train$features, 
                   data$train$label, 
                   trace = FALSE, 
                   ntree.vec = ntree, 
                   plot = FALSE)
    tune <- cbind(tune, ntree = rep(ntree, nrow(tune)))
    tune.RF.result <- rbind(tune.RF.result, tune)
  }
  
  best.index <- which.min(tune.RF.result$OOBError)
  mtry <- tune.RF.result$mtry[best.index]
  ntree <- tune.RF.result$ntree[best.index]

  rf.fit <- randomForest(x = data$train$features,
                         y = data$train$label,
                         mtry = mtry,
                         ntree = ntree)
  
  if(length(data$test) != 2){
    predictions <- predict(rf.fit, data$test)
  } else {
    predictions <- predict(rf.fit, data$test$features)
  }
  

  return(list(error.grid = tune.RF.result, predictions = predictions))  
}
rishi1226/classrish documentation built on May 27, 2019, 9:10 a.m.