library(learnr) library(titanic) library(rpart) library(partykit) library(caret) library(verification)
In this notebook we use the Titanic data that is used on Kaggle (https://www.kaggle.com) as an introductory competition for getting familiar with machine learning. It includes information on a set of Titanic passengers, such as age, sex, ticket class and whether he or she survived the Titanic tragedy (Note that the titanic
package also provides a separate test set that precludes the survival variable).
Source: https://www.kaggle.com/c/titanic/data
titanic <- titanic_train str(titanic)
We begin with some minor data preparations. The lapply()
function is a handy tool if the task is to apply the same transformation (e.g. as.factor()
) to multiple columns of a data frame.
titanic[, c(2:3,5,12)] <- lapply(titanic[, c(2:3,5,12)], as.factor)
Next we split the data into a training and a test part. This can be done by random sampling with sample
.
set.seed(3225) train <- sample(1:nrow(titanic), 0.8*nrow(titanic)) titanic_train <- titanic[train,] titanic_test <- titanic[-train,]
Our task is to predict survival status of the Titanic passengers. As a first attempt, we grow a classification tree with rpart
, which follows the CART idea. Tree size is controlled by the default options (see ?rpart
).
tree1 <- rpart(Survived ~ Pclass + Sex + Age + Fare + Embarked, data = titanic_train, method = "class") tree1
In addition to just printing the tree structure, calling summary
with the tree object gives us a lot of information on the grown tree.
summary(tree1)
Of course, trees are (usually) best represented by a plot. Here we use the partykit
package to first convert the tree into the party format and then use plot
on the new object.
party_tree1 <- as.party(tree1) plot(party_tree1, gp = gpar(fontsize = 9))
Lets build a larger tree.
tree2 <- rpart(Survived ~ Pclass + Sex + Age + Fare + Embarked, data = titanic_train, control = rpart.control(minsplit = 10, # minimal obs in a node minbucket = 3, # minimal obs in any terminal node cp = 0.001, # min improvement through splitting maxdepth = 30 # maximum tree depth ))
Unfortunately, this new tree might be too large to plot.
party_tree2 <- as.party(tree2) plot(party_tree2, gp = gpar(fontsize = 6))
This large tree is likely to overfit and will not generalize well to new data. Therefore, we use printcp
and plotcp
that help us to determine the best subtree. Root node error
times xerror
gives us the estimated test error for each subtree based on cross-validation.
printcp(tree2) plotcp(tree2)
On this basis, we are interested in picking the cp value that is associated with the smallest CV error. We could do this by hand or by using a few simple lines of code.
minx <- which.min(tree2$cptable[,"xerror"]) mincp <- tree2$cptable[minx,"CP"] mincp
Alternatively, we could also pick the best subtree based on the 1-SE rule. We are again interested in storing the corresponding cp value for tree pruning in the next step.
minx <- which.min(tree2$cptable[,"xerror"]) minxse <- tree2$cptable[minx,"xerror"] + tree2$cptable[minx,"xstd"] minse <- which(tree2$cptable[1:minx,"xerror"] < minxse) mincp2 <- tree2$cptable[minse[1],"CP"] mincp2
Now we can get the best subtree with the prune
function. First based on the smallest CV error...
p_tree <- prune(tree2, cp = mincp) p_tree
...and now based on the 1-SE rule.
p_tree2 <- prune(tree2, cp = mincp2) p_tree2
Finally, we can use the pruned tree in order to predict the outcome in the holdout (test) set. Prediction performance can be evaluated with confusionMatrix
from caret
.
y_tree <- predict(p_tree, newdata = titanic_test, type = "class") y_tree2 <- predict(p_tree2, newdata = titanic_test, type = "class") confusionMatrix(y_tree, titanic_test$Survived, mode = "everything", positive = "1") confusionMatrix(y_tree2, titanic_test$Survived, mode = "everything", positive = "1")
We can also predict probabilities instead of class membership.
yp_tree <- predict(p_tree, newdata = titanic_test, type = "prob")[,2] prob <- verify(pred = yp_tree, obs = as.numeric(as.character(titanic_test$Survived)), frcst.type = "prob", obs.type = "binary" ) summary(prob)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.