#' @export
makeRLearner.classif.randomForestSRC = function() {
makeRLearnerClassif(
cl = "classif.randomForestSRC",
package = "randomForestSRC",
par.set = makeParamSet(
makeIntegerLearnerParam(id = "ntree", default = 1000L, lower = 1L),
makeDiscreteLearnerParam(id = "bootstrap", default = "by.root",
values = c("by.root", "by.node", "none")),
makeIntegerLearnerParam(id = "mtry", lower = 1L),
makeIntegerLearnerParam(id = "nodesize", lower = 1L, default = 1L),
makeIntegerLearnerParam(id = "nodedepth", default = -1L),
makeDiscreteLearnerParam(id = "splitrule", default = "gini",
values = c("gini", "gini.unwt", "gini.hvwt", "random")),
makeIntegerLearnerParam(id = "nsplit", lower = 0L, default = 0L,
requires = quote(splitrule != "random")), # nsplit is ignored and internally set to 1L for splitrule = "random"
makeLogicalLearnerParam(id = "split.null", default = FALSE),
makeDiscreteLearnerParam(id = "importance", default = FALSE, tunable = FALSE,
values = list(`FALSE` = FALSE, `TRUE` = TRUE, "none", "permute", "random", "anti",
"permute.ensemble", "random.ensemble", "anti.ensemble")),
makeDiscreteLearnerParam(id = "na.action", default = "na.impute",
values = c("na.omit", "na.impute"), when = "both"),
# FIXME the default in rfsrc() for na.action is na.omit
makeIntegerLearnerParam(id = "nimpute", default = 1L, lower = 1L),
makeDiscreteLearnerParam(id = "proximity", default = FALSE, tunable = FALSE,
values = list("inbag", "oob", "all", `TRUE` = TRUE, `FALSE` = FALSE)),
makeIntegerLearnerParam(id = "sampsize", lower = 1L,
requires = quote(bootstrap == "by.root")),
makeDiscreteLearnerParam(id = "samptype", default = "swr", values = c("swr", "swor"),
requires = quote(bootstrap == "by.root")),
makeNumericVectorLearnerParam(id = "xvar.wt", lower = 0),
makeLogicalLearnerParam(id = "forest", default = TRUE, tunable = FALSE),
makeDiscreteLearnerParam(id = "var.used", default = FALSE, tunable = FALSE,
values = list(`FALSE` = FALSE, "all.trees", "by.tree")),
makeDiscreteLearnerParam(id = "split.depth", default = FALSE, tunable = FALSE,
values = list(`FALSE` = FALSE, "all.trees", "by.tree")),
makeIntegerLearnerParam(id = "seed", lower = 0L, tunable = FALSE),
makeLogicalLearnerParam(id = "do.trace", default = FALSE, tunable = FALSE, when = "both"), # is currently ignored
makeLogicalLearnerParam(id = "membership", default = TRUE, tunable = FALSE),
makeLogicalLearnerParam(id = "statistics", default = FALSE, tunable = FALSE),
makeLogicalLearnerParam(id = "tree.err", default = FALSE, tunable = FALSE)
),
par.vals = list(na.action = "na.impute"),
properties = c("missings", "numerics", "factors", "ordered", "prob", "twoclass", "multiclass", "weights", "oobpreds", "featimp"),
name = "Random Forest",
short.name = "rfsrc",
note = '`na.action` has been set to `"na.impute"` by default to allow missing data support.',
callees = "rfsrc"
)
}
#' @export
trainLearner.classif.randomForestSRC = function(.learner, .task, .subset, .weights = NULL, ...) {
f = getTaskFormula(.task)
randomForestSRC::rfsrc(f, data = getTaskData(.task, .subset, recode.target = "drop.levels"), case.wt = .weights, ...)
}
#' @export
predictLearner.classif.randomForestSRC = function(.learner, .model, .newdata, ...) {
p = predict(.model$learner.model, newdata = .newdata, membership = FALSE, ...)
if (.learner$predict.type == "prob") {
return(p$predicted)
} else {
return(p$class)
}
}
#' @export
getOOBPredsLearner.classif.randomForestSRC = function(.learner, .model) {
preds = getLearnerModel(.model, more.unwrap = TRUE)$predicted.oob
if (.learner$predict.type == "response") {
factor(colnames(preds)[max.col(preds)], levels = colnames(preds))
} else {
preds
}
}
#' @export
getFeatureImportanceLearner.classif.randomForestSRC = function(.learner, .model, ...) {
mod = getLearnerModel(.model, more.unwrap = TRUE)
randomForestSRC::vimp(mod, ...)$importance[, "all"]
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.