### R code from vignette source 'constparty.Rnw'
###################################################
### code chunk number 1: setup
###################################################
suppressWarnings(RNGversion("3.5.2"))
options(width = 70)
library("partykit")
set.seed(290875)
###################################################
### code chunk number 2: Titanic
###################################################
data("Titanic", package = "datasets")
ttnc <- as.data.frame(Titanic)
ttnc <- ttnc[rep(1:nrow(ttnc), ttnc$Freq), 1:4]
names(ttnc)[2] <- "Gender"
###################################################
### code chunk number 3: rpart
###################################################
library("rpart")
(rp <- rpart(Survived ~ ., data = ttnc, model = TRUE))
###################################################
### code chunk number 4: rpart-party
###################################################
(party_rp <- as.party(rp))
###################################################
### code chunk number 5: rpart-plot-orig
###################################################
plot(rp)
text(rp)
###################################################
### code chunk number 6: rpart-plot
###################################################
plot(party_rp)
###################################################
### code chunk number 7: rpart-pred
###################################################
all.equal(predict(rp), predict(party_rp, type = "prob"),
check.attributes = FALSE)
###################################################
### code chunk number 8: rpart-fitted
###################################################
str(fitted(party_rp))
###################################################
### code chunk number 9: rpart-prob
###################################################
prop.table(do.call("table", fitted(party_rp)), 1)
###################################################
### code chunk number 10: J48
###################################################
if (require("RWeka")) {
j48 <- J48(Survived ~ ., data = ttnc)
} else {
j48 <- rpart(Survived ~ ., data = ttnc)
}
print(j48)
###################################################
### code chunk number 11: J48-party
###################################################
(party_j48 <- as.party(j48))
###################################################
### code chunk number 12: J48-plot
###################################################
plot(party_j48)
###################################################
### code chunk number 13: J48-pred
###################################################
all.equal(predict(j48, type = "prob"), predict(party_j48, type = "prob"),
check.attributes = FALSE)
###################################################
### code chunk number 14: PMML-Titantic
###################################################
ttnc_pmml <- file.path(system.file("pmml", package = "partykit"),
"ttnc.pmml")
(ttnc_quest <- pmmlTreeModel(ttnc_pmml))
###################################################
### code chunk number 15: PMML-Titanic-plot1
###################################################
plot(ttnc_quest)
###################################################
### code chunk number 16: ttnc2-reorder
###################################################
ttnc2 <- ttnc[, names(ttnc_quest$data)]
for(n in names(ttnc2)) {
if(is.factor(ttnc2[[n]])) ttnc2[[n]] <- factor(
ttnc2[[n]], levels = levels(ttnc_quest$data[[n]]))
}
###################################################
### code chunk number 17: PMML-Titanic-augmentation
###################################################
ttnc_quest2 <- party(ttnc_quest$node,
data = ttnc2,
fitted = data.frame(
"(fitted)" = predict(ttnc_quest, ttnc2, type = "node"),
"(response)" = ttnc2$Survived,
check.names = FALSE),
terms = terms(Survived ~ ., data = ttnc2)
)
ttnc_quest2 <- as.constparty(ttnc_quest2)
###################################################
### code chunk number 18: PMML-Titanic-plot2
###################################################
plot(ttnc_quest2)
###################################################
### code chunk number 19: PMML-write
###################################################
library("pmml")
tfile <- tempfile()
write(toString(pmml(rp)), file = tfile)
###################################################
### code chunk number 20: PMML-read
###################################################
(party_pmml <- pmmlTreeModel(tfile))
all.equal(predict(party_rp, newdata = ttnc, type = "prob"),
predict(party_pmml, newdata = ttnc, type = "prob"),
check.attributes = FALSE)
###################################################
### code chunk number 21: mytree-1
###################################################
findsplit <- function(response, data, weights, alpha = 0.01) {
## extract response values from data
y <- factor(rep(data[[response]], weights))
## perform chi-squared test of y vs. x
mychisqtest <- function(x) {
x <- factor(x)
if(length(levels(x)) < 2) return(NA)
ct <- suppressWarnings(chisq.test(table(y, x), correct = FALSE))
pchisq(ct$statistic, ct$parameter, log = TRUE, lower.tail = FALSE)
}
xselect <- which(names(data) != response)
logp <- sapply(xselect, function(i) mychisqtest(rep(data[[i]], weights)))
names(logp) <- names(data)[xselect]
## Bonferroni-adjusted p-value small enough?
if(all(is.na(logp))) return(NULL)
minp <- exp(min(logp, na.rm = TRUE))
minp <- 1 - (1 - minp)^sum(!is.na(logp))
if(minp > alpha) return(NULL)
## for selected variable, search for split minimizing p-value
xselect <- xselect[which.min(logp)]
x <- rep(data[[xselect]], weights)
## set up all possible splits in two kid nodes
lev <- levels(x[drop = TRUE])
if(length(lev) == 2) {
splitpoint <- lev[1]
} else {
comb <- do.call("c", lapply(1:(length(lev) - 2),
function(x) combn(lev, x, simplify = FALSE)))
xlogp <- sapply(comb, function(q) mychisqtest(x %in% q))
splitpoint <- comb[[which.min(xlogp)]]
}
## split into two groups (setting groups that do not occur to NA)
splitindex <- !(levels(data[[xselect]]) %in% splitpoint)
splitindex[!(levels(data[[xselect]]) %in% lev)] <- NA_integer_
splitindex <- splitindex - min(splitindex, na.rm = TRUE) + 1L
## return split as partysplit object
return(partysplit(varid = as.integer(xselect),
index = splitindex,
info = list(p.value = 1 - (1 - exp(logp))^sum(!is.na(logp)))))
}
###################################################
### code chunk number 22: mytree-2
###################################################
growtree <- function(id = 1L, response, data, weights, minbucket = 30) {
## for less than 30 observations stop here
if (sum(weights) < minbucket) return(partynode(id = id))
## find best split
sp <- findsplit(response, data, weights)
## no split found, stop here
if (is.null(sp)) return(partynode(id = id))
## actually split the data
kidids <- kidids_split(sp, data = data)
## set up all daugther nodes
kids <- vector(mode = "list", length = max(kidids, na.rm = TRUE))
for (kidid in 1:length(kids)) {
## select observations for current node
w <- weights
w[kidids != kidid] <- 0
## get next node id
if (kidid > 1) {
myid <- max(nodeids(kids[[kidid - 1]]))
} else {
myid <- id
}
## start recursion on this daugther node
kids[[kidid]] <- growtree(id = as.integer(myid + 1), response, data, w)
}
## return nodes
return(partynode(id = as.integer(id), split = sp, kids = kids,
info = list(p.value = min(info_split(sp)$p.value, na.rm = TRUE))))
}
###################################################
### code chunk number 23: mytree-3
###################################################
mytree <- function(formula, data, weights = NULL) {
## name of the response variable
response <- all.vars(formula)[1]
## data without missing values, response comes last
data <- data[complete.cases(data), c(all.vars(formula)[-1], response)]
## data is factors only
stopifnot(all(sapply(data, is.factor)))
if (is.null(weights)) weights <- rep(1L, nrow(data))
## weights are case weights, i.e., integers
stopifnot(length(weights) == nrow(data) &
max(abs(weights - floor(weights))) < .Machine$double.eps)
## grow tree
nodes <- growtree(id = 1L, response, data, weights)
## compute terminal node number for each observation
fitted <- fitted_node(nodes, data = data)
## return rich constparty object
ret <- party(nodes, data = data,
fitted = data.frame("(fitted)" = fitted,
"(response)" = data[[response]],
"(weights)" = weights,
check.names = FALSE),
terms = terms(formula))
as.constparty(ret)
}
###################################################
### code chunk number 24: mytree-4
###################################################
(myttnc <- mytree(Survived ~ Class + Age + Gender, data = ttnc))
###################################################
### code chunk number 25: mytree-5
###################################################
plot(myttnc)
###################################################
### code chunk number 26: mytree-pval
###################################################
nid <- nodeids(myttnc)
iid <- nid[!(nid %in% nodeids(myttnc, terminal = TRUE))]
(pval <- unlist(nodeapply(myttnc, ids = iid,
FUN = function(n) info_node(n)$p.value)))
###################################################
### code chunk number 27: mytree-nodeprune
###################################################
myttnc2 <- nodeprune(myttnc, ids = iid[pval > 1e-5])
###################################################
### code chunk number 28: mytree-nodeprune-plot
###################################################
plot(myttnc2)
###################################################
### code chunk number 29: mytree-glm
###################################################
logLik(glm(Survived ~ Class + Age + Gender, data = ttnc,
family = binomial()))
###################################################
### code chunk number 30: mytree-bs
###################################################
bs <- rmultinom(25, nrow(ttnc), rep(1, nrow(ttnc)) / nrow(ttnc))
###################################################
### code chunk number 31: mytree-ll
###################################################
bloglik <- function(prob, weights)
sum(weights * dbinom(ttnc$Survived == "Yes", size = 1,
prob[,"Yes"], log = TRUE))
###################################################
### code chunk number 32: mytree-bsll
###################################################
f <- function(w) {
tr <- mytree(Survived ~ Class + Age + Gender, data = ttnc, weights = w)
bloglik(predict(tr, newdata = ttnc, type = "prob"), as.numeric(w == 0))
}
apply(bs, 2, f)
###################################################
### code chunk number 33: mytree-node
###################################################
nttnc <- expand.grid(Class = levels(ttnc$Class),
Gender = levels(ttnc$Gender), Age = levels(ttnc$Age))
nttnc
###################################################
### code chunk number 34: mytree-prob
###################################################
predict(myttnc, newdata = nttnc, type = "node")
predict(myttnc, newdata = nttnc, type = "response")
predict(myttnc, newdata = nttnc, type = "prob")
###################################################
### code chunk number 35: mytree-FUN
###################################################
predict(myttnc, newdata = nttnc, FUN = function(y, w)
rank(table(rep(y, w))))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.