R/extract_segments.R

Defines functions extract_segments

Documented in extract_segments

#' @title Extract lift segments from an rpart object in a table form
#'
#' @description \code{extract_segments} takes as
#' input a fitted rpart object using the lift_method
#' and returns the resulting segments in table form. See example
#' below for more details on it's usage.
#'
#' @param rpart_fit An object of class \code{rpart} fitted with the lift_method method import by \code{\link{import_lift_method}}.
#' @return A data.frame containing the resulting segments.
#' It contains confidednce intervals with the alpha parameter specified in the \code{parms} argument to the \code{rpart} function.
#' @example examples/RCTree_example.R
#' @seealso \code{\link{import_lift_method}}
#' @export


extract_segments <- function(rpart_fit){
  value_reduction <- function(df){
    if(df$eqn[1]=="="){
      if(nrow(df) == 1){
        ans <- paste0(df$val, collapse = ",")
      } else {
        ans <- paste0(Reduce(intersect, sapply(df$val, function(x) unlist(strsplit(x, ",")))), collapse = ",")
      }
      return(list(paste0("=", paste0("{", ans, "}"))))
    } else {
      df$val <- as.numeric(df$val)
      ans_low <- max(df$val[df$eqn == ">"])
      ans_high <- min(df$val[df$eqn == "<"])
      return(list(c(paste0(">=", ans_low), paste0("<", ans_high))))
    }
  }

  leaves <- rpart_fit$frame
  leaves$node <- as.integer(row.names(leaves))
  row.names(leaves) <- NULL
  leaves <- leaves[leaves$var == "<leaf>", ]

  paths <- path.rpart(rpart_fit, nodes = leaves$node, print.it = F)
  segments <- sapply(paths, function(x){
    x <- x[-1] # remove root
    var <- gsub("[<|>|\\=].*", "", x)
    val <- gsub(".*[<|>|\\=]", "", x)
    val[val == ""] <- "empty_string"
    eqn_ind <- regexpr(pattern = "[<|>|\\=]", text = x)
    eqn <- substring(x, first = eqn_ind, last = eqn_ind)
    ans <- data.frame(var, val, eqn, stringsAsFactors = F)
    ans2 <- lapply(split(x = ans, f = ans$var), value_reduction)
    ans3 <- unname(mapply(function(var_vec, ans_vec){
      paste0(paste(var_vec, unlist(ans_vec), sep = ""), collapse = ",")
    }, names(ans2), ans2))
    return(paste0(ans3, collapse = ", "))
  }
  )

  if(is.null(leaves$yval2)) leaves$yval2 <- matrix(ncol = 3, nrow = nrow(leaves))

  segments <- data.frame(segment = segments,
                         n = leaves$n,
                         lift = leaves$yval,
                         lift_lower = leaves$yval2[, 2],
                         lift_upper = leaves$yval2[, 3],
                         row.names = NULL)
  segments <- segments[order(segments$lift, decreasing = T), ]

  return(segments)
}
IyarLin/RCTree documentation built on April 13, 2020, 12:37 a.m.