#' ID3
#'
#' A Decision Tree implemented using the ID3 algorithm.
#'
#' @examples
#' clf <- ID3$new("monks")
#' clf$train(monks1_train[, -c(1,8)], monks1_train[,1])
#' preds <- clf$predict(monks1_test[, -c(1,8)])
#' accuracy <- sum(preds == monks1_test[,1])/nrow(monks1_test)
#' data.tree::SetEdgeStyle(clf$tree, fontname = 'helvetica', label = GetEdgeLabel)
#' data.tree::SetNodeStyle(clf$tree, fontname = 'helvetica', label = GetNodeLabel)
#' plot(clf$tree)
#'
#' @export
ID3 <- R6::R6Class("ID3",
public = list(
tree = NULL,
max_depth = NULL,
initialize = function(name = "tree_root") {
self$tree <- data.tree::Node$new(name)
},
train = function(features, labels, max_depth = ncol(features)) {
self$max_depth <- max_depth
feature_ranges <- list()
for(i in 1:ncol(features)) {
feature_ranges[[names(features)[i]]] <- unique(features[,i])
}
private$train_help(self$tree, features, labels, feature_ranges)
},
predict = function(features) {
if(is.vector(features)) {
# make single prediction
private$predict_help(self$tree, features)
} else {
# make prediction for every row in data.frame
preds <- vector()
for(row in 1:nrow(features)) {
preds <- c(preds, private$predict_help(self$tree, features[row,]))
}
preds
}
},
prune = function(validation_features, validation_labels){
private$prune_help(self$tree, validation_features, validation_labels)
}
),
private = list(
entropy = function(labels) {
ent_sum <- 0
num_labels <- length(labels)
for(val in unique(labels)) {
num_val <- length(labels[labels == val])
percent <- num_val/num_labels
ent_sum <- ent_sum - percent*log(percent, 2)
}
ent_sum
},
info_gain = function(feature, labels) {
ent_parent <- private$entropy(labels)
ent_children <- 0
for(val in unique(feature)) {
relative_weight <- length(feature[feature == val])/length(feature)
ent_children <- ent_children + relative_weight*private$entropy(labels[which(feature == val)])
}
ig <- ent_parent - ent_children
ig
},
depth = function(node) {
curr_depth <- 0
while(!node$isRoot) {
node <- node$parent
curr_depth <- curr_depth + 1
}
curr_depth
},
train_help = function(node, features, labels, feature_ranges) {
if(length(unique(labels)) == 1) {
# class labels are all the same, so classify!
node$feature <- "class"
node$class <- unique(labels)
} else if(private$depth(node) >= self$max_depth) {
# if no features left to split on, take mode of feature values
node$feature <- "class"
node$class <- names(which.max(table(labels)))
} else {
# get feature with highest info gain on split
# create children nodes for each value of feature
# recursively call train_help on each child with proper subset of data
max_ig <- -Inf
max_f <- ""
for(f in 1:ncol(features)) {
split_on_f_ig <- private$info_gain(features[, f], labels)
if(split_on_f_ig > max_ig) {
max_ig <- split_on_f_ig
max_f <- f
}
}
node$feature <- names(features)[max_f]
node$majority <- names(which.max(table(labels)))
for(c in feature_ranges[[node$feature]]) {
if(!any(features[, max_f] == c)) {
child <- node$AddChild(c)
child$feature <- "class"
child$class <- names(which.max(table(labels)))
} else {
child <- node$AddChild(c)
private$train_help(child, features[features[,max_f]==c,-max_f, drop=FALSE], labels[features[,max_f] == c], feature_ranges)
}
}
}
},
predict_help = function(node, features) {
while(node$feature != "class") {
for(n in node$children){
if(n$name == features[node$feature]){
node <- n
break
}
}
}
node$class
},
prune_help = function(node, validation_features, validation_labels) {
if(node$feature != "class") {
curr_preds <- vector()
for(row in 1:nrow(validation_features)) {
curr_preds <- c(curr_preds, private$predict_help(node, validation_features[row,]))
}
if(sum(curr_preds == validation_labels) <= sum(node$majority == validation_labels)) {
for(n in node$children) {
node$RemoveChild(n$name)
}
node$feature <- "class"
node$class <- node$majority
node$RemoveAttribute("majority")
}
else {
for(n in node$children) {
private$prune_help(n,
validation_features[validation_features[,node$feature]==n$name, , drop=FALSE],
validation_labels[validation_features[,node$feature]==n$name])
}
}
}
}
)
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.