inst/slowtests/usersplits.R

# usersplits.R: copied and extended from rpart/tests (july 2018, rpart version 4.1-13)

source("test.prolog.R")

# Any necessary setup
library(rpart)
options(na.action="na.omit")
options(digits=4) # to match earlier output
set.seed(1234)

mystate <- data.frame(state.x77, region=factor(state.region))
names(mystate) <- c("population","income" , "illiteracy","life" ,
       "murder", "hs.grad", "frost",     "area",      "region")
#
# Test out the "user mode" functions, with an anova variant
#

# The 'evaluation' function.  Called once per node.
#  Produce a label (1 or more elements long) for labeling each node,
#  and a deviance.  The latter is
#       - of length 1
#       - equal to 0 if the node is "pure" in some sense (unsplittable)
#       - does not need to be a deviance: any measure that gets larger
#            as the node is less acceptable is fine.
#       - the measure underlies cost-complexity pruning, however
temp1 <- function(y, wt, parms) {
    wmean <- sum(y*wt)/sum(wt)
    rss <- sum(wt*(y-wmean)^2)
    list(label= wmean, deviance=rss)
    }

# The split function, where most of the work occurs.
#   Called once per split variable per node.
# If continuous=T
#   The actual x variable is ordered
#   y is supplied in the sort order of x, with no missings,
#   return two vectors of length (n-1):
#      goodness = goodness of the split, larger numbers are better.
#                 0 = couldn't find any worthwhile split
#        the ith value of goodness evaluates splitting obs 1:i vs (i+1):n
#      direction= -1 = send "y< cutpoint" to the left side of the tree
#                  1 = send "y< cutpoint" to the right
#         this is not a big deal, but making larger "mean y's" move towards
#         the right of the tree, as we do here, seems to make it easier to
#         read
# If continuos=F, x is a set of integers defining the groups for an
#   unordered predictor.  In this case:
#       direction = a vector of length m= "# groups".  It asserts that the
#           best split can be found by lining the groups up in this order
#           and going from left to right, so that only m-1 splits need to
#           be evaluated rather than 2^(m-1)
#       goodness = m-1 values, as before.
#
# The reason for returning a vector of goodness is that the C routine
#   enforces the "minbucket" constraint. It selects the best return value
#   that is not too close to an edge.
temp2 <- function(y, wt, x, parms, continuous) {
    # Center y
    n <- length(y)
    y <- y- sum(y*wt)/sum(wt)

    if (continuous) {
        # continuous x variable
        temp <- cumsum(y*wt)[-n]

        left.wt  <- cumsum(wt)[-n]
        right.wt <- sum(wt) - left.wt
        lmean <- temp/left.wt
        rmean <- -temp/right.wt
        goodness <- (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2)
        list(goodness= goodness, direction=sign(lmean))
        }
    else {
        # Categorical X variable
        ux <- sort(unique(x))
        wtsum <- tapply(wt, x, sum)
        ysum  <- tapply(y*wt, x, sum)
        means <- ysum/wtsum

        # For anova splits, we can order the categories by their means
        #  then use the same code as for a non-categorical
        ord <- order(means)
        n <- length(ord)
        temp <- cumsum(ysum[ord])[-n]
        left.wt  <- cumsum(wtsum[ord])[-n]
        right.wt <- sum(wt) - left.wt
        lmean <- temp/left.wt
        rmean <- -temp/right.wt
        list(goodness= (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2),
             direction = ux[ord])
        }
    }

# The init function:
#   fix up y to deal with offsets
#   return a dummy parms list
#   numresp is the number of values produced by the eval routine's "label"
#   numy is the number of columns for y
#   summary is a function used to print one line in summary.rpart
# In general, this function would also check for bad data, see rpart.poisson
#   for instace.
temp3 <- function(y, offset, parms, wt) {
    if (!is.null(offset)) y <- y-offset
    list(y=y, parms=0, numresp=1, numy=1,
              summary= function(yval, dev, wt, ylevel, digits ) {
                  paste("  mean=", format(signif(yval, digits)),
                        ", MSE=" , format(signif(dev/wt, digits)),
                        sep='')
             })
    }


alist <- list(eval=temp1, split=temp2, init=temp3)

fit1 <- rpart(income ~population +illiteracy  + murder + hs.grad + region,
             mystate, control=rpart.control(minsplit=10, xval=0),
             method=alist)

rpart.plot(fit1, clip.facs=T, varlen=-4, faclen=4, main="fit1")
rpart.rules(fit1, clip.facs=T, varlen=-4, faclen=4)
head(rpart.predict(fit1, rules=TRUE))
stopifnot(identical(predict(fit1), rpart.predict(fit1)))
stopifnot(max(abs(predict(fit1) - rpart.predict(fit1, rules=TRUE)[,1])) == 0)
stopifnot(identical(predict(fit1, newdata=mystate[5:7,]), rpart.predict(fit1, newdata=mystate[5:7,])))
stopifnot(max(abs(predict(fit1, newdata=mystate[5:7,]) - rpart.predict(fit1, rules=TRUE, newdata=mystate[5:7,])[,1])) == 0)

fit2 <- rpart(income ~population +illiteracy + murder + hs.grad + region,
             mystate, control=rpart.control(minsplit=10, xval=0),
              method='anova')

rpart.plot(fit2, clip.facs=T, varlen=-4, faclen=4, main="fit2")
rpart.rules(fit2, clip.facs=T, varlen=-4, faclen=4)
head(rpart.predict(fit2, rules=TRUE, nn=TRUE))
stopifnot(identical(predict(fit2), rpart.predict(fit2)))
stopifnot(max(abs(predict(fit2) - rpart.predict(fit2, rules=TRUE)[,1])) == 0)
stopifnot(identical(predict(fit2, newdata=mystate[5:7,]), rpart.predict(fit2, newdata=mystate[5:7,])))
stopifnot(max(abs(predict(fit2, newdata=mystate[5:7,]) - rpart.predict(fit2, rules=TRUE, newdata=mystate[5:7,])[,1])) == 0)

# Other than their call statement, and a longer "functions" component in
#  fit1, fit1 and fit2 should be identical.
all.equal(fit1$frame, fit2$frame)
all.equal(fit1$splits, fit2$splits)
all.equal(fit1$csplit, fit2$csplit)
all.equal(fit1$where, fit2$where)
all.equal(fit1$cptable, fit2$cptable)

# Now try xpred on it
xvtemp <- rep(1:5, length=50)
xp1 <- xpred.rpart(fit1, xval=xvtemp)
xp2 <- xpred.rpart(fit2, xval=xvtemp)
aeq <- function(x,y) all.equal(as.vector(x), as.vector(y))
stopifnot(aeq(xp1, xp2))

fit3 <- rpart(income ~population +illiteracy + murder + hs.grad + region,
             mystate, control=rpart.control(minsplit=10, xval=xvtemp),
              method='anova')
rpart.plot(fit3, clip.facs=T, varlen=-4, faclen=4, main="fit3")
rpart.rules(fit3, clip.facs=T, varlen=-4, faclen=4)
head(rpart.predict(fit3, rules=TRUE, nn=TRUE))
stopifnot(identical(predict(fit3), rpart.predict(fit3)))
stopifnot(max(abs(predict(fit3) - rpart.predict(fit3, rules=TRUE)[,1])) == 0)
stopifnot(identical(predict(fit3, newdata=mystate[5:7,]), rpart.predict(fit3, newdata=mystate[5:7,])))
stopifnot(max(abs(predict(fit3, newdata=mystate[5:7,]) - rpart.predict(fit3, newdata=mystate[5:7,], rules=TRUE)[,1])) == 0)
zz <- apply((mystate$income - xp1)^2,2, sum)
stopifnot(aeq(zz/fit1$frame$dev[1], fit3$cptable[,4]))  #reproduce xerror

zz2 <- sweep((mystate$income-xp1)^2,2, zz/nrow(xp1))
zz2 <- sqrt(apply(zz2^2, 2, sum))/ fit1$frame$dev[1]
stopifnot(aeq(zz2, fit3$cptable[,5]))          #reproduce se(xerror)

#--- following copied and extend from usercode.R vignette code ---

###################################################
### code chunk number 1: usercode.Rnw:26-28
###################################################
options(continue="  ", width=60)
options(SweaveHooks=list(fig=function() par(mar=c(5.1, 4.1, .3, 1.1))))


###################################################
### code chunk number 2: usercode.Rnw:85-99
###################################################
itemp <- function(y, offset, parms, wt) {
    if (is.matrix(y) && ncol(y) > 1)
       stop("Matrix response not allowed")
    if (!missing(parms) && length(parms) > 0)
        warning("parameter argument ignored")
    if (length(offset)) y <- y - offset
    sfun <- function(yval, dev, wt, ylevel, digits ) {
                  paste("  mean=", format(signif(yval, digits)),
                        ", MSE=" , format(signif(dev/wt, digits)),
                        sep = '')
    }
    environment(sfun) <- .GlobalEnv
    list(y = c(y), parms = NULL, numresp = 1, numy = 1, summary = sfun)
}


###################################################
### code chunk number 3: usercode.Rnw:155-163
###################################################
temp <- 4
fun1 <- function(x) {
    q <- 15
    z <- 10
    fun2 <- function(y) y + z + temp
    fun2(x^2)
}
fun1(5)


###################################################
### code chunk number 4: usercode.Rnw:194-199
###################################################
etemp <- function(y, wt, parms) {
    wmean <- sum(y*wt)/sum(wt)
    rss <- sum(wt*(y-wmean)^2)
    list(label = wmean, deviance = rss)
}


###################################################
### code chunk number 5: usercode.Rnw:249-284
###################################################
stemp <- function(y, wt, x, parms, continuous)
{
    # Center y
    n <- length(y)
    y <- y- sum(y*wt)/sum(wt)

    if (continuous) {
        # continuous x variable
        temp <- cumsum(y*wt)[-n]
        left.wt  <- cumsum(wt)[-n]
        right.wt <- sum(wt) - left.wt
        lmean <- temp/left.wt
        rmean <- -temp/right.wt
        goodness <- (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2)
        list(goodness = goodness, direction = sign(lmean))
    } else {
        # Categorical X variable
        ux <- sort(unique(x))
        wtsum <- tapply(wt, x, sum)
        ysum  <- tapply(y*wt, x, sum)
        means <- ysum/wtsum

        # For anova splits, we can order the categories by their means
        #  then use the same code as for a non-categorical
        ord <- order(means)
        n <- length(ord)
        temp <- cumsum(ysum[ord])[-n]
        left.wt  <- cumsum(wtsum[ord])[-n]
        right.wt <- sum(wt) - left.wt
        lmean <- temp/left.wt
        rmean <- -temp/right.wt
        list(goodness= (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2),
             direction = ux[ord])
    }
}


###################################################
### code chunk number 6: usercode.Rnw:327-342
###################################################
library(rpart)
mystate <- data.frame(state.x77, region=state.region)
names(mystate) <- casefold(names(mystate)) #remove mixed case
ulist <- list(eval = etemp, split = stemp, init = itemp)
fit1 <- rpart(murder ~ population + illiteracy + income + life.exp +
              hs.grad + frost + region, data = mystate,
              method = ulist, minsplit = 10)
rpart.rules(fit1)
fit2 <- rpart(murder ~ population + illiteracy + income + life.exp +
              hs.grad + frost + region, data = mystate,
              method = 'anova', minsplit = 10, xval = 0)
rpart.rules(fit2)
all.equal(fit1$frame, fit2$frame)
all.equal(fit1$splits, fit2$splits)
all.equal(fit1$csplit, fit2$csplit)
all.equal(fit1$where, fit2$where)
all.equal(fit1$cptable, fit2$cptable)


###################################################
### code chunk number 7: usercode.Rnw:358-369
###################################################
xgroup <- rep(1:10, length = nrow(mystate))
xfit <- xpred.rpart(fit1, xgroup)
xerror <- colMeans((xfit - mystate$murder)^2)

fit2b <-  rpart(murder ~ population + illiteracy + income + life.exp +
                hs.grad + frost + region, data = mystate,
                method = 'anova', minsplit = 10, xval = xgroup)
rpart.rules(fit2b)
topnode.error <- (fit2b$frame$dev/fit2b$frame$wt)[1]

xerror.relative <- xerror/topnode.error
all.equal(xerror.relative, fit2b$cptable[, 4], check.attributes = FALSE)


###################################################
### code chunk number 8: fig1
###################################################
getOption("SweaveHooks")[["fig"]]()
tdata <- mystate[order(mystate$illiteracy), ]
n <- nrow(tdata)
temp <- stemp(tdata$income, wt = rep(1, n), tdata$illiteracy,
              parms = NULL, continuous = TRUE)
xx <- (tdata$illiteracy[-1] + tdata$illiteracy[-n])/2
plot(xx, temp$goodness, xlab = "Illiteracy cutpoint",
     ylab = "Goodness of split")
lines(smooth.spline(xx, temp$goodness, df = 4), lwd = 2, lty = 2)


###################################################
### code chunk number 9: usercode.Rnw:438-458
###################################################
loginit <- function(y, offset, parms, wt)
{
    if (is.null(offset)) offset <- 0
    if (any(y != 0 & y != 1)) stop ('response must be 0/1')

    sfun <- function(yval, dev, wt, ylevel, digits ) {
                  paste("events=",  round(yval[,1]),
                        ", coef= ", format(signif(yval[,2], digits)),
                        ", deviance=" , format(signif(dev, digits)),
                        sep = '')}
    environment(sfun) <- .GlobalEnv
    list(y = cbind(y, offset), parms = 0, numresp = 2, numy = 2,
         summary = sfun)
    }

logeval <- function(y, wt, parms)
{
    tfit <- glm(y[,1] ~ offset(y[,2]), binomial, weight = wt)
    list(label= c(sum(y[,1]), tfit$coef), deviance = tfit$deviance)
}


###################################################
### code chunk number 10: usercode.Rnw:466-502
###################################################
logsplit <- function(y, wt, x, parms, continuous)
{
    if (continuous) {
        # continuous x variable: do all the logistic regressions
        n <- nrow(y)
        goodness <- double(n-1)
        direction <- goodness
        temp <- rep(0, n)
        for (i in 1:(n-1)) {
            temp[i] <- 1
            if (x[i] != x[i+1]) {
                tfit <- glm(y[,1] ~ temp + offset(y[,2]), binomial, weight = wt)
                goodness[i] <- tfit$null.deviance - tfit$deviance
                direction[i] <- sign(tfit$coef[2])
            }
        }
    } else {
        # Categorical X variable
        # First, find out what order to put the categories in, which
        #  will be the order of the coefficients in this model
        tfit <- glm(y[,1] ~ factor(x) + offset(y[,2]) - 1, binomial, weight = wt)
        ngrp <- length(tfit$coef)
        direction <- rank(rank(tfit$coef) + runif(ngrp, 0, 0.1)) #break ties
        # breaking ties -- if 2 groups have exactly the same p-hat, it
        #  does not matter which order I consider them in.  And the calling
        #  routine wants an ordering vector.
        #
        xx <- direction[match(x, sort(unique(x)))] #relabel from small to large
        goodness <- double(length(direction) - 1)
        for (i in 1:length(goodness)) {
            tfit <- glm(y[,1] ~ I(xx > i) + offset(y[,2]), binomial, weight = wt)
            goodness[i] <- tfit$null.deviance - tfit$deviance
        }
    }
    list(goodness=goodness, direction=direction)
}

source("test.epilog.R")

Try the rpart.plot package in your browser

Any scripts or data that you put into this service are public.

rpart.plot documentation built on May 29, 2024, 12:07 p.m.