R/rpart.rules.R

Defines functions check.if.dot.arg.supported.by.rpart.rules stop.if.dot.arg.used print.style.tall order.cols order.rows bar.to.space parse.node.lab parse.split.lab var.is.in.rule trim.leading.space.in.columns fit.colname format.rules apply.varlen.to.colnames as.numeric.na.ok trim.surrounding.space format.fit process.rules trace.print.rules get.rule descendants handle.extra.for.rules get.node.and.split.labs get.raw.rules print.rpart.rules rpart.rules

Documented in print.rpart.rules rpart.rules

# rpart.rules.R:

allowed.styles <- c("wide", "tall", "tallw")

rpart.rules <- function(x=stop("no 'x' argument"),
                        style="wide", cover=FALSE, nn=FALSE,
                        roundint=TRUE, clip.facs=FALSE,
                        varorder=NULL, ...)
{
    if(!inherits(x, "rpart"))
        stop("Not an rpart object")

    ret <- check.if.dot.arg.supported.by.rpart.rules(...)
        extra         <- ret$extra
        digits        <- ret$digits
        varlen        <- ret$varlen
        faclen        <- ret$faclen
        trace         <- ret$trace
        facsep        <- ret$facsep
        eq            <- ret$eq
        lt            <- ret$lt
        ge            <- ret$ge
        and           <- ret$and
        when          <- ret$when
        because       <- ret$because
        null.model    <- ret$null.model
        response.name <- ret$response.name
        rpart.predict <- ret$rpart.predict # hidden arguments for rpart.predict
        where         <- ret$where

    obj <- x
    style <- match.choices(style, allowed.styles)
    cover <- check.boolean(cover)
    nn <- check.boolean(nn)
    roundint <- check.boolean(roundint)
    clip.facs <- check.boolean(clip.facs)
    rpart.predict <- check.boolean(rpart.predict)
    digits <- process.digits.arg(digits)
    if(digits < 0) # non negative because we use standard data.frame formatting
        digits <- -digits
    varlen <- check.integer.scalar(varlen, logical.ok=FALSE)
    faclen <- check.integer.scalar(faclen, logical.ok=FALSE)
    obj$varinfo <- get.modelframe.info(obj, roundint, trace,
                                       parent.frame(), "rpart.rules")
    if(is.null(response.name)) # not explicitly specified by the user?
        response.name <- obj$varinfo$response.name
    stopifnot.string(response.name)
    stopifnot.string(facsep)
    # we trim spaces around these because print.data.frame
    # unavoidably adds spaces between columns
    stopifnot.string(eq, allow.empty=TRUE);
    eq <- trim.surrounding.space(eq)
    stopifnot.string(lt, allow.empty=TRUE);
    lt <- trim.surrounding.space(lt)
    stopifnot.string(ge, allow.empty=TRUE);
    ge <- trim.surrounding.space(ge)
    stopifnot.string(and, allow.empty=TRUE);
    and <- trim.surrounding.space(and)
    stopifnot.string(when, allow.empty=TRUE);
    when <- trim.surrounding.space(when)
    if(when == "" && (style %in% c("tall", "tallw") || rpart.predict))
        when <- ":EMPTY:"
    else if(nrow(obj$frame) == 1)   # null model? (no rules)
        when <- null.model          # hack
    stopifnot.string(because, allow.empty=TRUE);
    because <- trim.surrounding.space(because)
    stopifnot.string(null.model)
    trace <- as.numeric(check.numeric.scalar(trace, logical.ok=TRUE))
    # we get the variable names from the splits because
    # attr(x$terms,"dataClasses") sometimes doesn't save the actual
    # variable names (e.g. model fit9 in slowtests/rpart.report.R)
    varnames <-
        if(nrow(obj$frame) == 1)        # null model? (no rules)
            varnames <- ":NULL.MODEL:"
        else
            unique(rownames(obj$splits))

    ret <- get.raw.rules(obj, extra, varlen, faclen, roundint, trace,
                         facsep, varnames)
        rules          <- ret$rules
        nrules.per.var <- ret$nrules.per.var

    if(trace >= 1)
        trace.print.rules(rules, "raw rules")

    rules <- process.rules(obj, rules, style, cover, nn, clip.facs,
                rpart.predict, where,
                eq, lt, ge, and, when, because, null.model,
                digits, trace, varorder, varlen,
                nrules.per.var, varnames,
                response.name,
                obj$method == "class" || is.class.response(obj),
                attr(obj, "ylevels"))

    node.numbers <- rownames(rules)

    if(trace >= 1)
        trace.print.rules(rules, "processed rules")

    class(rules) <- c("rpart.rules", "data.frame")
    attr(rules, "style") <- style
    attr(rules, "eq")    <- eq
    attr(rules, "and")   <- and
    attr(rules, "when")  <- when

    if(rpart.predict) {
        # return a vector of strings, one string for each element of where
        # TODO sometimes still too much whitespace e.g. example(rpart.predict)
        # TODO must also trim trailing space
        rules <- capture.output(print.rpart.rules(rules))
        rules <- rules[-1] # drop data.frame column names
        # trim leading single space added by print.data.frame in print.rpart.rules
        rules <- gsub("^ ", "", rules)
    }
    rules
}
print.rpart.rules <- function(x=stop("no 'x' argument"),
                              style=attr(x, "style"), ...)
{
    old.warnPartialMatchDollar <- getOption("warnPartialMatchDollar")
    if(is.boolean(old.warnPartialMatchDollar)) # prevents problem when old value is NULL
        on.exit(options(warnPartialMatchDollar=old.warnPartialMatchDollar))
    options(warnPartialMatchDollar=FALSE)

    # some hand holding for an easy error: specifying digits in print.rpart.rules
    dots <- match.call(expand.dots=FALSE)$...
    if(!is.null(dots$di) || !is.null(dots$dig) || !is.null(dots$digi) ||
       !is.null(dots$digit) || !is.null(dots$digits))
        stop0("specify 'digits' in rpart.rules (not in print.rpart.rules)")

    if(is.boolean(old.warnPartialMatchDollar))
        options(warnPartialMatchDollar=old.warnPartialMatchDollar)

    stop.if.dot.arg.used(...)
    style <- match.choices(style, allowed.styles)
    old.width <- options(width=1e4)$width
    on.exit(options(width=old.width))
    if(style == "wide") {
        class(x) <- "data.frame"
        print(x, row.names=FALSE)
    } else if(style == "tall" || style == "tallw")
        print.style.tall(x, style, eq=attr(x, "eq"),
                         and=attr(x, "and"), when=attr(x, "when"))
    else
        stop0("illegal style ", style)
}
# The raw rules are a data.frame like:
#   lab fit iclass cover Girth= Girth< Girth>= Height= Height< Height>=
#        18     18    48           12
#        31     31    29           16      12
#        56     56    23                   16
#
# Or like this for a multiclass model (lab inited, multiple fits in fit colum):
#   lab         fit iclass cover   sex= sex< sex>= age= age< age>=
#   1st .74 .16 .10      1    11 female                         35
#   1st .54 .26 .20      1     7   male                   46
#   3rd .21 .25 .54      3     9   male                   46    35
#   3rd .15 .21 .64      3    73                                35
#
# The raw data.frame has full variable names (varlen is only applied later)
#
# fit is the fit displayed in the leaf (could be a prob or vector of probs)
#
# lab is used only for multiclass models
#
# iclass is used only to sort rows on the fitted class for multiclass models
#   for anova models, iclass is floor(fit)
#   for class models, iclass is the fitted class as an integer

get.raw.rules <- function(obj, extra, varlen, faclen, roundint, trace,
                          facsep, varnames)
{
    ret <- get.node.and.split.labs(obj, extra, faclen, roundint, trace,
                                  facsep, under.percent=2)
        node.labs  <- ret$node.labs
        split.labs <- ret$split.labs

    frame <- obj$frame
    is.leaf <- is.leaf(frame)
    node.numbers <- as.numeric(row.names(frame))

    # map frame row to splits row (matrix: column is iframe, row is isplit)
    iframe.to.isplit.mat <- descendants(node.numbers)

    maxrules <- 1e3 # arb
    nrules <- 0

    # first 4 columns are always lab, fit, iclass, cover; then 3 cols per variable
    # column names are c("lab", "fit", "iclass", "cover",
    #                    "Girth=", "Girth<", "Girth>=", "Height=", "Height<", "Height>=")
    rules <- matrix("", nrow=maxrules, ncol=4 + 3 * length(varnames))
    colnames(rules) <- c("lab", "fit", "iclass", "cover",
                          paste0(rep(varnames, each=3), c("=", "<", ">=")))
    rules <- as.data.frame(rules, stringsAsFactors=FALSE)

    # number of rules each variable is in, for sorting columns on var importance
    nrules.per.var <- repl(0, length(varnames))
    names(nrules.per.var) <- varnames

    trace1(trace, "\n")
    for(iframe in 1:nrow(frame)) if(is.leaf[iframe]) { # for each leaf in frame
        nrules <- nrules + 1
        if(nrules > maxrules)
            stopf("too many rules (maximum number of rules is %d)", maxrules)
        ret <- get.rule(obj, rules[nrules,], nrules.per.var,
                        iframe, node.numbers, node.labs, split.labs,
                        iframe.to.isplit.mat, trace)
            rules[nrules,] <- ret$rule
            nrules.per.var <- ret$nrules.per.var
    }
    trace1(trace, "\n")
    rules <- rules[1:nrules, , drop=FALSE]
    # rownames are node numbers (we use them only if nn=TRUE)
    rownames(rules) <- rownames(obj$frame)[is.leaf]
    list(rules=rules, nrules.per.var=nrules.per.var)
}
get.node.and.split.labs <- function(obj, extra, faclen, roundint, trace,
                                    facsep, under.percent)
{
    class.stats <- NULL
    if(obj$method == "class" || is.class.response(obj))
        class.stats <- get.class.stats(obj)
    extra <- handle.extra.for.rules(extra, obj, class.stats)
    list(node.labs =
            internal.node.labs(obj,
                node.fun=NULL, node.fun.name="NULL",
                type=TYPE0.default, extra=extra,
                under=FALSE, xsep=NULL,
                digits=-10, # we will apply digits later
                varlen=0,   # full variable names in raw data.frame
                prefix="", suffix="", class.stats, under.percent),
        split.labs =
            internal.split.labs(obj, type=TYPE4.fancy.all,
                digits=-10, # we will apply digits later
                varlen=0,   # full variable names in raw data.frame
                faclen=faclen, roundint=roundint,
                clip.facs=FALSE, # factor names are not stripped in raw data.frame
                clip.left.labs=FALSE, clip.right.labs=FALSE, xflip=FALSE,
                trace=trace, facsep=facsep,
                # note: parse split.lab code later relies on "|" in these args
                # we use | and not space to allow space in varnames
                eq="|=|", logical.eq="|=|", lt="|<|", ge="|>=|",
                split.prefix="", right.split.suffix="",
                split.suffix="", right.split.prefix=""))
}
handle.extra.for.rules <- function(extra, obj, class.stats)
{
    if(is.numeric(extra)) {
        stopifnot(length(extra) == 1)
        if(extra >= 100)
            extra <- extra - 100
        if(extra == EX0)
            extra <- get.default.extra(obj, class.stats) - 100
        else if(extra == EX1.NOBS ||
                extra == EX2.CLASS.RATE ||
                extra == EX3.MISCLASS.RATE) {
            warning0(
"extra=", extra, " is not supported by rpart.rules (although useable for plots)")
            extra <- get.default.extra(obj, class.stats) - 100
        } else if(extra == EX5.PROB.PER.CLASS.DONT ||
                  extra == EX7.PROB.2ND.CLASS.DONT ||
                  extra == EX11.PROB.ACROSS.ALL.2ND.CLASS.DONT)
            extra <- extra - 1 # must have class label for parse.split.lab
    } else if(is.auto(extra, n=1)) {
        extra <- get.default.extra(obj, class.stats)
    } else
        stop0("rpart.rules: illegal extra")
    if(obj$method == "poisson" || obj$method == "exp")
        extra <- 0
    if(extra < 100)
        extra <- extra + 100
    extra
}
# descendants is lifted verbatim from rpart version 4.1-13 dated 2018-02-23
descendants <- function(nodes, include = TRUE)
{
    n <- length(nodes)
    if (n == 1L) return(matrix(TRUE, 1L, 1L))
    ind <- 1:n
    desc <- matrix(FALSE, n, n)
    if (include) diag(desc) <- TRUE
    parents <- match((nodes %/% 2L), nodes)
    lev <- floor(log(nodes, base = 2))
    desc[1L, 2L:n] <- TRUE
    for (i in max(lev):2L) {
        desc[cbind(ind[parents[lev == i]], ind[lev == i])] <- TRUE
        parents[lev == i] <- parents[parents[lev == i]]
        lev[lev == i] <- i - 1L
    }
    desc
}
get.rule <- function(obj, rule, nrules.per.var,
                     iframe, node.numbers, node.labs, split.labs,
                     iframe.to.isplit.mat, trace)
{
    ret <- parse.node.lab(node.labs[iframe]) # note that iframe indexes a leaf
        rule$lab   <- ret$lab
        rule$fit   <- ret$fit
        rule$cover <- ret$cover

    rule$iclass <- floor(obj$frame[iframe, "yval"]) # predicted class for class models

    # The order of the splits in path is important.
    # Later splits take precedence because so if we have
    # path=c("root", "Girth< 16.15", "Girth< 12.45")
    # then 12.45 takes precedence over 16.15
    path <- split.labs[iframe.to.isplit.mat[, iframe]]
    trace1(trace, "iframe %3d node %3d path %s\n",
           iframe, node.numbers[iframe], bar.to.space(path))
    stopifnot(path[1] == "root")
    path <- path[-1] # drop root
    for(split.lab in path) {
        ret <- parse.split.lab(split.lab, trace)
        if(!var.is.in.rule(ret$varname, rule))
            nrules.per.var[ret$varname] <- nrules.per.var[ret$varname] + 1
        rule[paste0(ret$varname, ret$op)] <- ret$cut
    }
    list(rule=rule, nrules.per.var=nrules.per.var)
}
trace.print.rules <- function(rules, msg) # only used if trace > 0
{
    old.width <- options(width=1e4)$width
    on.exit(options(width=old.width))
    cat0(msg, ":\n")
    class(rules) <- "data.frame"
    print(rules) # all columns are character except iclass
    cat0("\n")
}
process.rules <- function(obj, rules, style, cover, nn, clip.facs,
                          rpart.predict, where,
                          eq, lt, ge, and, when, because, null.model,
                          digits, trace, varorder, varlen,
                          nrules.per.var, varnames, response.name,
                          is.class.response, ylevels)
{
    ret <- format.fit(rules$fit, digits, is.class.response)
        rules$fit <- ret$fit
        rowmaxs   <- ret$rowmaxs
        ncol.fit  <- ret$ncol.fit
    fit <- rules$fit # needed for fit.colname()
    rules <- order.rows(rules, rowmaxs)
    ret <- order.cols(rules, varorder, varnames, nrules.per.var)
        rules    <- ret$rules
        varnames <- ret$varnames
    if(varlen != 0) { # note that we do this only after using varnames to sort cols above
        ret <- apply.varlen.to.colnames(rules, varnames, varlen)
            colnames(rules) <- ret$colnames
            varnames        <- ret$shortnames
    }
    rules.cover <- rules$cover
    rules <- format.rules(rules, style, cover, clip.facs, eq, lt, ge, and, when,
                          digits, trace,
                          response.name, varnames, ncol.fit)
    # retain only used columns
    rules <- rules[, apply(rules, 2, function(col) any(col != "")), drop=FALSE]
    if(!rpart.predict) {
        # all columns unnamed, except first column (response.name)
        colnames(rules) <- c(response.name, repl("", ncol(rules)-1))
    } else { # rpart.predict
        if(nrow(obj$frame) == 1) # null model? (no rules)
            rules <- as.data.frame(matrix(paste0(because, paste0(" ", null.model)),
                                   nrow=nrow(rules)), stringsAsFactors=FALSE)
        else {
            # drop all columns up to column "when"
            iwhen <- match("when", colnames(rules))[1] # index of "when" column
            if(iwhen < ncol(rules)) {
                rules <- rules[, iwhen:ncol(rules), drop=FALSE]
                rules[1] <- because # replace "when" with "because"
            }
        colnames(rules) <- NULL # all columns unnamed
        }
    }
    if(cover) {
        # append cover column (space in "  cover" to shift printed cover column right)
        rules$cover <- sprint("%3.0f%%", as.numeric(rules.cover))
        colnames(rules)[ncol(rules)] <- "  cover"
    }
    if(nn) {
        colnames <- colnames(rules)
        rules <- cbind(rownames(rules), rules, stringsAsFactors=FALSE)
        colnames(rules) <- c("nn", colnames)
    }
    if(rpart.predict) {
        # must generate the rule for each element of where
        if(style != "wide")
            stop0("style = \"", style, "\" is not supported by rpart.predict")
        check.vec(where, "where", na.ok=TRUE)
        nn <- as.numeric(rownames(obj$frame)[where])
        stopifnot(!any(is.na(nn)))
        # rules indexed by node number
        rules.nn <- rules
        rules.nn[1:max(nn), ] <- rules[1, , drop=FALSE]
        for(name in rownames(rules))
            rules.nn[as.numeric(name),] <-
                rules[which(rownames(rules) == name), , drop=FALSE]
        rules <- rules.nn[nn, , drop=FALSE]
        rownames(rules) <- NULL
        # retain only used columns
        rules <- rules[, apply(rules, 2, function(col) any(col != "")), drop=FALSE]
    } else if(ncol.fit > 1) { # multiple probabilities
        # add column names etc. for a nice print
        colnames(rules)[2+nn] <- fit.colname(ylevels, fit, ncol.fit)
        rules[,2+nn] <- paste0("[", rules[,2+nn], "]")
        rules[,3+nn] <- paste0(" ", when)
    }
    # following trim is necessary because we may not use all elements in
    # format.gt and format.le although we called format() with all elements
    trim.leading.space.in.columns(rules)
}
format.fit <- function(fit, digits, is.class.response)
{
    fit <- strsplit(fit, " ", fixed=TRUE) # ".12 .34 .56" becomes ".12" ".34" ".56"
    nrow <- length(fit) # fit is a list
    fit <- matrix(as.numeric.na.ok(unlist(fit)), nrow=nrow, byrow=TRUE) # matrix of floats
    ncol.fit <- ncol(fit)
    rowmaxs <-
        if(ncol.fit == 2) {
            # binomial model with both probabilities predicted e.g. ".12" ".34"
            # want to order rows on prob of 2nd class (ignoring prob of 1st class)
            rowmaxs <- fit[,2]
        } else
            rowmaxs <- apply(fit, 1, max)     # max of each row of fits
    if(ncol.fit == 1) {
        fit <- # align for printing
            if(is.class.response) # probabilities, always 2 decimal places
                format(sprint("%.2f", fit), justify="right")
            else # anova, poisson, or exp model
                format(fit, digits=digits, justify="right")
    } else {
        # Multiple responses per split e.g. ".12" ".34" ".56".
        # Want fixed number of decimal places, don't use format()
        digits <- 2 # ignore digits argument to rpart.rules
        max.rowmaxs <- max(rowmaxs, na.rm=TRUE)
        format <- if(max.rowmaxs >= 1)                # a prob is 1.00 (or greater)?
                    sprint("%%%d.%df", digits+2, digits) # e.g. "%4.2f"
                  else
                    sprint("%%%d.%df", digits+1, digits) # e.g. "%3.2f"
        fit <- matrix(paste(sprint(format, fit)), nrow=nrow(fit)) # format each elem
        fit <- apply(fit, 1, paste, collapse=" ")     # make each row a single string
        fit <- if(max.rowmaxs >= 1)
                    gsub("0.", " .", fit, fixed=TRUE) # "0.12" becomes " .12"
                else
                    gsub("0.", ".", fit, fixed=TRUE)  # "0.12" becomes ".12"
    }
    list(fit=fit, rowmaxs=rowmaxs, ncol.fit=ncol.fit)
}
trim.surrounding.space <- function(s)
{
    # trim leading and trailing space (this trims a maximum of one space on each side)
    gsub("^ | $", "", s)
}
as.numeric.na.ok <- function(x) # as.numeric issues warning if NAs, we don't want that
{
    old.warn <- getOption("warn")   # no warning in as.numeric() if NAs in fit
    on.exit(options(warn=old.warn))
    options(warn=-1)                # temporarily turn off warnings
    as.numeric(x)
}
apply.varlen.to.colnames <- function(rules, varnames, varlen)
{
    shortnames <- my.abbreviate(varnames, varlen)
    colnames <- colnames(rules)
    for(i in seq_along(varnames)) {
        ivar <- 3 * i + 2 # index of "Girth=" column (skip lab,fit,iclass,cover)
        colnames[ivar:(ivar+2)] <-
            sub(varnames[i], shortnames[i], colnames[ivar:(ivar+2)], fixed=TRUE)
    }
    list(colnames=colnames, shortnames=shortnames)
}
format.rules <- function(rules, style, cover, clip.facs, eq, lt, ge, and, when,
                         digits, trace,
                         response.name, varnames, ncol.fit)
{
    n <- function() # generate a new column name
    {
        icol <<- icol + 1
        sprint("c%d", icol)
    }
    #--- format.rules starts here ---
    # build up the new rules column by column

    new <- if(ncol.fit > 1) # multiple responses, add class label
                data.frame(class=rules$lab, fit=rules$fit, stringsAsFactors=FALSE)
           else
                data.frame(fit=rules$fit, stringsAsFactors=FALSE)

    rownames(new) <- rownames(rules)
    new$when <- when
    icol <- 0 # global var for n() function
    # subsequent is TRUE if split is not the first for this rule
    subsequent <- repl(FALSE, nrow(rules))
    for(i in seq_along(varnames)) {
        # process three columns for the variable ("Girth=", "Girth<", "Girth>=")
        # simultaneously process all rows for each column
        varname <- varnames[i]

        ivar <- 3 * i + 2           # index of "Girth=" column (skip lab,fit,iclass,cover)
        rules.eq <- rules[, ivar]   # "Girth=", vector of strings
        rules.lt <- rules[, ivar+1] # "Girth<"
        rules.ge <- rules[, ivar+2] # "Girth>="

        is.eq    <- rules.eq != ""  # vector of bools
        is.lt    <- rules.lt != ""
        is.ge    <- rules.ge != ""
        lt.or.ge <- is.lt | is.ge

        if(any(is.eq)) { # pclass is 1st or 2nd, boolean = 1, boolean = 0
            new[,n()] <- ifelse(subsequent & is.eq, and, "")
            subsequent <- (subsequent | is.eq)
            # formatting for rules with "=" in them works like this:
            # clip.facs FALSE: sex is female & boolean1 = 1 & boolean2 = 0
            # clip.facs TRUE:  female        & boolean1     & not boolean2
            # the checks against "1" and "0" are for for logical variables
            new[,n()] <- ifelse(clip.facs  & is.eq & rules.eq == "0",
                            sprint("not %s", varname),
                         ifelse(clip.facs  & is.eq & rules.eq == "1",
                            sprint("%s    ", varname),
                         ifelse(!clip.facs & is.eq,
                            sprint("%s", varname), "")))
            new[,n()] <- ifelse(is.eq & !clip.facs, eq, "")
            new[,n()] <- ifelse(is.eq & (!clip.facs | (rules.eq != "1" & rules.eq != "0")),
                            rules.eq, "")
        } else if(any(lt.or.ge)) { # Girth < 12.45, Girth >= 16.25, or both
            new[,n()] <- ifelse(subsequent & lt.or.ge, and, "")
            subsequent <- (subsequent | lt.or.ge)
            # add "verysmall" so format rounds .5 upwards, not to even
            verysmall <- exp10(-abs(digits) - 8)
            if(any(is.lt)) # format for printing
                rules.lt <- format(as.numeric(rules.lt) + verysmall,
                                   digits=digits, justify="right")
            if(any(is.ge))
                rules.ge <- format(as.numeric(rules.ge) + verysmall,
                                   digits=digits, justify="right")
            new[,n()] <- ifelse(lt.or.ge, varname, "")
            if(any(is.lt & is.ge)) {
                # at least one row has both lt and ge although any particular row
                # could be Girth 12.45 to 16.25, Girth < 12.45, Girth >= 16.25
                new[,n()] <- ifelse(is.lt & !is.ge,  lt,
                             ifelse(is.ge & !is.lt,  ge,
                             ifelse(is.ge | is.lt,   eq, "")))
                new[,n()] <- ifelse(is.lt & is.ge,   rules.ge,
                             ifelse(is.lt, rules.lt, ""))
                new[,n()] <- ifelse(is.lt & is.ge,   "to",  "")
                new[,n()] <- ifelse(is.lt & is.ge,   rules.lt,
                             ifelse(is.ge, rules.ge, ""))
            } else {                # Girth < 12.45, Girth >= 16.25
                # no row has both lt and ge
                new[,n()] <- ifelse(is.lt,           lt,
                             ifelse(is.ge,           ge, ""))
                new[,n()] <- ifelse(is.lt, rules.lt,
                             ifelse(is.ge, rules.ge,  ""))
            }
        }
    }
    new
}
# predicting multiple probabilities
# add truncated class names like "1st 2nd 3rd" above the ".12 .34 .56"
fit.colname <- function(ylevels, fit, ncol.fit)
{
    # Aug 2019: commented following out because if last level in the response
    # is unused in the training data, then it does't appear in yval2
    # e.g. see "unusedlev" in the rpart.plot tests
    #
    # if(ncol.fit != length(ylevels)) # should never happen
    #     warning0("ncol(fit) ", ncol.fit, " != length(ylevels) ", length(ylevels),
    #              "\n           ylevels(fit):  ", quotify(ylevels), "\n")

    ylevels <- ylevels[1:ncol.fit] # necessary if last level(s) are unused in training data

    width <- unlist(gregexpr(" ", substring(fit, 2)))[1] # position of first space
    if(width < 1) # paranoia                             # substring to skip possible lead space
        width <- 1
    format <- sprint("%%%d.%ds", width, width) # e.g. "%3.3s"
    colname <- paste.collapse(sprint(format, ylevels))
    colname <- paste0(colname, " ") # space for "]" in column entries
    colname                         # "1st 2nd 3rd"
}
# if all elements in a column have leading space, trim that space
trim.leading.space.in.columns <- function(x)
{
    stopifnot(NROW(x) > 0)
    for(j in 1:NCOL(x)) {
        x1 <- x[,j]
        x1 <- x1[x1 != ""]
        len <- unlist(gregexpr("^ +", x1)) # length of leading spaces
        if(!is.null(len)) {
            min <- min(len)                    # shortest leading space
            if(min > 0)
                x[,j] <- substring(x[,j], min+1)
        }
    }
    x
}
# true if any of "Girth=", "Girth<", "Girth>=" is used
var.is.in.rule <- function(varname, rule)
{
    rule[paste0(varname, "=") ] != "" ||
    rule[paste0(varname, "<") ] != "" ||
    rule[paste0(varname, ">=")] != ""
}
# parse split.lab into varname, op, cut
parse.split.lab <- function(split.lab, trace)
{                                                   # split.lab="Girth|<|16.15"
    i <- gregexpr("|", split.lab, fixed=TRUE)[[1]]  # posn of "|" around "|<|"
    if(length(i) != 2 || i[2] < i[1] + 2)           # will fail if | in varname
        stopf("Cannot parse split.lab %s", bar.to.space(split.lab))
    varname <- substring(split.lab, 1, i[1]-1)      # "Girth"
    op      <- substring(split.lab, i[1]+1, i[2]-1) # "<"
    cut     <- substring(split.lab, i[2]+1)         # "16.5"
    trace2(trace,
           "                split.lab %-20.20s varname %s op %s cut %s\n",
           bar.to.space(split.lab), varname, op, cut)
    list(varname=varname, op=op, cut=cut)
}
parse.node.lab <- function(node.lab)
{
    err <- function(node.lab)
        stop0("Cannot parse node.lab \"",
               gsub("\n", "\\\\n", node.lab[1]), "\"") # convert newline to string "\n"

    #--- parse.node.lab starts here ---
    lab <- fit <- cover <- repl("", length(node.lab))
    i <- gregexpr("\n", node.lab, fixed=TRUE)[[1]]
    if(length(i) > 2) {
        # For labels like this: (note extra new line)
        # "Japan\n.00  .00  .00  .11  .89\n.00  .00  .00  .00  .00\n18%"
        i <- c(i[1], i[length(i)])
    }
    if(length(i) == 1) {
        # anova or poisson model "14.94\n25.81%"
        if(i <= 0)
            err(node.lab)
        fit   <- substr(node.lab, 1, i-1)
        cover <- substr(node.lab, i+1, nchar(node.lab)-1)
    } else if(length(i) == 2) {
        # class model: "died\n0.17\n61%" or "1st\n.63  .31  .06\n59%"
        lab   <- substr(node.lab, 1, i[1]-1)
        fit   <- substr(node.lab, i[1]+1, i[2]-1)
        fit   <- gsub("  ", " ", fit) # convert double space to single space
        fit   <- gsub("\n", " ", fit) # convert \n to space
        cover <- substr(node.lab, i[2]+1, nchar(node.lab)-1)
    } else
        err(node.lab)

    list(lab=lab, fit=fit, cover=cover)
}
# convert | to space and add quotes, for use in messages to the user
# (the "|" arises from call to internal.split.labs with eq="|=|", lt="|<|", ge="|>=|")
bar.to.space <- function(s)
{
    quote.with.c(gsub("|", " ", s, fixed=TRUE))
}
# Sort on iclass, then on fit, then order of rows in splits.
# For anova models, iclass and fit are the same.
# For class models, iclass is the fitted class (as an integer)
# Sort on order of rows in splits means left-to-right in tree.

order.rows <- function(rules, rowmaxs)
{
    order <- order(as.numeric(rules$iclass), rowmaxs, 1:length(rowmaxs), na.last=TRUE)
    rules[order, , drop=FALSE]
}
# change order of columns so most important variables first, override with varorder
order.cols <- function(rules, varorder, varnames, nrules.per.var)
{
    order <- order(nrules.per.var, decreasing=TRUE)
    if(!is.null(varorder)) {
        # move variables in varorder to front of order vector
        stopifnot(is.character(varorder))
        varorder <- rev(varorder) # rev(varorder) so first var prepended last
        pmatch <- pmatch(varorder, varnames, duplicates.ok=TRUE)
        for(i in seq_along(pmatch)) {
            if(is.na(pmatch[i]))
                warnf(
"varorder=\"%s\" does not uniquely match one of: %s",
                         varorder[i], paste.trunc(quotify(varnames)))
            else {
                order <- order[order != pmatch[i]] # remove the variable
                order <- c(pmatch[i], order)       # prepend it
            }
        }
    }
    ivar <- 3 * rep(order, each=3) + c(2, 3, 4) # lab,fit,iclass,cover and 3 cols per var
    list(rules    = rules[, c(1:4, ivar), drop=FALSE], # 1:4 is lab,fit,iclass,cover
         varnames = varnames[order])
}
print.style.tall <- function(rules, style, eq, and, when)
{
    newline.with.spaces <- function()
    {
        printf("\n")
        if(style == "tall")
            printf("   ")
        else # "tallw"
            printf(format, "", "  ", "") # prefix space to align with prolog
        printf(format2, "")
        if(have.nn)
            printf(format.nn, "")
    }
    #--- print.style.tall starts here ---
    colnames <- colnames(rules)
    ncol <- ncol(rules)
    have.nn <- colnames[1] == "nn"
    have.cover <- colnames[ncol] == "  cover"
    response.name <- colnames[1 + have.nn]
    class.probs   <- colnames[2 + have.nn]
    have.class.probs <- class.probs != "" # currently used only for multi class models
    if(is.null(and))
        and <- " & "

    # output will be as follows ("format" is for "survived is 0.93"):
    #
    # survived is 0.93 when
    #                  sex is female is
    #                  pclass is 1st or 2nd

    # get format for prolog of each rule   # e.g. "survived is 0.93"
    format <- sprint("%%-%ds %%s %%-%ds", # e.g. "%-8s %s %-4s"
                     nchar(response.name),
                     max(nchar(rules[, 1 + have.nn])))

    nn.width <- if(have.nn) max(nchar(rules[, 1]))+3 else 0 # +3 for "[] "
    format.nn <- sprint("%%-%ds", nn.width) # e.g. "%-0s" or "%-5s"

    format2 <- "%0.0s"
    if(have.class.probs) {
        # output will be as follows ("format" is for "pclass is 1st"
        # and "format2" is for "[.74 .16 .10]"):
        #
        # pclass is 1st [.74 .16 .10] when
        #                             age >= 35
        #                             survived is survived

        # prefix spaces to align with prolog
        printf(format.nn, "")
        printf(format, "", "  ", "")
        printf("  %s\n", colnames[2 + have.nn])
        if(style == "tallw")
            format2 <- sprint("%%-%ds  ", nchar(colnames[2 + have.nn]))
    }
    for(i in 1:nrow(rules)) {
        if(have.nn)
            printf(format.nn, sprint("[%s] ", rules[i, 1]))
        printf(format, response.name, eq, rules[i, 1 + have.nn])
        for(j in (2 + have.nn):(ncol(rules) - have.cover)) {
            e <- trim.surrounding.space(rules[i, j])
            if(nchar(e)) {
                if(e == when) {
                    if(have.cover)
                        printf(" with cover %-s", gsub("^ *", "", rules[i, ncol]))
                    printf(" %s", if(when == ":EMPTY:") "" else when)
                    newline.with.spaces()
                } else if(e == and)
                    newline.with.spaces()
                else
                    printf(" %s", e)
            }
        }
        printf("\n")
        if(i != nrow(rules))
            printf("\n")
    }
}
# stop.if.dot.arg.used will cause an error message if any args are passed to it.
# We use it to test if any dots arg of the calling function was used, for
# functions that must have a dots arg (to match the generic method) but don't
# actually use the dots.  This helps the user catch mistyped or illegal args.
stop.if.dot.arg.used <- function()
{
    NULL
}
# this also issues an error if an illegal argument name is attempted
# following args must match args of prp() except where commented below
check.if.dot.arg.supported.by.rpart.rules <- function(x=stop("no 'x' arg"),
    type=0,
    extra="auto", # different default
    under=FALSE, fallen.leaves=FALSE,
    nn=FALSE, ni=FALSE, yesno=TRUE,
    branch=if(fallen.leaves) 1 else .2,
    uniform=TRUE, left=TRUE, xflip=FALSE, yflip=FALSE,
    digits=2,
    varlen=0, faclen=0, # different defaults
    # roundint=TRUE,
    cex=NULL, tweak=1,
    # clip.facs=FALSE,
    clip.right.labs=TRUE,
    compress=TRUE, ycompress=uniform,
    Margin=0, space=1, gap=NULL,
    snip=FALSE, snip.fun=NULL, trace=FALSE,

    box.col=0, box.palette=0,
    pal.thresh=NULL, pal.node.fun=FALSE,
    border.col=col,
    round=NULL, leaf.round=NULL,
    shadow.col=0, prefix="", suffix="", xsep=NULL,

    under.font=1, under.col=1, under.cex=.8,

    split.cex=1, split.font=2, split.family=1, split.col=1,
    split.box.col=0, split.border.col=0,
    split.lty=1, split.lwd=NULL, split.round=0,
    split.shadow.col=0,
    split.prefix="", right.split.prefix=NULL,
    split.suffix="", right.split.suffix=NULL,
    facsep=" or ", eq=" is ", # different defaults
    lt=" <  ", ge=" >= ",

    branch.col=if(is.zero(branch.type)) 1 else "gray",
    branch.lty=1, branch.lwd=NULL,
    branch.type=0, branch.tweak=1,
    min.branch.width=.002, branch.fill=branch.col,

    nn.cex=NULL, nn.font=3, nn.family="", nn.col=1,
    nn.box.col=0, nn.border.col=nn.col,
    nn.lty=1, nn.lwd=NULL, nn.round=.3,
    yes.text="yes", no.text="no",

    node.fun=NULL,
    split.fun=NULL,
    FUN="text",

    nspace=branch, minbranch=.3, do.par=TRUE,
    add.labs=TRUE,
    clip.left.labs=(type == 5),
    fam.main="",
    yshift=0, yspace=space, shadow.offset=.4,

    split.adj=NULL, split.yshift=0, split.space=space,
    split.yspace=yspace, split.shadow.offset=shadow.offset,

    nn.adj=.5, nn.yshift=0, nn.space=.8, nn.yspace=.5,

    ygap=gap/2, under.ygap=.5, yesno.yshift=0,
    xcompact=TRUE, ycompact=uniform, xcompact.ratio=.8, min.inter.height=4,

    max.auto.cex=1, min.auto.cex=.15, ycompress.cex=.7, accept.cex=1.1,
    shift.amounts=c(1.5, 2),
    Fallen.yspace=.1,
    boxes.include.gap=FALSE,
    legend.x=NULL, legend.y=NULL, legend.cex=1,
    # extra args for rpart.rules, not in prp arg list
    and=" & ", when=" when ", because=" because ", null.model="null model",
    response.name=NULL,
    # hidden args for rpart.predict
    RPART.PREDICT=FALSE, WHERE=NULL)
{
    warn1 <- function(arg)
    {
        warnf("rpart.rules: ignoring argument '%s'", deparse(substitute(arg)))
    }
    # if(!missing(x))                 warn1(x)
    if(!missing(type))                warn1(type)
    # if(!missing(extra))             warn1(extra)
    if(!missing(under))               warn1(under)
    if(!missing(fallen.leaves))       warn1(fallen.leaves)
    if(!missing(nn))                  warn1(nn)
    if(!missing(ni))                  warn1(ni)
    if(!missing(yesno))               warn1(yesno)
    if(!missing(branch))              warn1(branch)
    if(!missing(uniform))             warn1(uniform)
    if(!missing(left))                warn1(left)
    if(!missing(xflip))               warn1(xflip)
    if(!missing(yflip))               warn1(yflip)
    # if(!missing(digits))            warn1(digits)
    # if(!missing(varlen))            warn1(varlen)
    # if(!missing(faclen))            warn1(faclen)
    # if(!missing(roundint))          warn1(roundint)
    if(!missing(cex))                 warn1(cex)
    if(!missing(tweak))               warn1(tweak)
    # if(!missing(clip.facs))         warn1(clip.facs)
    if(!missing(clip.right.labs))     warn1(clip.right.labs)
    if(!missing(compress))            warn1(compress)
    if(!missing(ycompress))           warn1(ycompress)
    if(!missing(Margin))              warn1(Margin)
    if(!missing(space))               warn1(space)
    if(!missing(gap))                 warn1(gap)
    if(!missing(snip))                warn1(snip)
    if(!missing(snip.fun))            warn1(snip.fun)
    # if(!missing(trace))             warn1(trace)
    if(!missing(box.col))             warn1(box.col)
    if(!missing(box.palette))         warn1(box.palette)
    if(!missing(pal.thresh))          warn1(pal.thresh)
    if(!missing(pal.node.fun))        warn1(pal.node.fun)
    if(!missing(border.col))          warn1(border.col)
    if(!missing(round))               warn1(round)
    if(!missing(leaf.round))          warn1(leaf.round)
    if(!missing(shadow.col))          warn1(shadow.col)
    if(!missing(prefix))              warn1(prefix)
    if(!missing(suffix))              warn1(suffix)
    if(!missing(xsep))                warn1(xsep)
    if(!missing(under.font))          warn1(under.font)
    if(!missing(under.col))           warn1(under.col)
    if(!missing(under.cex))           warn1(under.cex)
    if(!missing(split.cex))           warn1(split.cex)
    if(!missing(split.font))          warn1(split.font)
    if(!missing(split.family))        warn1(split.family)
    if(!missing(split.col))           warn1(split.col)
    if(!missing(split.box.col))       warn1(split.box.col)
    if(!missing(split.border.col))    warn1(split.border.col)
    if(!missing(split.lty))           warn1(split.lty)
    if(!missing(split.lwd))           warn1(split.lwd)
    if(!missing(split.round))         warn1(split.round)
    if(!missing(split.shadow.col))    warn1(split.shadow.col)
    if(!missing(split.prefix))        warn1(split.prefix)
    if(!missing(right.split.prefix))  warn1(right.split.prefix)
    if(!missing(split.suffix))        warn1(split.suffix)
    if(!missing(right.split.suffix))  warn1(right.split.suffix)
    # if(!missing(facsep))            warn1(facsep)
    # if(!missing(eq))                warn1(eq)
    # if(!missing(lt))                warn1(lt)
    # if(!missing(ge))                warn1(ge)
    if(!missing(branch.col))          warn1(branch.col)
    if(!missing(branch.lty))          warn1(branch.lty)
    if(!missing(branch.lwd))          warn1(branch.lwd)
    if(!missing(branch.type))         warn1(branch.type)
    if(!missing(branch.tweak))        warn1(branch.tweak)
    if(!missing(min.branch.width))    warn1(min.branch.width)
    if(!missing(branch.fill))         warn1(branch.fill)
    if(!missing(nn.cex))              warn1(nn.cex)
    if(!missing(nn.font))             warn1(nn.font)
    if(!missing(nn.family))           warn1(nn.family)
    if(!missing(nn.col))              warn1(nn.col)
    if(!missing(nn.box.col))          warn1(nn.box.col)
    if(!missing(nn.border.col))       warn1(nn.border.col)
    if(!missing(nn.lty))              warn1(nn.lty)
    if(!missing(nn.lwd))              warn1(nn.lwd)
    if(!missing(nn.round))            warn1(nn.round)
    if(!missing(yes.text))            warn1(yes.text)
    if(!missing(no.text))             warn1(no.text)
    if(!missing(node.fun))            warn1(node.fun)
    if(!missing(split.fun))           warn1(split.fun)
    if(!missing(FUN))                 warn1(FUN)
    if(!missing(nspace))              warn1(nspace)
    if(!missing(minbranch))           warn1(minbranch)
    if(!missing(do.par))              warn1(do.par)
    if(!missing(add.labs))            warn1(add.labs)
    if(!missing(clip.left.labs))      warn1(clip.left.labs)
    if(!missing(fam.main))            warn1(fam.main)
    if(!missing(yshift))              warn1(yshift)
    if(!missing(yspace))              warn1(yspace)
    if(!missing(shadow.offset))       warn1(shadow.offset)
    if(!missing(split.adj))           warn1(split.adj)
    if(!missing(split.yshift))        warn1(split.yshift)
    if(!missing(split.space))         warn1(split.space)
    if(!missing(split.yspace))        warn1(split.yspace)
    if(!missing(split.shadow.offset)) warn1(split.shadow.offset)
    if(!missing(nn.adj))              warn1(nn.adj)
    if(!missing(nn.yshift))           warn1(nn.yshift)
    if(!missing(nn.space))            warn1(nn.space)
    if(!missing(nn.yspace))           warn1(nn.yspace)
    if(!missing(ygap))                warn1(ygap)
    if(!missing(under.ygap))          warn1(under.ygap)
    if(!missing(yesno.yshift))        warn1(yesno.yshift)
    if(!missing(xcompact))            warn1(xcompact)
    if(!missing(ycompact))            warn1(ycompact)
    if(!missing(xcompact.ratio))      warn1(xcompact.ratio)
    if(!missing(min.inter.height))    warn1(min.inter.height)
    if(!missing(max.auto.cex))        warn1(max.auto.cex)
    if(!missing(min.auto.cex))        warn1(min.auto.cex)
    if(!missing(ycompress.cex))       warn1(ycompress.cex)
    if(!missing(accept.cex))          warn1(accept.cex)
    if(!missing(shift.amounts))       warn1(shift.amounts)
    if(!missing(Fallen.yspace))       warn1(Fallen.yspace)
    if(!missing(boxes.include.gap))   warn1(boxes.include.gap)
    if(!missing(legend.x))            warn1(legend.x)
    if(!missing(legend.y))            warn1(legend.y)
    if(!missing(legend.cex))          warn1(legend.cex)

    list(extra=extra, digits=digits, varlen=varlen, faclen=faclen, trace=trace,
         facsep=facsep, eq=eq, lt=lt, ge=ge, and=and,
         when=when, because=because, null.model=null.model,
         response.name=response.name,
         rpart.predict=RPART.PREDICT, where=WHERE)
}

Try the rpart.plot package in your browser

Any scripts or data that you put into this service are public.

rpart.plot documentation built on May 21, 2022, 9:05 a.m.