R/bin_rpart.R

Defines functions bin.rpart

Documented in bin.rpart

#' Bin numerical values based on recursive partitioning (rpart)
#' 
#' \code{bin.rpart} relies on \code{\link[rpart]{rpart}}
#' function to split numerical values into different nodes. According to the 
#' tree-structure splits generated form \code{rpart}, \code{bin.rpart} further 
#' divides the numerical values into different bins, and record the cut points. 
#' The usage of \code{bin.rpart} is similar to \code{rpart}, except that the 
#' argument of \emph{control} in \code{rpart} is named as \emph{rcontrol} in 
#' \code{bin.rpart}
#' 
#' @param formula The formula for rpart
#' @param data The data frame used for binning
#' @param rcontrol The arguments passed into \code{\link[rpart]{rpart.control}} 
#' @param n.group Number of acceptable binning groups. It can be NULL,
#'    a single number (e.g., 5), or a vector (e.g., 3:7). If the value is NULL, it 
#'    returns the output with the default \emph{rpart.control}
#'    settings. If the n.group is a numeric value, it will change the \emph{cp} 
#'    value within \emph{rpart.control} automatically, 
#'    until it gets the desirable number of groups 
#' @param ... All other arguments that can be passed to \code{rpart} 
#' 
#' @return The cut points (\emph{cut.points}) and \emph{bins}. 
#' @examples
#' data <- rpart::stagec
#' bin.rpart(pgstat ~ age, data = data)
#' @export

bin.rpart <- function(formula, data, rcontrol = rpart.control(), n.group = NULL, 
  ...) {
  # This function is used to bin the numerical variable for survival model
  # Arg:
  #    formula: the formula used for rpart
  #    data: the dataset used for rpart
  #    n.group: the acceptable number of groups (NA group not counted for)
  #    rcontrol: the control used for rpart
  
  # The NA values are removed by the rpart function automatically
  row.names(data) <- 1:nrow(data)
  
  vars <- all.vars(formula)
  x.num <- vars[length(vars)]
  
  # if the minbucket is the default value 7, then update it to 1% of the data
  # if(rcontrol$minbucket == 7) {
  #   rcontrol$minbucket <- .01 * nrow(data)
  # }
  
  rp.tree <- rpart(formula, data, control = rcontrol, ...)
  # rp.tree <- rpart(formula, data, control = rcontrol)
  
  ## if n.group is NULL, and no group is found, return 'No Bin'
  if(is.null(n.group) & length(unique(rp.tree$where)) == 1) {
    cat(c(x.num, ': No Bin \n'))
    return('No Bin')
  }
  
  ## if n.group is not NULL, change cp to find the possible bins within n.group
  while((!is.null(n.group)) & (!length(unique(rp.tree$where)) %in% n.group)) {
    multipler <- ifelse(length(unique(rp.tree$where)) > median(n.group), 1.1, .9)
    rcontrol$cp <- rcontrol$cp * multipler
    rp.tree <- rpart(formula, data, control = rcontrol, ...)
    # rp.tree <- rpart(formula, data, control = rcontrol)
  }
  
  tree.where <- data.frame(Where = rp.tree$where)
  
  tree.value <- data.frame(Value = data[, x.num], Where = 'Missing',
    stringsAsFactors = F)
  tree.value[row.names(tree.where), 'Where'] <- tree.where$Where
  
  tree.cut <- dplyr::group_by(tree.value, Where) %>%
    dplyr::summarise(Cut_Start = min(Value), Cut_End = max(Value)) %>%
    dplyr::arrange(Cut_End)
  
  if(is.na(tree.cut$Cut_End[nrow(tree.cut)])) {
    cut.p <- tree.cut$Cut_End[1:(nrow(tree.cut) - 2)]
  } else {
    cut.p <- tree.cut$Cut_End[1:(nrow(tree.cut) - 1)]
  }
  
  cat(c(x.num, ':', cut.p, '\n'))
  
  bin.names <- c(paste('<=', cut.p[1]), rep(NA, length(cut.p) - 1),
    paste('>', cut.p[length(cut.p)]))
  
  if(length(cut.p) > 1) {
    for (i in 2:length(cut.p)) {
      bin.names[i] <- paste(cut.p[i - 1], '< \u00B7 <=', cut.p[i])
    }}  # \u00B7 is the unicode for mid-dot, \u2022 is for bullet point
  
  # Turn off this, to make the naming more consistent
  # check whether the Cut_Start and Cut_End are the same
  # if the same, the <, =, or > signs is not needed
  # bin.names <- ifelse(tree.cut$Cut_Start[1:length(bin.names)] ==
  #     tree.cut$Cut_End[1:length(bin.names)], tree.cut$Cut_End, bin.names)
  
  tree.cut$Bin <- 'Missing'
  tree.cut$Bin[1:length(bin.names)] <- bin.names
  
  x.bins <- factor(tree.cut$Bin[match(tree.value$Where, tree.cut$Where)],
    levels = tree.cut$Bin)
  
  return(list(cut.points = cut.p, bins = x.bins))
}
# dt <- read.csv('C:/Projects/AlumniConvProg/data/Data_Associate.csv')
# bin.rpart(formula = Surv(Conversion_Time_Months, Conversion_Status) ~ WF_Count,
#   rcontrol = rpart.control(cp = 0.0001, minbucket = nrow(dt.conv) * .01),
#   data = dt.conv, n.group = 3:7)
JianhuaHuang/streamit documentation built on May 7, 2019, 10:40 a.m.