tune.rfsrc <- function(formula, data,
mtryStart = ncol(data) / 2,
nodesizeTry = c(1:9, seq(10, 100, by = 5)), ntreeTry = 100,
sampsize = function(x){min(x * .632, max(150, x ^ (3/4)))},
nsplit = 1, stepFactor = 1.25, improve = 1e-3, strikeout = 3, maxIter = 25,
trace = FALSE, doBest = TRUE, ...)
{
## inital checks - all are fatal
if (improve < 0) {
stop("improve must be non-negative.")
}
if (stepFactor <= 1) {
stop("stepFactor must be great than 1.")
}
if (missing(formula)) {
stop("a formula must be supplied (only supervised forests allowed).")
}
## re-define the original data in case there are missing values
stump <- rfsrc(formula, data, nodedepth = 0, perf.type = "none", save.memory = TRUE, ntree = 1, splitrule = "random")
n <- stump$n
yvar.names <- stump$yvar.names
data <- data.frame(stump$yvar, stump$xvar)
colnames(data)[1:length(yvar.names)] <- yvar.names
rm(stump)
## sample size
if (is.function(sampsize)) {
ssize <- sampsize(n)
}
else {
ssize <- sampsize
}
## set performance type
if (is.null(list(...)$perf.type)) {
perf.type <- "default"
}
else {
perf.type <- list(...)$perf.type
}
## now hold out a test data set equal to the tree sample size (if possible)
if ((2 * ssize) < n) {
tst <- sample(1:n, size = ssize, replace = FALSE)
trn <- setdiff(1:n, tst)
newdata <- data[tst,, drop = FALSE]
}
else {
trn <- 1:n
newdata <- NULL
}
## intialize outer loop values
res <- list()
counter1 <- 0
## loop over nodesize
for (nsz in nodesizeTry) {
counter1 <- counter1 + 1
## acquire the starting error rate
if (is.null(newdata)) {
o <- rfsrc.fast(formula, data, ntree = ntreeTry, mtry = mtryStart, nodesize = nsz,
sampsize = ssize, nsplit = nsplit, ...)
errorOld <- mean(get.mv.error(o, TRUE), na.rm = TRUE)
}
else {
o <- rfsrc.fast(formula, data[trn,, drop = FALSE], ntree = ntreeTry, mtry = mtryStart, nodesize = nsz,
sampsize = ssize, nsplit = nsplit, forest = TRUE, perf.type = "none", save.memory = FALSE, ...)
errorOld <- mean(get.mv.error(predict(o, newdata, perf.type = perf.type), TRUE), na.rm = TRUE)
}
mtryStart <- o$mtry
mtryMax <- length(o$xvar.names)
## verbose output
if (trace) {
cat("nodesize = ", nsz,
" mtry =", mtryStart,
"error =", paste(100 * round(errorOld, 4), "%", sep = ""), "\n")
}
## terminate if we are stumping
if (mean(o$leaf.count) <= 2) {
break
}
## intialize inner loop values
oobError <- list()
oobError[[1]] <- errorOld
names(oobError)[1] <- mtryStart
## search over mtry
for (direction in c("left", "right")) {
if (trace)
cat("Searching", direction, "...\n")
Improve <- 1.1 * improve
mtryBest <- mtryStart
mtryCur <- mtryStart
counter2 <- 1
strikes <- 0
while (counter2 <= maxIter && (Improve >= improve || (Improve < 0 && strikes < strikeout))) {
counter2 <- counter2 + 1
if (Improve < 0) {
strikes <- strikes + 1
}
mtryOld <- mtryCur
if (direction == "left") {
mtryCur <- max(1, min(ceiling(mtryCur / stepFactor), mtryCur - 1))
}
else {
mtryCur <- min(mtryMax, max(floor(mtryCur * stepFactor), mtryCur + 1))
}
if (mtryCur == mtryOld) {
break
}
if (is.null(newdata)) {
errorCur <- mean(get.mv.error(rfsrc.fast(formula, data, ntree = ntreeTry, mtry = mtryCur,
nodesize = nsz, sampsize = ssize, nsplit = nsplit, ...), TRUE), na.rm = TRUE)
}
else {
errorCur <- mean(get.mv.error(predict(rfsrc.fast(formula, data[trn,, drop = FALSE], ntree = ntreeTry, mtry = mtryCur,
nodesize = nsz, sampsize = ssize, nsplit = nsplit, forest = TRUE, perf.type = "none", save.memory = FALSE, ...),
newdata, perf.type = perf.type), TRUE), na.rm = TRUE)
}
if (trace) {
cat("nodesize = ", nsz,
" mtry =", mtryCur,
" error =", paste(100 * round(errorCur, 4), "%", sep = ""), "\n")
}
oobError[[as.character(mtryCur)]] <- errorCur
Improve <- 1 - errorCur / errorOld
if (trace) {
cat(Improve, improve, "\n")
}
if (Improve > improve) {
errorOld <- errorCur
mtryBest <- mtryCur
}
}
}## finished inner loop
## save the results by nodesize
mtry <- sort(as.numeric(names(oobError)))
err <- unlist(oobError[as.character(mtry)])
res[[counter1]] <- cbind(nodesize = nsz, mtry = mtry, err = err)
}## finished outer loop
## check that results are not NULL
if (is.null(res)) {
stop("NULL results - something is wrong, check parameter (tuning) settings\n")
}
## deparse the results and sort along nodesize and mtry
res <- do.call(rbind, res)
res <- res[order(res[, 1], res[, 2]), ]
rownames(res) <- 1:nrow(res)
## determine the optimal value
opt.idx <- which.min(res[, 3])
## fit the optimized forest?
rf <- NULL
if (doBest) {
rf <- rfsrc.fast(formula, data,
mtry = res[opt.idx, 2],
nodesize = res[opt.idx, 1],
sampsize = sampsize,
nsplit = nsplit,
forest = TRUE, ...)
}
## return the goodies
list(results = res, optimal = res[opt.idx, -3], rf = rf)
}
tune <- tune.rfsrc
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.