R/hhcartr_export_displayTree.R

Defines functions displayTree dispnode navigate_hash

Documented in displayTree dispnode navigate_hash

# source: hhcartr_export_displayTree.R

#########################################################################################
#'
#' navigate_hash generates DOT statements for all trees in current model.
#'
#' This internal function generates DOT statements for the selected tree generated by the
#' current model, grViz is then called to display the resultant graph.
#'
#' @param hobj List of all tree objects.
#' @param numtree The number of the tree to display.
#' @param dataset_description A brief description of the dataset being used.
#'
#' @return nothing.
#'
navigate_hash <- function(hobj, numtree = NA, dataset_description = "Unknown"){

  parentid <- NULL
  objectid <- 0
  level    <- 1
  ln       <- NA
  rn       <- NA

  # skeleton DOT statement for a node.
  label_1X <- "%s [label = \"Node %s\\nX[%s] <= %s\\ngini index = %s\\nsamples = %s\\nvalue = %s\"];"
  label_1z <- "%s [label = \"Node %s\\nz[%s] <= %s\\ngini index = %s\\nsamples = %s\\nvalue = %s\"];"

  if(is.na(numtree)){
    loop_start <- 1
    loop_end   <- length(hobj)
  } else {
    loop_start <- numtree
    loop_end   <- numtree
  }

  # make sure users tree number is within range of trees we have.
  if(numtree %notin% seq(1, length(hobj))){
    msg <- "hhcartr(displayTree) Tree %s requested, only %s trees available."
    msgs <- sprintf(msg, numtree, length(hobj))
    stop(msgs)
  }

  # initialise DOT statement list before we start
  clear_dot_list()

  # add initial DOT statements
  append_dot_list("digraph rmarkdown {")  # was Tree
  #append_dot_list("size = '8, 6';")
  append_dot_list("nodesp = 0.75;")
  # plot title.
  title_dot_stmt <- "graph [label = 'displayTree for Tree %s. Dataset: %s.\\nn_folds = %s, n_trees = %s, n_min = %s,\\nmin_node_impurity = %s, useIdentity = %s,\\nseed = %s.', labelloc='t', fontsize=28];"
  title_dot_stmt_formatted <- sprintf(title_dot_stmt,
                                      numtree,
                                      dataset_description,
                                      pkg.env$n_folds,
                                      pkg.env$n_trees,
                                      pkg.env$n_min,
                                      pkg.env$min_node_impurity,
                                      pkg.env$useIdentity,
                                      pkg.env$seed)
  append_dot_list(title_dot_stmt_formatted)
  append_dot_list("node [shape = box, color = black, fontsize = 12];") # fontname = Helvetica

  for(i in loop_start:loop_end){
    zero_objectid_count()

    # is there a LEFT node on current root? yes then get it.
    if(!hobj[[i]]$node_children_left_NA){
      ln <- hobj[[i]]$node_children_left
    }
    # is there a RIGHT node on current root? yes then get it.
    if(!hobj[[i]]$node_children_right_NA){
      rn <- hobj[[i]]$node_children_right
    }

    # build DOT statement for current node
    # check here if X or z
    if(hobj[[i]]$node_using_householder){
      label_1 <- label_1z
    } else {
      label_1 <- label_1X
    }
    label_1_o <- sprintf(label_1, hobj[[i]]$node_objectid,
                                  hobj[[i]]$node_objectid,
                                  hobj[[i]]$node_feature_index,
                                  round(hobj[[i]]$node_threshold, 4),
                                  round(hobj[[i]]$node_gini, 4),
                                  hobj[[i]]$node_tot_samples,
                                  hobj[[i]]$node_num_samples_per_class)
    # save DOT statement
    append_dot_list(label_1_o)

    # if we have a LEFT node go run chain of nodes
    if(!hobj[[i]]$node_children_left_NA){
      dispnode(ln, level, objectid)
    }
    # if we have a RIGHT node go run chain of nodes
    if(!hobj[[i]]$node_children_right_NA){
      dispnode(rn, level, objectid)
    }

    # terminate DOT statements
    append_dot_list("}")

    # create temporary file.
    file_name <- tempfile(fileext = ".gv")
    # open file name for writing
    out <- file(file_name, 'w')
    # write out each DOT statement
    lapply(get_dot_list(), write, out, append = TRUE)

    if(!is.na(numtree)){
      tryCatch((grViz(diagram = file_name)), error = function(c) {
        c$message <- paste0(c$message, " (in ", file_name, ")")
        unlink(file_name)
        stop(c)
      })
      unlink(file_name)
    }
    close(out)
  }
}

#########################################################################################
#'
#' dispnode generates DOT statements for all trees in current model.
#'
#' This internal function generates DOT statements for all trees in the current model. It
#' is called by function navigate_hash.
#'
#' @param nn current node of the current tree
#' @param level current depth in the current tree
#' @param pid the parent-id of the current node in the current tree
#' @return nothing
#'
dispnode <- function(nn, level, pid){
  # initialise default values
  ln <- NA
  rn <- NA
  # skeleton DOT statement
  label_1X <- "%s [label = \"Node %s\\nX[%s] <= %s\\ngini index = %s\\nsamples = %s\\nvalue = %s\"];"
  label_1z <- "%s [label = \"Node %s\\nz[%s] <= %s\\ngini index = %s\\nsamples = %s\\nvalue = %s\"];"
  # skeleton DOT statement
  label_2 <- "%s [label=\"Node %s\\ngini index = %s\\nsamples = %s\\nvalue = %s\"] ;"
  # skeleton DOT statement
  label_3 <- "%s -> %s ;"

  # keep count of the nodes in the current tree
  increment_objectid_count()
  # assign current node count as this nodes objectid
  nn$node_objectid <- get_objectid_count()
  # assign this nodes parentid
  nn$node_parentid <- pid

  if(nn$node_children_left_NA & nn$node_children_right_NA){
    # create DOT statement for terminal node
    label_2_o <- sprintf(label_2,
                         nn$node_objectid,
                         nn$node_objectid,
                         round(nn$node_gini, 4),
                         nn$node_tot_samples,
                         nn$node_num_samples_per_class)
    # add DOT statement to list
    append_dot_list(label_2_o)

  } else {
    # create DOT statement for internal node
    # check here if X or z
    if(nn$node_using_householder){
      label_1 <- label_1z
    } else {
      label_1 <- label_1X
    }
    label_1_o <- sprintf(label_1,
                         nn$node_objectid,
                         nn$node_objectid,
                         nn$node_feature_index,
                         round(nn$node_threshold, 4),
                         round(nn$node_gini, 4),
                         nn$node_tot_samples,
                         nn$node_num_samples_per_class)
    # add DOT statement to list
    append_dot_list(label_1_o)
  }
  # create DOT statement to show how nodes relate
  label_3_o <- sprintf(label_3, nn$node_parentid, nn$node_objectid)
  # add DOT statement to list
  append_dot_list(label_3_o)

  # increment level (why? what?)
  level <- level + 1

  # does this node have a left child node?
  if(!nn$node_children_left_NA){
    ln <- nn$node_children_left
    dispnode(ln, level, nn$node_objectid)
  } else {
    return()
  }
  # does this node have a right child node?
  if(!nn$node_children_right_NA){
    rn <- nn$node_children_right
    dispnode(rn, level, nn$node_objectid)
  } else {
    return()
  }
  # return to caller
  return()
}

#########################################################################################
#'
#' displayTree Display one selected decision tree created from DOT statements.
#'
#' This function displayTree() generates DOT statements for the selected decision tree
#' in the current model. The DOT statements are written to a temporary file in the tmp
#' directory. Function grViz() from the DiagrammeR package is then called to visualize
#' the graph. displayTree() will check to make sure package DiagrammeR is installed and
#' loaded before attempting generation of DOT statements.
#'
#' @param ntree The number of the tree the user wishes to display.
#' @param rpart_ If specified, it is the rpart() equivalent tree in hhcartr format.
#'
#' @return nothing.
#'
#' @example man/examples/displayTree.R
#'

#' @export
displayTree <- function(ntree = 1, rpart_ = NA){
  # ensure allres is not null before proceeding...
  if(is.na(rpart_)){
    allres <- pkg.env$folds_trees
  } else {
    allres <- rpart_
    ntree <- 1
  }
  # allres is a list()
  if(length(allres) == 0){
    stop("hhcartr(displayTree) no trees found. Run fit() or rpartTree() first.")
  }
  # ensure package DiagrammeR is installed and loaded.
  packages <- c("DiagrammeR")
  check_package(packages)

  # get the dataset description
  dataset_description <- pkg.env$model_data_description
  # go generate DOT statements
  navigate_hash(allres, ntree, dataset_description)

  return(grViz(unlist(get_dot_list())))
}

Try the hhcartr package in your browser

Any scripts or data that you put into this service are public.

hhcartr documentation built on July 2, 2021, 9:06 a.m.