plot_rf_tree: Plot Random Forest trees

Description Usage Arguments Value Examples

Description

Plots least and most node trees from Random Forest object

Usage

1
plot_rf_tree(model_rf, spec = c("least", "most")[1])

Arguments

final_model

object of class randomForest from caret

tree_num

tree number you want to plot

Value

ggplot object of a random forest containing least or most nodes

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
library(caret)


set.seed(42)
index <- createDataPartition(iris$Species, p = 0.7, list = FALSE)
train_data <- iris[index, ]
test_data  <- iris[-index, ]

fitControl <- trainControl(method = "repeatedcv", 
             number = 5, 
             repeats = 2,
             sampling = "smote",
             savePredictions = TRUE, 
             verboseIter = FALSE)

rfGrid <-  expand.grid(mtry = c(2: sqrt(ncol(train_data))))
rfGrid <-  expand.grid(mtry = c(2: 4))

# run model
set.seed(42)
model_rf <- train(Species ~ .,
                         data = train_data,
                         method = "rf",
                         preProcess = c("scale", "center"),
                         fitControl = fitControl,
                         tuneGrid = rfGrid)

plot(model_rf)
ggplot(model_rf)
plot(model_rf$finalModel)

plot_rf_tree(model_rf, "most")
plot_rf_tree(model_rf, "least")

HanjoStudy/quotidieR documentation built on May 5, 2019, 6:13 p.m.