Nothing
utils::globalVariables(c("node", "Y", "p", "variable", "decImp", "splitLabel", "ID", "index", "Wt", "prob")) # to avoid CRAN check errors for tidyverse programming
#' Explainable Ensemble Tree
#'
#' It creates an explainable tree for Random Forest. Explainable Ensemble Trees (E2Tree) aimed to generate a “new tree” that can explain and represent the relational structure between the response variable and the predictors. This lead to providing a tree structure similar to those obtained for a decision tree exploiting the advantages of a dendrogram-like output.
#'
#' @param formula is a formula describing the model to be fitted, with a response but no interaction terms.
#' @param data a data frame containing the variables in the model. It is a data frame in which to interpret the variables named in the formula.
#' @param D is the dissimilarity matrix. This is a dissimilarity matrix measuring the discordance between two observations concerning a given classifier of a random forest model. The dissimilarity matrix is obtained with the \link{createDisMatrix} function.
#' @param ensemble is an ensemble tree object (for the moment ensemble works only with random forest objects)
#' @param setting is a list containing the set of stopping rules for the tree building procedure.
#' \tabular{lll}{
#' \code{impTotal}\tab \tab The threshold for the impurity in the node\cr
#' \code{maxDec}\tab \tab The threshold for the maximum impurity decrease of the node\cr
#' \code{n}\tab \tab The minimum number of the observations in the node \cr
#' \code{level}\tab \tab The maximum depth of the tree (levels) \cr}
#' Default is \code{setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)}.
#'
#' @return A e2tree object, which is a list with the following components:
#' \tabular{lll}{
#' \code{tree}\tab \tab A data frame representing the main structure of the tree aimed at explaining and graphically representing the relationships and interactions between the variables used to perform an ensemble method. \cr
#' \code{call}\tab \tab The matched call\cr
#' \code{terms}\tab \tab A list of terms and attributes \cr
#' \code{control}\tab \tab A list containing the set of stopping rules for the tree building procedure \cr
#' \code{varimp}\tab \tab A list containing a table and a plot for the variable importance. Variable importance refers to a quantitative measure that assesses the contribution of individual variables within a predictive model towards accurate predictions. It quantifies the influence or impact that each variable has on the model's overall performance. Variable importance provides insights into the relative significance of different variables in explaining the observed outcomes and aids in understanding the underlying relationships and dynamics within the model \cr}
#'
#' @examples
#' \donttest{
#' ## Classification:
#' data(iris)
#'
#' # Create training and validation set:
#' smp_size <- floor(0.75 * nrow(iris))
#' train_ind <- sample(seq_len(nrow(iris)), size = smp_size)
#' training <- iris[train_ind, ]
#' validation <- iris[-train_ind, ]
#' response_training <- training[,5]
#' response_validation <- validation[,5]
#'
#' # Perform training:
#' ensemble <- randomForest::randomForest(Species ~ ., data=training,
#' importance=TRUE, proximity=TRUE)
#'
#' D <- createDisMatrix(ensemble, data=training, label = "Species",
#' parallel = list(active=FALSE, no_cores = 1))
#'
#' setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)
#' tree <- e2tree(Species ~ ., training, D, ensemble, setting)
#'
#'
#'
#' ## Regression
#' data("mtcars")
#'
#' # Create training and validation set:
#' smp_size <- floor(0.75 * nrow(mtcars))
#' train_ind <- sample(seq_len(nrow(mtcars)), size = smp_size)
#' training <- mtcars[train_ind, ]
#' validation <- mtcars[-train_ind, ]
#' response_training <- training[,1]
#' response_validation <- validation[,1]
#'
#' # Perform training
#' ensemble = randomForest::randomForest(mpg ~ ., data=training, ntree=1000,
#' importance=TRUE, proximity=TRUE)
#'
#' D = createDisMatrix(ensemble, data=training, label = "mpg",
#' parallel = list(active=FALSE, no_cores = 1))
#'
#' setting=list(impTotal=0.1, maxDec=(1*10^-6), n=2, level=5)
#' tree <- e2tree(mpg ~ ., training, D, ensemble, setting)
#'
#' }
#'
#' @export
e2tree <- function(formula, data, D, ensemble, setting=list(impTotal=0.1, maxDec=0.01, n=2, level=5)){
# === Input Validation ===
# Validate formula
if (!inherits(formula, "formula")) {
stop("Error: 'formula' must be a valid formula object.")
}
# Validate data
if (!is.data.frame(data) || nrow(data) == 0) {
stop("Error: 'data' must be a non-empty data frame.")
}
# Validate D (dissimilarity matrix)
if (!is.matrix(D) || nrow(D) != ncol(D)) {
stop("Error: 'D' must be a square dissimilarity matrix.")
}
# Validate ensemble
if (!inherits(ensemble, "randomForest")) {
stop("Error: 'ensemble' must be a trained 'randomForest' model.")
}
# Validate ensemble type
if (!ensemble$type %in% c("classification", "regression")) {
stop("Error: 'type' in ensemble object must be either 'classification' or 'regression'.")
}
# Validate setting
if (!is.list(setting) || !all(c("impTotal", "maxDec", "n", "level") %in% names(setting))) {
stop("Error: 'setting' must be a list with keys: 'impTotal', 'maxDec', 'n', and 'level'.")
}
# Ensure all setting values are numeric and positive
if (!all(sapply(setting, is.numeric)) || any(unlist(setting) <= 0)) {
stop("Error: All values in 'setting' must be positive numeric values.")
}
# === Proceed with the function ===
row.names(data) = NULL
Call <- match.call()
mf <- match.call(expand.dots = FALSE)
m <- match(c("formula", "data"), names(mf), 0L)
mf <- mf[c(1L, m)]
mf$drop.unused.levels <- TRUE
mf[[1L]] <- quote(stats::model.frame)
mf <- eval(mf, parent.frame())
Terms <- attributes(mf)$terms
response <- mf[,1]
X <- mf[,-1]
type <- ensemble$type
setting$tMax <- 1
for (i in 1:setting$level) setting$tMax <- setting$tMax*2+1
## identify qualitative variable and the number of categories:
# Determine classes of all variables in X
var_classes <- unlist(lapply(X,class))
# Identify indices of factors and character variables
ind <- which(var_classes %in% c("factor","character"))
# Calculate the number of unique categories for each factor and character variable
ncat <- NULL
if (length(ind)>0){
for (i in 1:length(ind)){
ncat[i] <- length(unique(X[,ind[i]]))
}
names(ncat) <- names(X)[ind]
}
#ncat <- (sapply(X[,ind], function(x) length(unique(x))))
#ncat <- (apply(X[,ind],2, function(x) length(unique(x))))
# Update var_classes with the number of categories for character and factor variables
var_classes[names(ncat)] = ncat
# Set other variable types to -1
var_classes[!names(var_classes) %in% names(ncat)] = -1
# Convert var_classes values to numeric
var_classes <- as.numeric(var_classes)
# Ensure the names of var_classes match those of X
names(var_classes) <- names(X)
## Generate the split matrix S from the original predictor matrix X
res <- split(X)
S <- res$S
l <- res$lab
rm(res)
# Initilize node vector
nodes <- rep(1,nrow(S))
# list of no-terminal nodes
nterm <- 1
m <- unique(response)
if (type == "classification") {
m_label <- paste("V", seq(1,length(m)*2), sep="")
labels <- c("node","n", "pred", "prob", "impTotal", "impChildren","decImp","decImpSur","variable" ,"split", "splitLabel", "variableSur", "splitLabelSur","parent","children","terminal", "obs", "path", "ncat",m_label)
info <- data.frame(matrix(NA,setting$tMax,length(labels)))
names(info) <- labels
indv <- (ncol(info)-length(m_label)+1):ncol(info)
} else {
labels <- c("node","n", "pred", "prob", "impTotal", "impChildren","decImp","decImpSur","variable" ,"split", "splitLabel", "variableSur", "splitLabelSur","parent","children","terminal", "obs", "path", "ncat")
info <- data.frame(matrix(NA,setting$tMax,length(labels)))
names(info) <- labels
}
N <- NULL
### variance in the root node
vart1 = ifelse(type=="regression", variance(response), NA)
while(length(nterm)>0){
t <- tail(nterm,1)
#print(t)
N <- c(N,t)
index <- which(nodes == t)
### check Stopping Rules----
results <- eStoppingRules(D,index, t, setting, response, ensemble, vart1)
### Compute the response in the node
switch(type,
classification={
res <- moda(response[index])
},
regression={
res <- c(mean(response[index]), sum((response[index] - ensemble$predicted[index])^2)) # mean and MSE
})
###
info$pred[t] <- res[1] # Mean for regression
info$prob[t] <- as.numeric(res[2]) # MSE for regression
info$node[t] <- t
info$parent[t] <- floor(t/2)
info$n[t] <- length(index)
info$impTotal[t] <- results$impTotal
##### statistics
if (type == "classification") {
yval <- data.frame(Y=response[index]) %>%
mutate(Y=factor(Y, levels=levels(m))) %>%
group_by(Y) %>%
summarize(n=n(),
p=n/length(response[index])) %>%
complete(Y, fill = list(n = 0, p = 0)) %>%
select(n,p) %>%
as.matrix() %>%
c()
info[t,indv] <- yval
}
###
if (isTRUE(results$sRule)){
#list(imp=NA,decImp=NA, split=NA, splitLabel=NA, terminal=TRUE, impTotal=results$impTotal, obs=index, path=NA)
info$terminal[t] <- TRUE
#info$impTotal[t] <- results$impTotal
info$obs[t] <- list(index)
info$path[t] <- paths(info,t)
#### add prediction
#info[[t]]$pred <- prediction(Y[index])
####
} else{
imp<- eImpurity(D,index,S)
s <- which.min(imp)
v <- gsub(" <=.*| %in%.*","",names(s))
s2 <- which.min(imp[!(l %in% v)])
# max decrease of impurity
decImp <- results$impTotal-imp[s]
decImpSur <- results$impTotal-imp[s2]
#check for max impurity stopping rule
if (decImp<=setting$maxDec){
#info[[t]] <- list(imp=NA,decImp=NA, split=NA, splitLabel=NA, terminal=TRUE, parent= floor(t/2), children=NA, impTotal=results$impTotal, obs=index, path=NA)
#info$parent[t] <- floor(t/2)
info$terminal[t] <- TRUE
#info$impTotal[t] <- results$impTotal
info$obs[t] <- list(index)
info$path[t] <- paths(info,t)
} else if (suppressWarnings(Wtest(Y=response[index], X=S[index,s], p.value=0.05, type = ensemble$type))){
# Stopping Rule with Mann-Whitney for regression case
# if it is regression, check that the hypothesis that the two distributions in tL and tR are equal is rejected
info$impChildren[t] <- as.numeric(imp[s])
info$decImp[t] <- as.numeric(decImp)
info$decImpSur[t] <- as.numeric(decImpSur)
info$split[t] <- as.numeric(s)
info$splitLabel[t] <- names(s)
info$splitLabelSur[t] <- names(s2)
info$terminal[t] <- FALSE
info$parent[t] <- floor(t/2)
info$children[t] <- list(c(t*2,t*2+1))
info$impTotal[t] <- results$impTotal
info$obs[t] <- list(index)
info$path[t] <- paths(info,t)
info$variable[t] <- gsub(" <=.*| %in%.*","",info$splitLabel[t])
info$ncat[t] <- var_classes[info$variable[t]]
nodes[index] <- (nodes[index]*2+1)-S[index,s]
nterm <- c(nterm, sort(unique(nodes[index]), decreasing=T))
}else{
info$terminal[t] <- TRUE
#info$impTotal[t] <- results$impTotal
info$obs[t] <- list(index)
info$path[t] <- paths(info,t)
}
}
#print(info$path[t])
nterm <- setdiff(nterm,t)
}
info <- info %>% drop_na(node)
info$pred_val <- as.numeric(factor(info$pred))
info$variableSur <- gsub(" <=.*| %in%.*","",info$splitLabelSur)
if (type == "classification") {
yval2 <- as.matrix(info[,indv])
ypred <- info$pred_val
nodeprob <- info$n/first(info$n)
info <- info[,-indv]
#info$yval2 <- cbind(ypred, yval2, nodeprob)
yval2 <- cbind(ypred,yval2)
#names(yval2) <- paste("V",seq(ncol(yval2)),sep="")
attr(yval2,"dimnames")[[2]] <- paste("V",seq(ncol(yval2)),sep="")
info$yval2 <- cbind(yval2, nodeprob)
}
ylevels <- as.character(unique(response))
row.names(info) <- info$node
info <- info[as.character(N),]
## normalized variance in nodes for regression
if (type == "regression"){
info$Wt <- NULL
maxvar <- diff(range(response))^2L / 9L
size <- length(response)
for (i in info$node){
indice <- unlist(eval(parse(text=info$obs[info$node==i])))
v <- 1-(variance(response[indice])/(maxvar*length(indice)/size))
info$Wt[info$node==i] <- ifelse(v<0,0,v)
}
info <- info %>% relocate(Wt, .after=prob)
}
object <- csplit_str(info,X,ncat, call=Call, terms=Terms, control=setting, ylevels=ylevels)
#object$varimp <- vimp(object, data = data)
###### DOESN'T WORK!!
object$N <- N
class(object) <- c("list", "e2tree")
return(object)
}
paths <- function(info,t){
path <- NULL
while(t[1]>1){
tp <- floor(t[1]/2)
symb <- ifelse((t[1] %% 2)==0,"","!")
t <- c(tp,t)
path <- c(paste(symb,info$splitLabel[tp],sep=""),path)
#<- paste(path," AND ", symb,info[[tp]]$splitLabel, sep="",collapse="")
}
path <- paste(path,sep="",collapse=" & ")
return(path)
}
### creation objects for rpart.plot ----
csplit_str <- function(info,X,ncat, call, terms, control, ylevels){
var_lev <- attribute <- list()
for (i in names(ncat)){
var_lev[[i]] <- seq(ncat[i])
names(var_lev[[i]]) <- unique(X[,i])
attribute[[i]] <- as.character(unique(X[,i]))
}
#qualitative splits
splits <- info %>%
dplyr::filter(!is.na(variable)) %>%
dplyr::select(n,ncat,variable,decImp, splitLabel)
#splits <- info[!is.na(info$variable), c("n","ncat","variable","decImp", "splitLabel")]
# index for numerical predictors
if(nrow(info)>1){
varnumerics <- strsplit(splits$splitLabel[splits$ncat==-1],"<=")
splits$index[splits$ncat==-1] <- unlist(lapply(varnumerics, function(x) x[2]))}
#splits$index <- suppressWarnings(as.numeric(as.numeric(gsub("[^[:digit:]., ]", "",splits$splitLabel))))
splits$index[splits$ncat!=-1] <- seq(sum(splits$ncat!=-1))
splits$index <- as.numeric(splits$index)
catsplits <- splits %>%
dplyr::filter(ncat!=-1) %>%
dplyr::mutate(ID = row_number()) %>%
dplyr::group_by(ID) %>%
dplyr::mutate(modal = gsub(paste0(variable," %in% "),"",splitLabel))
splits <- splits %>%
#data.frame() %>%
dplyr::select(n,ncat,decImp,index) %>%
dplyr::rename(improve = decImp,
count = n) %>%
dplyr::mutate(adj=0) %>%
as.matrix()
attr(splits,"dimnames")[[1]] <- info$variable[!is.na(info$variable)]
### creation of csplit
if (nrow(catsplits)>0){
csplit <- matrix(2,nrow(catsplits),max(catsplits$ncat))
for (i in 1:nrow(csplit)){
modal <- eval(parse(text=catsplits$modal[i]))
vec <- var_lev[[catsplits$variable[i]]]
ind <- vec[modal]
csplit[i,ind] <- 1
ind <- vec[setdiff(names(vec),modal)]
csplit[i,ind] <- 3
}
csplit <- as.matrix(csplit)
} else {
csplit = NULL
}
object <- list(tree=info, csplit=csplit,splits=splits, call=call, terms=terms, control=control)
attr(object,"xlevels") <- attribute
attr(object,"ylevels") <- ylevels
return(object)
}
## Variance
variance <- function(x){
sum((x-mean(x))^2)/length(x)
}
Wtest = function(Y, X, type, p.value=0.05){
switch (type,
"classification" = {resp=TRUE},
"regression" = {resp <- wilcox.test(Y ~ X)$p.value <= 0.05}
)
return(resp)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.