Description Usage Arguments Format Details Value Fields Methods Author(s) Examples
ORF is a class of R6.
You can use it to create a random forest via diffrent ways, which supports incremental learning as well as batch learning.
As a matter of fact, the Online Random Forest is made of a list of Online Random Trees
.
1 |
param |
A list which usually has names of |
numTrees |
A nonnegative integer indicates how many ORT trees are going to build. |
R6Class
object.
Online Random Forest was first introduced by Amir Saffari, etc.
After that, Arthur Lui has implemented the algorithm using Python.
Following the paper's advice and Lui's implemention, I refactor the code via R and R6 package. In additon,
the implemention of ORF in this package support both incremental learning and batch learning by combining with randomForest
.
For usage, see details in description of each field or method.
Object of R6Class
, Object of Online Random Forest
.
numClasses
A nonnegative integer indicates how many classes when solve a classifation problem. Default 0 for regression. If numClasses > 0, then do classifation.
classify
TRUE for classification and FALSE for Regression, depending on the value of numClasses
.
forest
A list of ORT trees. More details show in Online Random Tree
.
update(x, y)
When a sample comes in, update all ORT trees in forest with the sample's x variables and y value.
x - The x variables of a sample. Note it is an numeric vector other than a scalar.
y - The y value of a sample.
generateForest(rf, df.train, y.col)
Generate a list of ORT trees, call function ORT$generateTree()
inside.
tree.mat - A tree matrix which can be obtained from randomForest::getTree()
. Node that it must have a column named node.ind. See Examples.
df.train - The training data frame which has been used to contruct randomForest, i.e., the data argument in randomForest
function.
Note that all columns in df.train must be numeric or integer.
y.col - A character indicates which column is y, i.e., the dependent variable. Note that y column must be the last column of df.train.
predict(x, type = c("class", "prob"))
Predict the corresponding y value of x, using all ORT trees.
x - The x variables of a sample. Note it is an numeric vector other than a scalar.
type - For classification only, class means to predict y class for x, and prob means to preict probabilities of each class that x belongs to.
predicts(X, type = c("class", "prob"))
Predict the corresponding y value for a batch of x, using all ORT trees.
X - A matrix or a data frame corresponding to a batch of samples' x variables.
type - For classification only, class means to predict y class for x, and prob means to preict probabilities of each class that x belongs to.
confusionMatrix(X, y, pretty = FALSE)
Get a confusion matrix about predicted y values and true y values. Only for classification problem.
X - A matrix or a data frame corresponding to a batch of samples' x variables.
y - A vector of y values corresponding to a batch of samples.
pretty - If TRUE, print a pretty confusion matrix (need gmodels
package). Default FALSE.
meanTreeSize()
Mean size of ORT trees in the forest.
meanNumLeaves()
Mean leaf nodes numbers of ORT trees in the forest.
meanTreeDepth()
Mean depth of ORT trees in the forest.
sdTreeSize()
Standard deviation for size of ORT trees in the forest.
sdTreeSize()
Standard deviation for leaf nodes numbers of ORT trees in the forest.
sdTreeSize()
Standard deviation for depth of ORT trees in the forest.
Quan Gu
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | # classifaction example
dat <- iris; dat[,5] <- as.integer(dat[,5])
x.rng <- dataRange(dat[1:4])
param <- list('minSamples'= 2, 'minGain'= 0.2, 'numClasses'= 3, 'x.rng'= x.rng)
ind.gen <- sample(1:150,30) # for generate ORF
ind.updt <- sample(setdiff(1:150, ind.gen), 100) # for uodate ORF
ind.test <- setdiff(setdiff(1:150, ind.gen), ind.updt) # for test
rf <- randomForest::randomForest(factor(Species) ~ ., data = dat[ind.gen, ], maxnodes = 2, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "Species")
orf$meanTreeSize()
for (i in ind.updt) {
orf$update(dat[i, 1:4], dat[i, 5])
}
orf$meanTreeSize()
orf$confusionMatrix(dat[ind.test, 1:4], dat[ind.test, 5], pretty = T)
# compare
table(predict(rf, newdata = dat[ind.test,]) == dat[ind.test, 5])
table(orf$predicts(X = dat[ind.test,]) == dat[ind.test, 5])
# regression example
if(!require(ggplot2)) install.packages("ggplot2")
data("diamonds", package = "ggplot2")
dat <- as.data.frame(diamonds[sample(1:53000,1000), c(1:6,8:10,7)])
for (col in c("cut","color","clarity")) dat[[col]] <- as.integer(dat[[col]]) # Don't forget !
x.rng <- dataRange(dat[1:9])
param <- list('minSamples'= 10, 'minGain'= 1, 'maxDepth' = 10, 'x.rng'= x.rng)
ind.gen <- sample(1:1000, 800)
ind.updt <- sample(setdiff(1:1000, ind.gen), 100)
ind.test <- setdiff(setdiff(1:1000, ind.gen), ind.updt)
rf <- randomForest::randomForest(price ~ ., data = dat[ind.gen, ], maxnodes = 20, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "price")
orf$meanTreeSize()
for (i in ind.updt) {
orf$update(dat[i, 1:9], dat[i, 10])
}
orf$meanTreeSize()
# compare
if(!require(Metrics)) install.packages("Metrics")
preds.rf <- predict(rf, newdata = dat[ind.test,])
Metrics::rmse(preds.rf, dat$price[ind.test])
preds <- orf$predicts(dat[ind.test, 1:9])
Metrics::rmse(preds, dat$price[ind.test]) # make progress
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.