Description Usage Arguments Value Examples
Plots least and most node trees from Random Forest object
1 | plot_rf_tree(model_rf, spec = c("least", "most")[1])
|
final_model |
object of class randomForest from caret |
tree_num |
tree number you want to plot |
ggplot object of a random forest containing least or most nodes
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | library(caret)
set.seed(42)
index <- createDataPartition(iris$Species, p = 0.7, list = FALSE)
train_data <- iris[index, ]
test_data <- iris[-index, ]
fitControl <- trainControl(method = "repeatedcv",
number = 5,
repeats = 2,
sampling = "smote",
savePredictions = TRUE,
verboseIter = FALSE)
rfGrid <- expand.grid(mtry = c(2: sqrt(ncol(train_data))))
rfGrid <- expand.grid(mtry = c(2: 4))
# run model
set.seed(42)
model_rf <- train(Species ~ .,
data = train_data,
method = "rf",
preProcess = c("scale", "center"),
fitControl = fitControl,
tuneGrid = rfGrid)
plot(model_rf)
ggplot(model_rf)
plot(model_rf$finalModel)
plot_rf_tree(model_rf, "most")
plot_rf_tree(model_rf, "least")
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.