R/saveForest.R

Defines functions saveForest saveTrees saveForestComponents check_and_create_directory

Documented in saveForest

#' Save Random Forests
#'
#' Saves a random forest for later use, given that the base R
#' \code{\link[base]{save}} function doesn't work for this package.
#'
#' @param forest The forest to save.
#' @param directory The directory that should be created to save the trees in.
#'   Note that if the directory already exists, an error will be displayed
#'   unless \code{overwrite} is set to TRUE.
#' @param overwrite Should the function overwrite an existing forest; FALSE by
#'   default.

#' @export
#' @seealso \code{\link{train}}, \code{\link{loadForest}}
#' @examples
#' # Regression Example
#' x1 <- rnorm(1000)
#' x2 <- rnorm(1000)
#' y <- 1 + x1 + x2 + rnorm(1000)
#'
#' data <- data.frame(x1, x2, y)
#' forest <- train(y ~ x1 + x2, data,
#'  ntree=100, numberOfSplits = 5, mtry = 1, nodeSize = 5)
#'
#' saveForest(forest, "saved_forest")
#' new_forest <- loadForest("saved_forest")
saveForest <- function(forest, directory, overwrite=FALSE){
  check_and_create_directory(directory, overwrite)
  
  saveTrees(forest, directory)
  
  # Next save the response combiners and the split finders
  saveForestComponents(directory, 
                       covariateList=forest$covariateList,
                       params=forest$params,
                       forestCall=forest$call)
  
}

saveTrees <- function(forest, directory){
  # This function assumes that directory is free for us to write in. 
  
  forest.java <- forest$javaObject
  
  # First save the trees
  tree.collection.java <- .jcall(forest.java, makeResponse(.class_List), "getTrees")
  numberOfTrees <- forest$params$ntree
  width = round(log10(numberOfTrees))+1
  treeNames <- paste0(directory, "/tree-", formatC(1:numberOfTrees, width=width, format="d", flag="0"), ".tree")
  for(i in 1:numberOfTrees){
    treeName <-treeNames[i]
    tree.java <- .jcall(tree.collection.java, makeResponse(.class_Object), "get", as.integer(i-1))
    tree.java <- .jcast(tree.java, .class_Serializable)
    .jcall(.class_DataUtils, "V", "saveObject", tree.java, treeName)
  }
  
}

saveForestComponents <- function(directory, covariateList, params, forestCall){
  
  nodeResponseCombiner <- params$nodeResponseCombiner
  nodeResponseCombiner.java <- .jcast(nodeResponseCombiner$javaObject, .class_Serializable)
  .jcall(.class_DataUtils, "V", "saveObject", nodeResponseCombiner.java, paste0(directory, "/nodeResponseCombiner.jData"))
  nodeResponseCombiner$javaObject <- NULL
  
  splitFinder <- params$splitFinder
  splitFinder.java <- .jcast(splitFinder$javaObject, .class_Serializable)
  .jcall(.class_DataUtils, "V", "saveObject", splitFinder.java, paste0(directory, "/splitFinder.jData"))
  splitFinder$javaObject <- NULL
  
  forestResponseCombiner <- params$forestResponseCombiner
  forestResponseCombiner.java <- .jcast(forestResponseCombiner$javaObject, .class_Serializable)
  .jcall(.class_DataUtils, "V", "saveObject", forestResponseCombiner.java, paste0(directory, "/forestResponseCombiner.jData"))
  forestResponseCombiner$javaObject <- NULL
  
  covariateList <- .jcast(covariateList, .class_Serializable)
  .jcall(.class_DataUtils, "V", "saveObject", covariateList, paste0(directory, "/covariateList.jData"))
  
  saveRDS(object=params, file=paste0(directory, "/parameters.rData"))
  saveRDS(object=forestCall, file=paste0(directory, "/call.rData"))
}

check_and_create_directory <- function(directory, overwrite){
  if(file.exists(directory) & !overwrite){
    stop(paste(directory, "already exists; will not modify it. Please remove/rename it or set overwrite=TRUE"))
  }
  else if(file.exists(directory) & overwrite){
    unlink(directory)
  }
  
  dir.create(directory)
}
jatherrien/largeRCRF documentation built on Nov. 15, 2019, 7:16 a.m.