Nothing
get.tree.rfsrc <- function(object,
tree.id,
target,
m.target = NULL,
time,
surv.type = c("mort", "rel.freq", "surv", "years.lost", "cif", "chf"),
class.type = c("bayes", "rfq", "prob"),
ensemble = FALSE,
oob = TRUE,
show.plots = TRUE,
do.trace = FALSE)
{
##----------------------------------------------------------------
##
## tolerance value for dealing with float issues related to
## splitting left at an exact split value
##
##----------------------------------------------------------------
tolerance <- sqrt(.Machine$double.eps)
##----------------------------------------------------------------
##
## The following two utilities were copied from the BMS CRAN
## package. Thanks to Martin Feldkircher and Stefan Zeugne for
## these little quickies.
##
##----------------------------------------------------------------
hex2bin <- function (hexcode)
{
if (!is.character(hexcode))
stop("please input a character like '0af34c'")
hexcode <- paste("0", tolower(hexcode), sep = "")
hexobj <- .hexcode.binvec.convert(length(hexcode) * 16L)
return(hexobj$as.binvec(hexcode))
}
.hexcode.binvec.convert <- function (length.of.binvec)
{
if (length(length.of.binvec) > 1)
length.of.binvec = length(length.of.binvec)
addpositions = 4 - length.of.binvec%%4
positionsby4 = (length.of.binvec + addpositions)/4
hexvec = c(0:9, "a", "b", "c", "d", "e", "f")
hexcodelist = list(`0` = numeric(4),
`1` = c(0, 0, 0, 1),
`2` = c(0, 0, 1, 0),
`3` = c(0, 0, 1, 1),
`4` = c(0, 1, 0, 0),
`5` = c(0, 1, 0, 1),
`6` = c(0, 1, 1, 0),
`7` = c(0, 1, 1, 1),
`8` = c(1, 0, 0, 0),
`9` = c(1, 0, 0, 1),
a = c(1, 0, 1, 0),
b = c(1, 0, 1, 1),
c = c(1, 1, 0, 0),
d = c(1, 1, 0, 1),
e = c(1, 1, 1, 0),
f = c(1, 1, 1, 1))
return(list(
as.hexcode = function(binvec) {
incl = c(numeric(addpositions), binvec)
dim(incl) = c(4, positionsby4)
return(paste(hexvec[crossprod(incl, 2L^(3:0)) + 1], collapse = ""))
},
as.binvec = function(hexcode) {
return(unlist(
hexcodelist[unlist(strsplit(hexcode, "", fixed = TRUE),
recursive = FALSE, use.names = FALSE)],
recursive = FALSE, use.names = FALSE)[-(1:addpositions)])
}))
}
##----------------------------------------------------------------
##
## coherence checks
##
##----------------------------------------------------------------
if (sum(inherits(object, c("rfsrc", "grow"), TRUE) == c(1, 2)) != 2) {
stop("This function only works for objects of class `(rfsrc, grow)'")
}
if (is.forest.missing(object)) {
stop("forest is missing. Re-run rfsrc (grow call) with forest=TRUE")
}
if (inherits(object, "anonymous")) {
## anonymous <- TRUE
stop("get.tree does not currently work with anonymous forests\n")
}
else {
anonymous <- FALSE
}
## only one tree allowed
tree.id <- tree.id[1]
## must be integer from 1...ntree
if (tree.id < 1 || tree.id > object$ntree) {
stop("tree id must be integer from 1 to ntree")
}
## ensure coherency of the multivariate target
m.target <- get.univariate.target(object, m.target)
##----------------------------------------------------------------
##
##
## take care of some preliminary processing for prediction
## this will be used for labels on terminal nodes
##
##
##----------------------------------------------------------------
## process the object depending on the underlying family
family <- object$forest$family
## coherence for prediction
predict.flag <- TRUE
if (family == "unsupv" || family == "surv-TDC") {
predict.flag <- FALSE
}
if (predict.flag) {
## survival and competing risk families.
if (grepl("surv", family)) {
## extract event information
event.info <- object$event.info
cens <- event.info$cens
event.type <- event.info$event.type
## assign time if missing
if (missing(time)) {
time <- median(event.info$time.interest, na.rm = TRUE)
}
## check for single time point
if (length(time) > 1) {
stop("time must be a single value: ", time)
}
## competing risk
if (family == "surv-CR") {
if (missing(target)) {
target <- 1
}
else {
if (target < 1 || target > max(event.type, na.rm = TRUE)) {
stop("'target' is specified incorrectly")
}
}
## set the surv.type
surv.type <- setdiff(surv.type, c("mort", "rel.freq", "surv"))[1]
pred.type <- match.arg(surv.type, c("years.lost", "cif", "chf"))
}
## survival
else {
target <- 1
## set the surv.type
surv.type <- setdiff(surv.type, c("years.lost", "cif", "chf"))[1]
pred.type <- match.arg(surv.type, c("rel.freq", "mort", "chf", "surv"))
}
}
## univariate and multivariate families.
else {
## assign a null time value
event.info <- time <- NULL
## factor outcome
if(is.factor(coerce.multivariate(object, m.target)$yvar)) {
object.yvar.levels <- levels(coerce.multivariate(object, m.target)$yvar)
pred.type <- match.arg(class.type, c("bayes", "rfq", "prob"))
if (missing(target)) {
target <- object.yvar.levels[1]
}
if (is.character(target)) {
target <- match(match.arg(target, object.yvar.levels), object.yvar.levels)
}
else {
if ((target > length(object.yvar.levels)) | (target < 1)) {
stop("target is specified incorrectly:", target)
}
}
}
## regression or unsupervised
else {
pred.type <- "y"
target <- NULL
}
}
}
##----------------------------------------------------------------
##
## define the target subset of cases
## if ensemble=TRUE --> this is the entire data
## if ensemble=FALSE --> this is inbag or oob cases
##
##----------------------------------------------------------------
if (ensemble) {
subset <- 1:object$n
}
## restrict the object to the target
## oob not allowed
else {
object <- predict(object, get.tree = tree.id, membership = TRUE, do.trace = do.trace)
subset <- which(object$inbag[, tree.id] != 0)
oob <- FALSE
}
##----------------------------------------------------------------
##
## extract x data for the target subset
## convert the data to numeric mode, apply the na.action protocol
## missing data not allowed
##
##----------------------------------------------------------------
xvar.names <- object$forest$xvar.names
xvar.factor <- object$forest$xvar.factor
if (!anonymous) {
x.data <- object$forest$xvar
if (any(is.na(x.data))) {
stop("missing data not allowed")
}
x.data <- finalizeData(xvar.names, x.data, miss.flag = FALSE)
x.data <- x.data[subset,, drop = FALSE]
}
##----------------------------------------------------------------
##
## now acquire the predicted values for the tree labels
##
##----------------------------------------------------------------
if (predict.flag) {
if (ensemble) {
object <- predict.rfsrc(object, m.target = m.target, do.trace = do.trace)
}
yhat <- extract.pred(object, pred.type, subset, time, m.target, target, oob = oob)
}
##----------------------------------------------------------------
##
## get tree data
##
##----------------------------------------------------------------
## pull the arrays
native.array <- object$forest$nativeArray
native.f.array <- object$forest$nativeFactorArray[[1]]
## added processing needed for factors
f.ctr <- 0
factor.flag <- FALSE
if (!is.null(native.f.array)) {
pt.f <- which(native.array$mwcpSZ != 0)
factPT <- lapply(pt.f, function(j) {
f.ctr <<- f.ctr + 1
step <- native.array$mwcpSZ[j] - 1
mwcpPT <- native.f.array[f.ctr:(f.ctr+step)]
mwcpPT <- paste0(sapply(mwcpPT, function(mwc) {
format(as.hexmode(mwcpPT), 8)
}))
mwcpSZ <- hex2bin(mwcpPT)
paste(mwcpSZ, collapse = "")
})
native.array$contPT[pt.f] <- factPT
factor.flag <- TRUE
}
## define the display tree
display.tree <- native.array[native.array$treeID == tree.id,, drop = FALSE]
## check to see if any factors are left
## store relevant informatio for later split-inequality encodings
if (factor.flag) {
pt.f <- display.tree$mwcpSZ !=0
if (sum(pt.f) > 0) {
f.names <- unique(xvar.names[display.tree$parmID[pt.f]])
}
else {
factor.flag <- FALSE
}
}
##----------------------------------------------------------------
##
## prepare the tree to be converted into a network
##
##----------------------------------------------------------------
## conversion
converted.tree <- display.tree
vars.id <- data.frame(var = c("<leaf>", xvar.names), parmID = 0:length(xvar.names), stringsAsFactors = FALSE)
converted.tree$var <- vars.id$var[match(display.tree$parmID, vars.id$parmID)]
## special symbol to be used for encoding the counter for variables (see note below)
special <- "999_999"
# note: we append a counter to the variables, because the data.tree package has trouble when
# nodes are not unique.
var.count <- 1:nrow(converted.tree)
lapply(unique(converted.tree$var), function(vv) {
pt <- converted.tree$var == vv
var.count[which(pt)] <<- 1:sum(pt)
})
converted.tree$var_count <- var.count
converted.tree$var_conc <- paste0(converted.tree$var, special, converted.tree$var_count)
##----------------------------------------------------------------
##
## convert the tree to a network data frame, using the fact that the
## nativeArray output is a pre-order traversal
##
##----------------------------------------------------------------
## preliminary
from_node <- ""
network <- data.frame()
num.children <- data.frame(converted.tree, children = 0)
num.children <- num.children[num.children$var != "<leaf>",, drop = FALSE]
num.children <- num.children[!duplicated(num.children$var_conc),, drop = FALSE]
num_children <- as.list(rep(0, nrow(num.children)))
names(num_children) <- num.children$var_conc
## loop (using lapply)
lapply(1:nrow(converted.tree), function(i) {
rowi <- converted.tree[i, ]
xs <- converted.tree$contPT[converted.tree$var_conc == from_node]
if(i == 1){
from_node <<- rowi$var_conc
}
else{
## develop the split encoding
if(num_children[[from_node]] == 0) {#left split
side <- "<="
contPT.pretty <- round(as.numeric(xs), 3)
split_ineq_pretty <- paste0(side, contPT.pretty)
}
else {#right split
side <- ">"
split_ineq_pretty <- ""
}
## both numeric and factors are encoded as <= > but factors are secretely in hex notation
## !!ADD MACHINE TOLERANCE WHEN SPLIT IS ON ACTUAL X-VALUE !!
if (is.numeric(xs)) {
xs <- xs + tolerance
}
split_ineq <- paste0(side, xs)
## update the network
to_node <- rowi$var_conc
new_node <- list(from = from_node, to = to_node, split = split_ineq, split.pretty = split_ineq_pretty)
network <<- data.frame(rbind(network, new_node, stringsAsFactors = FALSE))
num_children[[from_node]] <<- num_children[[from_node]] + 1
if(rowi$var != "<leaf>")
from_node <<- to_node
else{
if(i != nrow(converted.tree)){
while(num_children[[from_node]] == 2){
from_node <<- network$from[network$to == from_node]
}
}
}
}
})
##----------------------------------------------------------------
##
## process network for factors - clean up the split encoding
## encode as complementary pair string
##
##----------------------------------------------------------------
if (factor.flag) {
## identify which splits need to be cleaned up
from.names <- network$from
pt.f <- sort(unique(unlist(lapply(f.names, function(ptrn) {
grep(ptrn, from.names)
}))))
## identify levels of the factor
## otherwise we would have huge sets full of levels that aren't used
fs <- gsub(paste0(special,".*"),"",from.names[pt.f])
fs.levels <- sapply(fs, function(fsn) {
#length(levels(x.org.data[, fsn]))
xvar.factor$nlevels[which(fsn == xvar.names)]
})
## clean the splits up and encode as complementary pair sets
split.str <- lapply(1:length(pt.f), function(j) {
str <- network$split[pt.f[j]]
## left split
if (grepl("<=", str)) {
str <- sub("<=", "", str)
str <- strsplit(str, "")[[1]]
cpr <- 1 + length(str) - which(str != "0")
cpr <- cpr[cpr <= fs.levels[j]]
paste0("{", paste(cpr, collapse = ","), "}")
}
## right split
else {
str <- sub(">", "", str)
str <- strsplit(str, "")[[1]]
cpr <- 1 + length(str) - which(str == "0")
cpr <- cpr[cpr <= fs.levels[j]]
paste0("{", paste(cpr, collapse = ","), "}")
}
})
## pretty complementary pairs
split.str.pretty <- lapply(1:length(pt.f), function(j) {
str <- network$split[pt.f[j]]
## left split
if (grepl("<=", str)) {
str <- sub("<=", "", str)
str <- strsplit(str, "")[[1]]
cpr <- 1 + length(str) - which(str != "0")
cpr <- cpr[cpr <= fs.levels[j]]
paste0("{", paste(cpr, collapse = ","), "}")
}
else {
""
}
})
## replace the previous fake encoding with the now correct "set encoding"
network$split[pt.f] <- split.str
network$split.pretty[pt.f] <- split.str.pretty
}
##----------------------------------------------------------------
##
## create a tree object, see the data.tree package
##
##----------------------------------------------------------------
data.tree.network <- data.tree::FromDataFrameNetwork(network, "split")
## INTERNAL NODES
## label the edges with the splits, and relabel the nodes so that the appended counters are not visible
data.tree.network$Get(function(node) {
data.tree::SetEdgeStyle(node, color = "black", label = node$split.pretty, fontcolor = "black")
data.tree::SetNodeStyle(node, color = "black", fontcolor = "black", penwidth = 3,
label = strsplit(node$name, special)[[1]][1])
})
## TERMINAL NODES
## loop through the leaves of the tree, each time getting the path down to the leaf, and using this path
## to establish a filtering condition to obtain all cases that match the splitting on the way to this leaf.
## we filter out cases to correspond to the leaf, and then modify the leaf display to include the number
## of cases, followed by the user requested predicted value
lapply(data.tree.network$leaves, function(node) {
path_list <- node$path
var_list <- sapply(path_list, function(x){strsplit(x, special)[[1]][1]})
var_list[length(var_list)] <- ""
node_iter <- data.tree.network
## make boolean string operatore
call <- lapply(2:(length(path_list)), function(i) {
node_iter <<- node_iter[[path_list[[i]]]]
str <- node_iter$split
## numeric boolean operator
if (!any(grepl("\\{", str))) {
str <- paste0("x.data$", paste0(var_list[i-1], str))
}
## complementary pair boolean operator
else {
str <- gsub("\\{", "", str)
str <- gsub("\\}", "", str)
str <- strsplit(str, ",")[[1]]
str <- paste("==", str, sep = "")
str <- paste0("(",paste(paste0("x.data$", var_list[i-1], str), collapse = "|"),")")
}
str
})
call <- paste(call, collapse = " & ")
## evaluate the boolean operator
## this yields the id's for the cases in the node
if (!anonymous) {
pt <- which(eval(parse(text=call)))
n.cases <- length(pt)
}
## set the edgelabel
edge.label <- node$split.pretty
## set the node label
## extract the predicted value in the node
if (predict.flag && !anonymous) {
if (!is.factor(yhat)) {
yhat <- round(mean(yhat[pt], na.rm = TRUE), 2)
node.label <- paste0("n=", n.cases, "\n", yhat)
}
else {
frqtable <- tapply(yhat[pt], yhat[pt], length)
pred <- names(frqtable)[which.max(frqtable)]
node.label <- paste0("n=", n.cases, "\n", pred)
}
}
## unsupervised family --> no predicted value
else if (!predict.flag && !anonymous) {
node.label <- paste0("n=", n.cases)
}
## anonymous
else {
node.label <- NULL
}
## set styles
data.tree::SetGraphStyle(node, rankdir = "TB")
data.tree::SetEdgeStyle(node, arrowhead = "vee", color = "grey35",
penwidth = 3, label = edge.label)
data.tree::SetNodeStyle(node, style = "filled,rounded", shape = "box",
fillcolor = "GreenYellow", penwidth = 3,
fontname = "helvetica", tooltip = data.tree::GetDefaultTooltip,
label = node.label)
})
invisible(data.tree.network)
}
get.tree <- get.tree.rfsrc
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.