#' @title Oblique Random Forest Classifier
#' @author annanzrv
#' @name mlr_learners_classif.aorsf
#'
#' @description
#' Accelerated oblique random classification forest.
#' Calls [aorsf::orsf()] from \CRANpkg{aorsf}.
#'
#' @section Initial parameter values:
#' * `n_thread`: This parameter is initialized to 1 (default is 0) to avoid conflicts with the mlr3 parallelization.
#' * `pred_simplify` has to be TRUE, otherwise response is NA in prediction
#'
#' @template seealso_learner
#' @examplesIf requireNamespace("aorsf", quietly = TRUE)
#' # Define the Learner
#' learner = mlr3::lrn("classif.aorsf", importance = "anova")
#' print(learner)
#'
#' # Define a Task
#' task = mlr3::tsk("breast_cancer")
#' # Create train and test set
#' ids = mlr3::partition(task)
#'
#' # Train the learner on the training ids
#' learner$train(task, row_ids = ids$train)
#'
#' print(learner$model)
#' print(learner$importance())
#'
#' # Make predictions for the test rows
#' predictions = learner$predict(task, row_ids = ids$test)
#'
#' # Score the predictions
#' predictions$score()
#'
#' @export
LearnerClassifObliqueRandomForest = R6Class("LearnerClassifObliqueRandomForest",
inherit = LearnerClassif,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
attach_data = p_lgl(default = TRUE, tags = "train"),
epsilon = p_dbl(default = 1e-9, lower = 0, tags = "train"),
importance = p_fct(levels = c("none", "anova", "negate", "permute"), default = "anova", tags = "train"),
importance_max_pvalue = p_dbl(default = 0.01, lower = 0.0001, upper = .9999, tags = "train"),
leaf_min_events = p_int(default = 1L, lower = 1L, tags = "train"),
leaf_min_obs = p_int(default = 5L, lower = 1L, tags = "train"),
max_iter = p_int(default = 20L, lower = 1, tags = "train"),
method = p_fct(levels = c("glm", "net", "pca", "random"), default = "glm", tags = "train"),
mtry = p_int(default = NULL, lower = 1L, special_vals = list(NULL), tags = "train"),
mtry_ratio = p_dbl(lower = 0, upper = 1, tags = "train"),
n_retry = p_int(default = 3L, lower = 0L, tags = "train"),
n_split = p_int(default = 5L, lower = 1L, tags = "train"),
n_thread = p_int(default = 0, lower = 0, tags = c("train", "predict", "threads")),
n_tree = p_int(default = 500L, lower = 1L, tags = "train"),
na_action = p_fct(levels = c("fail", "omit", "impute_meanmode"), default = "fail", tags = "train"),
net_mix = p_dbl(default = 0.5, tags = "train"),
oobag = p_lgl(default = FALSE, tags = "predict"),
oobag_eval_every = p_int(default = NULL, special_vals = list(NULL), lower = 1, tags = "train"),
oobag_fun = p_uty(default = NULL, special_vals = list(NULL), tags = "train", custom_check = function(x) checkmate::checkFunction(x, nargs = 3)),
oobag_pred_type = p_fct(levels = c("none", "leaf", "prob", "class"), default = "prob", tags = "train"),
pred_aggregate = p_lgl(default = TRUE, tags = "predict"),
# pred_simplify = p_lgl(default = TRUE, tags = "predict"), # can't be FALSE, otherwise response is NA in prediction
sample_fraction = p_dbl(lower = 0, upper = 1, default = .632, tags = "train"),
sample_with_replacement = p_lgl(default = TRUE, tags = "train"),
scale_x = p_lgl(default = FALSE, tags = "train"),
split_min_events = p_int(default = 5L, lower = 1L, tags = "train"),
split_min_obs = p_int(default = 10, lower = 1L, tags = "train"),
split_min_stat = p_dbl(default = NULL, special_vals = list(NULL), lower = 0, tags = "train"),
split_rule = p_fct(levels = c("gini", "cstat"), default = "gini", tags = "train"),
target_df = p_int(default = NULL, lower = 1L, special_vals = list(NULL), tags = "train"),
tree_seeds = p_int(default = NULL, lower = 1L, special_vals = list(NULL), tags = "train"),
verbose_progress = p_lgl(default = FALSE, tags = "train"))
ps$values = list(n_thread = 1)
super$initialize(
id = "classif.aorsf",
packages = c("mlr3extralearners", "aorsf"),
feature_types = c("integer", "numeric", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps,
properties = c("oob_error", "importance", "multiclass", "twoclass", "weights"),
man = "mlr3extralearners::mlr_learners_classif.aorsf",
label = "Oblique Random Forest Classifier"
)
},
#' @description
#' OOB concordance error extracted from the model slot
#' `eval_oobag$stat_values`
#' @return `numeric()`.
oob_error = function() {
nrows = nrow(self$model$eval_oobag$stat_values)
1 - self$model$eval_oobag$stat_values[nrows, 1L]
},
#' @description
#' The importance scores are extracted from the model.
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
sort(aorsf::orsf_vi(self$model, group_factors = TRUE),
decreasing = TRUE)
}
),
private = list(
.train = function(task) {
# initialize
args_ctrl = formalArgs(aorsf::orsf_control_classification)
pv = self$param_set$get_values(tags = "train")
pv = convert_ratio(pv, "mtry", "mtry_ratio", length(task$feature_names))
pv_ctrl = pv[names(pv) %in% args_ctrl]
pv_train = pv[names(pv) %nin% args_ctrl]
ctrl = invoke(aorsf::orsf_control_classification, .args = pv_ctrl)
# default value for oobag_eval_every is ntree, but putting
# default = ntree in p_int() above would be problematic, so:
if (is.null(pv$oobag_eval_every)) {
val = pv[["n_tree"]]
# if value not set, set to default value and pass to oobag_eval_every
if (is.null(val)) val = self$param_set$default[["ntree"]]
pv$oobag_eval_every = val
}
invoke(
aorsf::orsf,
data = task$data(),
formula = task$formula(),
weights = private$.get_weights(task),
control = ctrl,
no_fit = FALSE,
.args = pv
)
},
.predict = function(task) {
pars = self$param_set$get_values(tags = "predict")
newdata = ordered_features(task, self)
type = if (self$predict_type == "response") "class" else "prob"
pred_simplify = self$predict_type == "response"
pred = invoke(
predict,
self$model,
new_data = newdata,
pred_type = type,
pred_simplify = pred_simplify,
.args = pars)
if (self$predict_type == "response") {
list(response = pred)
} else {
list(prob = pred)
}
}
)
)
.extralrns_dict$add("classif.aorsf", LearnerClassifObliqueRandomForest)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.