tests/testthat/data/rapp.test.1/packrat/lib-R/rpart/tests/priors.R

#
# A hard-core test of losses and priors
#  Simple data set where I know what the answers must be
#
library(rpart)
aeq <- function(x,y, ...) all.equal(as.vector(x), as.vector(y), ...)

dummy <- c(3,1,4,1,5,9,2,6,5,3,5,8,9,7,9)/5
pdata <- data.frame(y=factor(rep(1:3, 5)),
                    x1 = 1:15,
                    x2 = c(1:6, 1:6, 1:3),
                    x3 = (rep(1:3, 5) + dummy)*10)

pdata$x3[c(1,5,10)] <- NA
pdata$y[15] <- 1     # make things unbalanced

pfit <- rpart(y ~ x1 + x2 + x3, pdata,
              cp=0, xval=0, minsplit=5, maxdepth=1,
              parms=list(prior=c(.2, .3, .5), 
                         loss =matrix(c(0,2,2,2,0,6,1,1,0), 3,3,byrow=T)))

#
# See section 12.1 of the report for these numbers
#

ntot   <- c(6,5,4)
phat   <- c(6,5,4)/15   # observed class probabilities
prior  <- c(.2, .3, .5) # priors
aprior <- c(4,12,5)/21  # altered priors
lmat   <- matrix(c(0,1,2, 2,0,1, 2,6,0), ncol=3)  #loss matrix

gini <- function(p) 1-sum(p^2)
loss <- function(n, class)  sum(n * lmat[,class])

phat <- function(n, ntot=c(6,5,4), prior=c(.2, .3, .5)) {
    n*prior/ntot
    }

# Are the losses correct?
# Class counts for the two children are (4,4,0) and (2,1,4), when
#   using surrogates
aeq(pfit$frame$dev/15, c(loss(prior,2), loss(phat(c(4,4,0)),2), 
                                 loss(phat(c(2,1,4)),3)))
# Node probabilities?
aeq(pfit$frame$yval2[,8] , 
    c(1, sum(phat(c(4,4,0))), sum(phat(c(2,1,4)))))

aeq(pfit$frame$yval2[,5:7] , rbind(prior, 
                                   phat(c(4,4,0))/ sum(phat(c(4,4,0))), 
                                   phat(c(2,1,4))/ sum(phat(c(2,1,4)))))

# Now the node and class probs, under altered priors
phat2 <- function(n, ntot=c(6,5,4), prior=aprior) {
    n*prior/ntot
    }

# Use these to create the gini losses, base data, and for the best
#  splits on variables 1, 2, 3
gfun <- function(n) {  #The gini loss for a node, given the counts
    temp <- phat2(n)
    sum(temp) * gini(temp/sum(temp))
    }
           
# These are in order x3, x2, x1 (best split to worst)
# Note that for x3, missing values cause the "parent" to be viewed as
#   having 12 obs instead of 15.
# Each line is gini(parent) - gini(children)
aeq(pfit$splits[1:3, 3],
    15* c(gfun(c(4,4,4)) - (gfun(c(3,4,0)) +  gfun(c(1,0,4))),
          gfun(c(6,5,4)) - (gfun(c(6,5,2)) +  gfun(c(0,0,2))),
          gfun(c(6,5,4)) - (gfun(c(4,4,4)) +  gfun(c(2,1,0)))))
rappster/rapp documentation built on May 26, 2019, 11:56 p.m.