R/reduce.surrogates.R

Defines functions less.surrogates.node less.surrogates.trees reduce.surrogates

Documented in reduce.surrogates

#' Reduce surrogate variables in a random forest.
#'
#'This function can be applied to reduce the surrogate variables in a forest that is created by getTreeranger, addLayer
#'and getSurrogates functions. Hence, it can be applied to the forests that were used for surrogate minimal depth variable importance.
#'
#' @param forest a list containing allvariables and trees. Allvariables is a vector of all variable names in the original data set (strings). Trees is a list of trees that was generated by getTreeranger, addLayer, and getSurrogates functions.
#' @param s number of surrogate variables in the new forest (have to be less than in the RF in trees)
#'
#' @return
#' forest with s surrogate variables.
#'
#' @examples
#' # read data
#' data("SMD_example_data")
#' \donttest{
#' ###### use result of SMD variable importance and reduce surrogate variables to 10
#' # select variables with smd variable importance (usually more trees are needed)
#' set.seed(42)
#' res = var.select.smd(x = SMD_example_data[,2:ncol(SMD_example_data)], y = SMD_example_data[,1], s = 100, ntree = 10)
#' forest.new = reduce.surrogates(forest = res$forest, s = 10)
#'
#' # execute SMD on tree with reduced number of surrogates
#' res.new = var.select.smd(create.forest = FALSE, forest = forest.new)
#' res.new$var
#'
#' #' # investigate variable relations
#' rel = var.relations(forest = forest.new, variables=c("X1","X7"), candidates = res$forest[["allvariables"]][1:100], t = 5)
#' rel$var
#'}
#' @export


reduce.surrogates = function(forest, s = 10){

  trees = forest[["trees"]]
  ntree = length(trees)
  trees.new = lapply(1:ntree, less.surrogates.trees, trees,s)
  forest.new = list(trees = trees.new, allvariables = forest[["allvariables"]])
return(forest.new)
}

less.surrogates.trees = function(i = 1, trees,s){
  tree = trees[[i]]
  n.nodes = length(tree)
  surr.now = sapply((lapply(tree,"[",-c(1:7))),length)/2
  surr.next = surr.now
  surr.next[which(surr.now >= s)] = s
  tree.new = lapply(1:n.nodes, less.surrogates.node, tree, surr.next,surr.now)
  return(tree.new)
}


less.surrogates.node = function(j=1, tree,surr.next, surr.now){
  node = tree[[j]]
  if (length(node) == 7) {
    node.new = node
  }
  if (length(node) > 7) {
node.new = node[c(1:(7 + surr.next[j]), (8 + surr.now[j]):(7 + surr.now[j] + surr.next[j]))]
}
return(node.new)
}
StephanSeifert/SurrogateMinimalDepth documentation built on Aug. 7, 2023, 1:59 a.m.