R/xgb.plot_Function.R

xgb.plot  <- function (feature_names = NULL, model = NULL, n_first_tree = NULL, 
                       plot_width = NULL, plot_height = NULL, ...) 
{
  require(DiagrammeR)
  
  if (class(model) != "xgb.Booster") {
    stop("model: Has to be an object of class xgb.Booster model generaged by the xgb.train function.")
  }
  if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
    stop("DiagrammeR package is required for xgb.plot.tree", 
         call. = FALSE)
  }
  allTrees <- xgb.model.dt.tree(feature_names = feature_names, 
                                model = model, n_first_tree = n_first_tree) %>% as.data.table()
  allTrees[, `:=`(label, paste0(Feature, "\\nCover: ", Cover, 
                                "\\nGain: ", Quality))]
  allTrees[, `:=`(shape, "rectangle")][Feature == "Leaf", `:=`(shape, 
                                                               "oval")]
  allTrees[, `:=`(filledcolor, "Beige")][Feature == "Leaf", 
                                         `:=`(filledcolor, "Khaki")]
  nodes <- DiagrammeR::create_node_df(n = length(allTrees[, 
                                                          ID] %>% rev), label = allTrees[, label] %>% rev, style = "filled", 
                                      color = "DimGray", fillcolor = allTrees[, filledcolor] %>% 
                                        rev, shape = allTrees[, shape] %>% rev, data = allTrees[, 
                                                                                                Feature] %>% rev, fontname = "Helvetica", fontcolor="black")
  edges <- DiagrammeR::create_edge_df(from = match(allTrees[Feature != 
                                                              "Leaf", c(ID)] %>% rep(2), allTrees[, ID] %>% rev), to = match(allTrees[Feature != 
                                                                                                                                        "Leaf", c(Yes, No)], allTrees[, ID] %>% rev), label = allTrees[Feature != 
                                                                                                                                                                                                         "Leaf", paste("<", Split)] %>% c(rep("", nrow(allTrees[Feature != 
                                                                                                                                                                                                                                                                  "Leaf"]))), color = "DimGray", arrowsize = "1.5", arrowhead = "vee", 
                                      fontname = "Helvetica", rel = "leading_to")
  graph <- DiagrammeR::create_graph(nodes_df = nodes, edges_df = edges)
  DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
}
Ehsan-F/R-Mixtape documentation built on June 24, 2020, 12:22 a.m.