
Defines functions paste.split.labs tweak.splits get.lsplit.rsplit internal.split.labs split_labs_wrapper

# split.labs.R: functions for generating split.labels

split_labs_wrapper <- function(x, split.fun, split.fun.name,
                               split.prefix, split.suffix,
                               right.split.prefix, right.split.suffix,
                               type, clip.facs,
                               clip.left.labs, clip.right.labs, xflip,
                               digits, varlen, faclen, roundint, trace,
                               facsep, eq, lt, ge)
    logical.eq <- eq
        eq <- "|" # special value used as a flag

    labs <- internal.split.labs(x, type,
                                digits, varlen, faclen, roundint,
                                clip.facs, clip.left.labs, clip.right.labs, xflip,
                                facsep, eq, logical.eq, lt, ge,
                                split.prefix, right.split.prefix,
                                split.suffix, right.split.suffix)

    if(!is.null(split.fun)) { # call user's split.fun?
        check.func.args(split.fun, "split.fun",
                        function(x, labs, digits, varlen, faclen) NA)
        labs <- split.fun(x, labs, abs(digits), varlen, faclen)
    # check returned labs because split.fun may be user supplied
    check.returned.labs(x, labs, split.fun.name)
# Modified version of labels.rpart.
# This uses format0 instead of formatg and has various other extensions.
internal.split.labs <- function(x, type,
                                digits, varlen, faclen, roundint,
                                clip.facs, clip.left.labs, clip.right.labs, xflip,
                                facsep, eq, logical.eq, lt, ge,
                                split.prefix, right.split.prefix,
                                split.suffix, right.split.suffix)
    frame <- x$frame
    if(nrow(frame) == 1) # special case, no splits?
        return("root")   # NOTE: return
    is.leaf <- is.leaf(frame)
    split.var.names <- frame$var[!is.leaf]  # variable names for the (primary) splits

    split.var.names <- as.character(split.var.names) # factor levels to character
    clip.left.labs  <- recycle(clip.left.labs,  split.var.names)
    clip.right.labs <- recycle(clip.right.labs, split.var.names)

    # isplit is the row index of the primary split in x$splits
    index <- cumsum(c(1, frame$ncompete + frame$nsurrogate + !is.leaf))
    isplit  <- index[c(!is.leaf, FALSE)]

    split <- get.lsplit.rsplit(x, isplit, split.var.names,
                               type, clip.left.labs, clip.right.labs, xflip,
                               digits, faclen, roundint, trace,
                               facsep, eq, logical.eq, lt, ge)

    # We now have something like this:
    #    split.var.names:   sex   age     pclass     sibsp   pclass
    #    split$lsplit:      mal   >=9.5   =2nd,3rd   >=2.5   =3rd
    #    split$rsplit:      fml   <9.5    =1st       <2.5    =1st,2nd

    if(clip.facs) {
        is.eq.l <- substr(split$lsplit, 1, 1) == "|" # special value used a flag
        is.eq.r <- substr(split$rsplit, 1, 1) == "|"
        split.var.names[is.eq.l | is.eq.r] <- ""                     # drop var name
        split$lsplit[is.eq.l] <- substring(split$lsplit[is.eq.l], 2) # drop leading |
        split$rsplit[is.eq.r] <- substring(split$rsplit[is.eq.r], 2)
                     split.var.names, split$lsplit, split$rsplit,
                     type, clip.facs,
                     clip.left.labs, clip.right.labs, xflip, varlen,
                     split.prefix, right.split.prefix,
                     split.suffix, right.split.suffix)
get.lsplit.rsplit <- function(x, isplit, split.var.names,
                              type, clip.left.labs, clip.right.labs, xflip,
                              digits, faclen, roundint, trace,
                              facsep, eq, logical.eq, lt, ge)
    frame <- x$frame
    is.leaf <- is.leaf(frame)
    splits <- tweak.splits(x, roundint, digits, trace)
    ncat  <- splits[isplit, "ncat"]
    lsplit <- rsplit <- character(length=length(isplit))
    is.con <- ncat <= 1             # continuous vars (a logical vector)
    if(any(is.con)) {               # any continuous vars?
        cut <- splits[isplit[is.con], "index"]
        formatted.cut <- format0(cut, digits)
        is.less.than <- ncat < 0
        lsplit[is.con] <- paste0(ifelse(is.less.than, lt, ge)[is.con], formatted.cut)
        rsplit[is.con] <- paste0(ifelse(is.less.than, ge, lt)[is.con], formatted.cut)
        # print logical and 01 predictors as "Survived = 1" or "Survived = 0"
        islogical <- x$varinfo$islogical[isplit]
        is01      <- x$varinfo$is01[isplit]
        if(!anyNA(islogical) && !anyNA(is01) && (any(islogical) || any(is01))) {
            isbool <- islogical | (roundint & is01)
            eq0 <- paste0(logical.eq, "0")
            eq1 <- paste0(logical.eq, "1")
            lsplit[isbool] <- paste0(ifelse(is.less.than, eq0, eq1)[isbool])
            rsplit[isbool] <- paste0(ifelse(is.less.than, eq1, eq0)[isbool])
    is.cat <- ncat > 1              # categorical variables (a logical vector)
    if(any(is.cat)) {               # any categorical variables?
        # jrow is the row numbers of factors within lsplit and rsplit
        # cindex is the index on the "xlevels" list
        jrow <- seq_along(ncat)[is.cat]
        crow <- splits[isplit[is.cat], "index"] # row number in csplit
        xlevels <- attr(x, "xlevels")
        cindex <- match(split.var.names, names(xlevels))[is.cat]
        # decide if we must add a "=" prefix
        paste.left.eq  <- !is.fancy(type) | (if(xflip) !clip.right.labs else !clip.left.labs)
        paste.right.eq <- !is.fancy(type) | (if(xflip) !clip.left.labs  else !clip.right.labs)
        for(i in 1:length(jrow)) {
            node.xlevels <- my.abbreviate(xlevels[[cindex[i]]],
                                          faclen, one.is.special=TRUE)
            j <- jrow[i]
            splits <- x$csplit[crow[i],]
            # splits is 1=left 2=neither 3=right
            left  <- (1:length(splits))[splits == 1]
            right <- (1:length(splits))[splits == 3]
            collapse <- if(faclen==1) "" else facsep
            lsplit[j] <- paste(node.xlevels[left],   collapse=collapse)
            rsplit[j] <- paste0(node.xlevels[right], collapse=collapse)
                lsplit[j] <- paste0(eq, lsplit[j])
                rsplit[j] <- paste0(eq, rsplit[j])
    list(lsplit=lsplit, rsplit=rsplit)
# If roundint is TRUE then round up the entries in splits$index for
# variables that are integral. (All values in the input data for the
# variable must be integral, not just the entry in splits$index.)
# Add verysmall to all other cuts where verysmall is something like 1e-10.
# This is so format(cut, digits=digits) always rounds up (rather than the
# default behaviour of format which is to round to even when the last digit
# is even).
# This gives more easily interpretable results, especially because split
# cuts ending in .5 are common for integer predictors.
# Thus "nbr.family.members < 2.5" is now rounded to "nbr.family.members < 3"
# (rather than "nbr.family.members < 2", which was misleading).
# e.g. format(1234.5,       digits=2) is 1234   rounds down (not what we want)
#      format(1235.5,       digits=2) is 1236   rounds up
#      format(1234.5+1e-10, digits=2) is 1235   rounds up
#      format(1235.5+1e-10, digits=2) is 1236   rounds up
# Note that if roundint is FALSE then we still bump all entries by verysmall.

tweak.splits <- function(obj, roundint, digits, trace)
    splits <- obj$splits
    if(is.null(splits))               # semtree object (from package semtree)
    if(!is.numeric(splits[,"index"])) # paranoia

    verysmall <- floor(log10(abs(splits[,"index"])))
    verysmall[verysmall > 0] <- 0
    verysmall <- exp10(-abs(verysmall) - 10)

    splits[,"index"] <- splits[,"index"] + verysmall

    # Because rpart.model.frame() may fail or give warnings, we do
    # processing here only if the roundint argument is TRUE.
    if(roundint) {
        isroundint <- obj$varinfo$isroundint # vector of bools
        if(anyNA(isroundint)) # couldn't get the isroundint vector
        for(varname in names(isroundint)) if(isroundint[varname]) {
            i <- rownames(splits) %in% varname
            if(length(i) > 0)
                splits[,"index"][i] <- ceiling(splits[,"index"][i] - verysmall[i])
# Paste the various constituents to create the split labels vector.
# On entry we have something like this:
#    split.var.names sex   age     pclass     sibsp   pclass
#    lsplit:         mal   >=9.5   =2nd,3rd   >=2.5   =3rd
#    rsplit:         fml   <9.5    =1st       <2.5    =1st,2nd

paste.split.labs <- function(frame, split.var.names, lsplit, rsplit,
                             type, clip.facs,
                             clip.left.labs, clip.right.labs, xflip, varlen,
                             split.prefix, right.split.prefix,
                             split.suffix, right.split.suffix)
    nodes <- as.numeric(row.names(frame))
    is.right <- nodes %% 2 == 1
    is.leaf <- is.leaf(frame)
    parent <- match(nodes %/% 2, nodes[!is.leaf])
    split.var.names <- my.abbreviate(split.var.names, varlen)
    left.names <- right.names <- split.var.names

    if(is.fancy(type)) {
            right.names[clip.left.labs] <- ""
            left.names[clip.left.labs] <- ""
            left.names[clip.right.labs] <- ""
            right.names[clip.right.labs] <- ""
        right.split.prefix <- split.prefix
        right.split.suffix <- split.suffix

    # the heart of this function

    labs  <- paste0(split.prefix, left.names[parent], lsplit[parent], split.suffix)
    labs[is.right] <- paste0(right.split.prefix,
    labs[1] <- "root" # was "NANA" because of above paste0

Try the rpart.plot package in your browser

Any scripts or data that you put into this service are public.

rpart.plot documentation built on May 29, 2024, 12:07 p.m.