R/forest.fold.R

# VT.FOREST.FOLD ----------------------------------------------------------

#' Difft via k random forests
#' 
#' A reference class to compute twins via k random forest
#' 
#' \code{VT.forest.fold} extends \code{VT.forest}
#' 
#' Twins are estimated by k-fold cross validation. A forest is computed on k-1/k
#' of the data and then used to estimate twin1 and twin2 on 1/k of the left 
#' data.
#' 
#' @include forest.R
#'   
#' @field interactions logical set TRUE if model has been computed with 
#'   interactions
#' @field fold numeric, number of fold, i.e. number of forest (k)
#' @field ratio numeric experimental, use to balance sampsize. Defaut to 1.
#' @field groups vector Define which observations belong to which group
#' @field ... field from parent class : \code{\link{VT.forest}}
#'   
#' @name VT.forest.fold
#'   
#' @seealso \code{\link{VT.difft}}, \code{\link{VT.forest}}, 
#'   \code{\link{VT.forest.one}}, \code{\link{VT.forest.double}}
#'   
#' @import methods
#' 
#' @export VT.forest.fold
#' 

VT.forest.fold <- setRefClass(
  Class = "VT.forest.fold",
  
  contains = "VT.forest",
  
  fields = list(
    interactions = "logical",
    fold = "numeric",
    ratio = "numeric",
    groups = "vector"
  ),
  
  methods = list(
    initialize = function(vt.object, fold, ratio, interactions = T, ...){
      
      .self$fold <- fold
      
      .self$ratio <- ratio
      
      .self$interactions <- interactions
      
      callSuper(vt.object, ...)
    },
    
    run = function(parallel = F, ...){
      
      .self$groups <- sample(1:.self$fold, nrow(.self$vt.object$data), replace = T)
      
      for(g in 1:.self$fold){
        .self$runOneForest(g, ...)
      }
      
      .self$computeDifft()
    },
    
    runOneForest = function(group, ...){
      data <- .self$vt.object$getX(interactions = .self$interactions)
      X <- data[.self$groups != group, -1] 
      Y <- .self$vt.object$data[.self$groups != group, 1]
      Yeff <- table(Y) # 1 -> levels(Y)[1] & 2 -> levels(Y)[2]
      sampmin <- min(Yeff[1], Yeff[2])
      
      if(sampmin == Yeff[2]){
        samp2 <- sampmin
        samp1 <- min(Yeff[1], round(.self$ratio*Yeff[1], digits = 0))
      }else{
        samp2 <- Yeff[2]
        samp1 <- sampmin
      }
      if(!requireNamespace("randomForest", quietly = TRUE)) stop("randomForest package must be loaded.")
      rf <- randomForest(x = X, y = Y, sampsize = c(samp1, samp2), keep.forest = T, ...)
      
      .self$computeTwin1(rf, group)
      .self$computeTwin2(rf, group)
    },
    
    computeTwin1 = function(rfor, group){
      
      data <- .self$vt.object$getX(interactions = .self$interactions)
      data <- data[.self$groups == group, -1]
      
      .self$twin1[.self$groups == group] <- VT.predict(rfor = rfor, newdata = data,  type = .self$vt.object$type)
      
      return(invisible(.self$twin1))
    },
    
    computeTwin2 = function(rfor, group){
      
      .self$vt.object$switchTreatment()
      data <- .self$vt.object$getX(interactions = .self$interactions)
      data <- data[.self$groups == group, ]
      
      .self$twin2[.self$groups == group] <- VT.predict(rfor = rfor, newdata = data,  type = .self$vt.object$type)
      
      .self$vt.object$switchTreatment()
      return(invisible(.self$twin2))
    }
  )
)
prise6/aVirtualTwins documentation built on May 8, 2019, 6:50 p.m.