##' @title Probability tree class
##' @description Subclass for probability tree.
##' Contains all fields and methods used special for probability trees.
TreeProbability <- setRefClass("TreeProbability",
contains = "Tree",
fields = list(),
methods = list(
splitNodeInternal = function(nodeID, possible_split_varIDs) {
## Check node size, stop if maximum reached
if (length(sampleIDs[[nodeID]]) <= min_node_size) {
return(NULL)
}
## Stop if node is pure
unique_response <- unique(data$subset(sampleIDs[[nodeID]], 1))
if (length(unique_response) == 1) {
return(NULL)
}
## Find best split, stop if no decrease of impurity
return(findBestSplit(nodeID, possible_split_varIDs))
},
findBestSplit = function(nodeID, possible_split_varIDs) {
## Initialize
best_split <- NULL
best_split$decrease <- -1
best_split$varID <- -1
best_split$value <- -1
## Get response
response <- data$subset(sampleIDs[[nodeID]], 1)
## For all possible variables
for (i in 1:length(possible_split_varIDs)) {
split_varID <- possible_split_varIDs[i]
data_values <- data$subset(sampleIDs[[nodeID]], split_varID)
## Handle ordered factors
if (!is.numeric(data_values) & !is.ordered(data_values) & unordered_factors == "order_split") {
## Order factor levels
num.response <- as.numeric(response)
means <- sapply(levels(data_values), function(x) {
mean(num.response[data_values == x])
})
levels.ordered <- as.character(levels(data_values)[order(means)])
## Get all levels not in node
levels.missing <- setdiff(levels(data_values), levels.ordered)
levels.ordered <- c(levels.missing, levels.ordered)
## Return reordered factor
data_values <- factor(data_values, levels = levels.ordered, ordered = TRUE)
}
## If still not ordered, use partition splitting
if (!is.numeric(data_values) & !is.ordered(data_values)) {
best_split = findBestSplitValuePartition(split_varID, data_values, best_split, response)
## Set split levels left
if (best_split$varID == split_varID) {
split_levels_left[[nodeID]] <<- best_split$values_left
}
} else {
best_split = findBestSplitValueOrdered(split_varID, data_values, best_split, response)
## Set split levels left (empty if ordered splitting)
if (unordered_factors == "order_split") {
if (best_split$varID == split_varID) {
split_levels_left[[nodeID]] <<- unique(data_values[data_values <= best_split$value])
if (is.factor(data_values)) {
## Use same splits as in partition
ints <- as.integer(factor(split_levels_left[[nodeID]], levels = levels(data$subset(sampleIDs[[nodeID]], split_varID))))
if (sum(2^(ints-1)) >= 2^(max(as.numeric(data$subset(sampleIDs[[nodeID]], split_varID))) - 1)) {
split_levels_left[[nodeID]] <<- unique(data_values[data_values > best_split$value])
}
}
}
} else {
if (best_split$varID == split_varID) {
split_levels_left[[nodeID]] <<- list()
}
}
}
}
if (best_split$varID < 0) {
## Stop if no good split found
return(NULL)
} else {
## Return best split
result <- NULL
result$varID <- as.integer(best_split$varID)
result$value <- best_split$value
return(result)
}
},
findBestSplitValueOrdered = function(split_varID, data_values, best_split, response) {
## For all possible splits
possible_split_values <- unique(data_values)
for (j in 1:length(possible_split_values)) {
split_value <- possible_split_values[j]
## Count classes in childs
idx <- data_values <= split_value
class_counts_left <- tabulate(response[idx])
class_counts_right <- tabulate(response[!idx])
## Skip if one child empty
if (sum(class_counts_left) == 0 | sum(class_counts_right) == 0) {
next
}
if (splitrule == "Gini") {
## Decrease of impurity
decrease <- sum(class_counts_left^2)/sum(class_counts_left) +
sum(class_counts_right^2)/sum(class_counts_right)
} else {
stop("Unknown splitrule.")
}
## Use this split if better than before
if (decrease > best_split$decrease) {
best_split$value <- split_value
best_split$varID <- split_varID
best_split$decrease <- decrease
}
}
return(best_split)
},
findBestSplitValuePartition = function(split_varID, data_values, best_split, response) {
## For all possible splits
possible_split_values <- sort(unique(data_values))
## For all 2^(n-1)-1 2-partitions
num_partitions <- 2^(length(possible_split_values) - 1) - 1
for (j in 1:num_partitions) {
## Convert number to logic vector
left_idx <- as.bitvect(j, length = length(possible_split_values))
values_left <- possible_split_values[left_idx]
## Count classes in childs
idx <- data_values %in% values_left
class_counts_left <- tabulate(response[idx])
class_counts_right <- tabulate(response[!idx])
## Skip if one child empty
if (sum(class_counts_left) == 0 | sum(class_counts_right) == 0) {
next
}
if (splitrule == "Gini") {
## Decrease of impurity
decrease <- sum(class_counts_left^2)/sum(class_counts_left) +
sum(class_counts_right^2)/sum(class_counts_right)
} else {
stop("Unknown splitrule.")
}
## Use this split if better than before
if (decrease > best_split$decrease) {
best_split$values_left <- values_left
best_split$varID <- split_varID
best_split$decrease <- decrease
}
}
return(best_split)
},
estimate = function(nodeID) {
## Return only NA, value is not used later
return(NA)
},
getNodePrediction = function(nodeID) {
## Return class fractions
node_samples <- sampleIDs[[nodeID]]
table(data$subset(node_samples, 1))/length(node_samples)
},
predictionError = function(pred = NULL) {
if (is.null(pred)) {
pred <- predictOOB()
}
## Get probabilities for true classes
response <- as.numeric(data$subset(oob_sampleIDs, 1))
num_samples <- nrow(probabilities)
num_classes <- ncol(probabilities)
idx <- response == matrix(1:num_classes, byrow = TRUE,
nrow = num_samples, ncol = num_classes)
true_class_probabilities <- probabilities[idx]
sum((true_class_probabilities - 1)^2, na.rm = TRUE) / length(oob_sampleIDs)
})
)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.