R/tree.class.R

# VT.TREE.CLASS -----------------------------------------------------------

#' Classification tree to find subgroups
#'
#' See \code{\link{VT.tree}}
#' 
#' @include tree.R
#' 
#' @name VT.tree.class
#' 
#' @export VT.tree.class
#' 
#' @import methods 
#' 
VT.tree.class <- setRefClass(
  Class = "VT.tree.class",
  
  contains = "VT.tree",
  
  methods = list(
    initialize = function(vt.difft, threshold = 0.05, sens = ">", screening = NULL){
      callSuper(vt.difft, threshold, sens, screening)
      
      .self$name <- .self$computeNameOfTree("class")
      
      if(.self$sens == ">"){
        .self$outcome  <- ifelse(.self$vt.difft$difft >= .self$threshold, 1, 0)
      } else {
        .self$outcome  <- ifelse(.self$vt.difft$difft <= .self$threshold, 1, 0)
      }  
    },
    
    run = function(...){
      "VT.tree.class:run(...) Compute classification tree with rpart parameters"
      callSuper()
      
      data  <- .self$getData()  
      if(sum(data[,1]) != 0){
        .self$tree <- rpart::rpart(as.formula(paste(.self$name, ".", sep = "~")), data = data, method = "class", ...)
        .self$Ahat <- as.numeric(predict(.self$tree, data, type = "class")) - 1
      }else{
        .self$Ahat <- .self$outcome
      }
      
      return(invisible(tree))
    },
    
    sumup = function(){
      cat("Classification Tree")
      callSuper()
    }
  )
)

VT.tree.class$lock("threshold", "vt.difft")

Try the aVirtualTwins package in your browser

Any scripts or data that you put into this service are public.

aVirtualTwins documentation built on May 2, 2019, 4:04 p.m.