Nothing
#' @include FamiliarS4Generics.R
#' @include FamiliarS4Classes.R
NULL
# familiarRFSRC ----------------------------------------------------------------
setClass("familiarRFSRC",
contains = "familiarModel",
slots = list("seed" = "ANY"),
prototype = list("seed" = NULL)
)
# familiarRFSRCDefault ---------------------------------------------------------
setClass("familiarRFSRCDefault",
contains = "familiarRFSRC"
)
# initialize -------------------------------------------------------------------
setMethod(
"initialize",
signature(.Object = "familiarRFSRC"),
function(.Object, ...) {
# Update with parent class first.
.Object <- callNextMethod()
# Set package
.Object@package <- "randomForestSRC"
return(.Object)
}
)
# is_available -----------------------------------------------------------------
setMethod(
"is_available", signature(object = "familiarRFSRC"),
function(object, ...) {
# Random forests exists for all outcome types and variable importance
# methods.
if (object@outcome_type == "count") ..deprecation_count()
return(TRUE)
}
)
# get_default_hyperparameters (familiarRFSRC) ----------------------------------
setMethod(
"get_default_hyperparameters",
signature(object = "familiarRFSRC"),
function(object, data = NULL, user_list = NULL, ...) {
# Initialise list and declare hyperparameter entries
param <- list()
param$sign_size <- list()
param$n_tree <- list()
param$sample_size <- list()
param$m_try <- list()
param$node_size <- list()
param$tree_depth <- list()
param$n_split <- list()
param$split_rule <- list()
param$sample_weighting <- list()
param$sample_weighting_beta <- list()
# Variable importance only parameters (are not optimised by hyperparameter
# optimisation)
param$fs_vimp_method <- list()
param$fs_vh_fold <- list()
param$fs_vh_step_size <- list()
param$fs_vh_n_rep <- list()
# If datat is not provided, return the list with hyperparameter names only
if (is.null(data)) return(param)
# Get the number of samples
n_samples <- data.table::uniqueN(data@data, by = get_id_columns(id_depth = "series"))
# signature size -----------------------------------------------------------
param$sign_size <- .get_default_sign_size(data = data)
# number of trees ----------------------------------------------------------
# Note that the number of trees is defined in powers of 2, based on Oshiro,
# T. M., Perez, P. S., & Baranauskas, J. A. (2012, July). How many trees in
# a random forest?. In MLDM (pp. 154-168).
param$n_tree <- .set_hyperparameter(
default = c(4, 8, 10),
type = "integer",
range = c(4, 10),
valid_range = c(0, Inf),
randomise = TRUE)
# sample size --------------------------------------------------------------
# Note that the sample size is here noted as a fraction, which corresponds
# to the usage in random forest SRC
param$sample_size <- .set_hyperparameter(
default = c(0.30, 0.50, 0.70, 0.90),
type = "numeric",
range = c(2.0 / n_samples, 1.00),
valid_range = c(0, 1.0),
randomise = TRUE)
# number of candidate features selected at node ----------------------------
# Note that the number of features is here noted as a fraction, but is used
# in randomforestSRC as an integer. Familiar ensures that always at least 1
# feature is available as a candidate.
param$m_try <- .set_hyperparameter(
default = c(0.1, 0.3, 0.5, 1.0),
type = "numeric",
range = c(0.0, 1.0),
randomise = TRUE)
# terminal node size -------------------------------------------------------
# Number of instances in the terminal node. Larger terminal node sizes limit
# tree depth and overfitting.
# Define the default range.
node_size_range <- c(5, floor(n_samples / 3))
# Define the default values.
node_size_default <- c(5, 10, 20, 50)
node_size_default <- node_size_default[
node_size_default >= node_size_range[1] &
node_size_default <= node_size_range[2]]
# Set the node_size parameter.
param$node_size <- .set_hyperparameter(
default = node_size_default,
type = "integer",
range = node_size_range,
valid_range = c(1, Inf),
randomise = TRUE)
# maximum tree depth -------------------------------------------------------
# Determines the depth trees are allowed to grow to. Larger depths increase
# the risk of overfitting.
param$tree_depth <- .set_hyperparameter(
default = c(1, 2, 3, 7),
type = "integer",
range = c(1, 10),
valid_range = c(1, Inf),
randomise = TRUE)
# number of split points ---------------------------------------------------
# The number of split points for each candidate variable is not randomised
# by default, and deterministic splitting is used.
param$n_split <- .set_hyperparameter(
default = 0,
type = "integer",
range = c(0, 10),
valid_range = c(0, Inf),
randomise = FALSE)
# splitting rule -----------------------------------------------------------
# Splitting rule is dependent on the outcome
if (data@outcome_type %in% c("binomial", "multinomial")) {
split_rule_range <- c("gini", "auc", "entropy")
split_rule_default <- "gini"
} else if (data@outcome_type %in% c("continuous", "count")) {
split_rule_range <- c("mse", "quantile.regr", "la.quantile.regr")
split_rule_default <- "mse"
} else if (data@outcome_type == "survival") {
split_rule_range <- c("logrank", "logrankscore", "bs.gradient")
split_rule_default <- "logrank"
} else if (data@outcome_type == "competing_risk") {
split_rule_range <- c("logrankCR", "logrank")
split_rule_default <- "logrankCR"
} else {
..error_no_known_outcome_type(data@outcome_type)
}
# Set the split_rule parameter.
param$split_rule <- .set_hyperparameter(
default = split_rule_default,
type = "factor",
range = split_rule_range,
randomise = FALSE)
# sample weighting method --------------------------------------------------
# Class imbalances may lead to learning majority classes. This can be
# partially mitigated by increasing weight of minority classes.
param$sample_weighting <- .get_default_sample_weighting_method(
outcome_type = object@outcome_type)
# effective number of samples beta -----------------------------------------
# Specifies the beta parameter for effective number sample weighting method.
# See Cui et al. (2019).
param$sample_weighting_beta <- .get_default_sample_weighting_beta(
method = c(
param$sample_weighting$init_config,
user_list$sample_weighting),
outcome_type = object@outcome_type)
# variable importance method -----------------------------------------------
param$fs_vimp_method <- .set_hyperparameter(
default = "permutation",
type = "factor",
range = c("permutation", "minimum_depth", "variable_hunting", "holdout"),
randomise = FALSE)
# variable hunting cross-validation folds (variable importance only) -------
param$fs_vh_fold <- .set_hyperparameter(
default = 5,
type = "integer",
range = c(2, n_samples),
valid_range = c(2, Inf),
randomise = FALSE)
# variable hunting step size (variable importance only) --------------------
param$fs_vh_step_size <- .set_hyperparameter(
default = 1,
type = "integer",
range = c(1, 50),
valid_range = c(1, Inf),
randomise = FALSE)
# variable hunting repetitions (variable importance only) ------------------
param$fs_vh_n_rep <- .set_hyperparameter(
default = 50,
type = "integer",
range = c(1, 50),
valid_range = c(1, Inf),
randomise = FALSE)
return(param)
}
)
# get_default_hyperparameters (familiarRFSRCDefault) ---------------------------
setMethod(
"get_default_hyperparameters",
signature(object = "familiarRFSRCDefault"),
function(object, data = NULL, user_list = NULL, ...) {
# Initialise list and declare hyperparameter entries
param <- list()
param$sign_size <- list()
param$sample_weighting <- list()
param$sample_weighting_beta <- list()
# Variable importance only parameters (are not optimised by hyperparameter
# optimisation)
param$fs_vimp_method <- list()
param$fs_vh_fold <- list()
param$fs_vh_step_size <- list()
param$fs_vh_n_rep <- list()
# If data is not provided, return the list with hyperparameter names only.
if (is.null(data)) return(param)
# Get the number of samples
n_samples <- get_n_samples(data, id_depth = "series")
# signature size -----------------------------------------------------------
param$sign_size <- .get_default_sign_size(data = data)
# sample weighting method --------------------------------------------------
# Class imbalances may lead to learning majority classes. This can be
# partially mitigated by increasing weight of minority classes.
param$sample_weighting <- .get_default_sample_weighting_method(
outcome_type = object@outcome_type)
# effective number of samples beta -----------------------------------------
# Specifies the beta parameter for effective number sample weighting method.
# See Cui et al. (2019).
param$sample_weighting_beta <- .get_default_sample_weighting_beta(
method = c(
param$sample_weighting$init_config,
user_list$sample_weighting),
outcome_type = object@outcome_type)
# variable importance method -----------------------------------------------
param$fs_vimp_method <- .set_hyperparameter(
default = "permutation",
type = "factor",
range = c("permutation", "minimum_depth", "variable_hunting", "holdout"),
randomise = FALSE)
# variable hunting cross-validation folds (variable importance only) -------
param$fs_vh_fold <- .set_hyperparameter(
default = 5,
type = "integer",
range = c(2, n_samples),
valid_range = c(2, Inf),
randomise = FALSE)
# variable hunting step size (variable importance only) --------------------
param$fs_vh_step_size <- .set_hyperparameter(
default = 1,
type = "integer",
range = c(1, 50),
valid_range = c(1, Inf),
randomise = FALSE)
# variable hunting repetitions (variable importance only) ------------------
param$fs_vh_n_rep <- .set_hyperparameter(
default = 50,
type = "integer",
range = c(1, 50),
valid_range = c(1, Inf),
randomise = FALSE)
return(param)
}
)
# get_prediction_type ----------------------------------------------------------
setMethod(
"get_prediction_type",
signature(object = "familiarRFSRC"),
function(object, type = "default") {
if (object@outcome_type == "survival") {
if (type == "default") {
# Standard predictions.
return("cumulative_hazard")
} else if (type %in% c("survival", "survival_probability")) {
return("survival_probability")
} else if (type %in% c("response", "cumulative_hazard")) {
return("cumulative_hazard")
} else {
stop(paste0("Prediction type is not implemented: ", type))
}
} else {
return(callNextMethod())
}
}
)
# ..train ----------------------------------------------------------------------
setMethod(
"..train",
signature(
object = "familiarRFSRC",
data = "dataObject"),
function(object, data, anonymous = TRUE, ...) {
# Aggregate repeated measurement data - randomForestSRC does not facilitate
# repeated measurements.
data <- aggregate_data(data = data)
# Check if training data is ok.
if (reason <- has_bad_training_data(object = object, data = data)) {
return(callNextMethod(object = .why_bad_training_data(
object = object,
reason = reason)))
}
# Check if hyperparameters are set.
if (is.null(object@hyperparameters)) {
return(callNextMethod(object = ..update_errors(
object = object,
..error_message_no_optimised_hyperparameters_available())))
}
# Check that required packages are loaded and installed.
require_package(object, "train")
# Find feature columns in data table
feature_columns <- get_feature_columns(x = data)
# Parse formula.
if (object@outcome_type == "survival") {
Surv <- survival::Surv
formula <- stats::reformulate(
termlabels = feature_columns,
response = quote(Surv(outcome_time, outcome_event)))
} else if (object@outcome_type %in% c("binomial", "multinomial", "count", "continuous")) {
formula <- stats::reformulate(
termlabels = feature_columns,
response = quote(outcome))
} else {
..error_outcome_type_not_implemented(object@outcome_type)
}
# Extract hyperparameters from the model object.
param <- object@hyperparameters
# Determine the sample size
sample_size <- ceiling(param$sample_size * nrow(data@data))
sample_type <- ifelse(sample_size == nrow(data@data), "swr", "swor")
# Set weights.
weights <- create_instance_weights(
data = data,
method = object@hyperparameters$sample_weighting,
beta = ..compute_effective_number_of_samples_beta(
object@hyperparameters$sample_weighting_beta),
normalisation = "average_one")
# Set forest seed. If object comes with a defined seed use the seed.
forest_seed <- ifelse(
is.null(object@seed),
as.integer(stats::runif(1, -100000, -1)),
object@seed)
# Use anonymised version for the forest, if available.
if (utils::packageVersion("randomForestSRC") >= "2.11.0" && anonymous) {
forest_function <- randomForestSRC::rfsrc.anonymous
} else {
forest_function <- randomForestSRC::rfsrc
}
# Set learner arguments.
learner_arguments <- list(formula,
"data" = data@data,
"case.wt" = weights,
"seed" = forest_seed)
# Add additional hyperparameters for models that use them.
if (!is(object, "familiarRFSRCDefault")) {
learner_arguments <- c(
learner_arguments,
list(
"ntree" = 2^param$n_tree,
"samptype" = sample_type,
"sampsize" = sample_size,
"mtry" = max(c(1, ceiling(param$m_try * length(feature_columns)))),
"nodesize" = param$node_size,
"nodedepth" = param$tree_depth,
"nsplit" = param$n_split,
"splitrule" = as.character(param$split_rule)))
}
# Train the model.
model <- do.call_with_handlers(
forest_function,
args = learner_arguments)
# Extract values.
object <- ..update_warnings(object = object, model$warning)
object <- ..update_errors(object = object, model$error)
model <- model$value
# Check if the model trained at all.
if (!is.null(object@messages$error)) return(callNextMethod(object = object))
# Add model to the object.
object@model <- model
# Store the seed used to create the object. This helps recreate the object
# in case the model is an anonymous forest.
object@seed <- forest_seed
# Set learner version
object <- set_package_version(object)
return(object)
}
)
# ..train_naive ----------------------------------------------------------------
setMethod(
"..train_naive",
signature(
object = "familiarRFSRC",
data = "dataObject"),
function(object, data, ...) {
if (object@outcome_type %in% c("count", "continuous", "binomial", "multinomial")) {
# Turn into a Naive model.
object <- methods::new("familiarNaiveModel", object)
} else if (object@outcome_type %in% c("survival")) {
# Turn into a Naive cumulative incidence model.
object <- methods::new("familiarNaiveCumulativeIncidenceModel", object)
}
return(..train(
object = object,
data = data,
...))
}
)
# ..predict --------------------------------------------------------------------
setMethod(
"..predict",
signature(
object = "familiarRFSRC",
data = "dataObject"),
function(object, data, type = "default", time = NULL, ...) {
# Check that required packages are loaded and installed.
require_package(object, "predict")
if (type %in% c("default", "survival_probability")) {
# Default method ---------------------------------------------------------
# Check if the model was trained.
if (!model_is_trained(object)) return(callNextMethod())
# Check if the data is empty.
if (is_empty(data)) return(callNextMethod())
# Get an empty prediction table.
prediction_table <- get_placeholder_prediction_table(
object = object,
data = data,
type = type)
# Make predictions using the model.
model_predictions <- predict(
object = object@model,
newdata = data@data)
if (object@outcome_type %in% c("binomial", "multinomial")) {
# categorical outcomes -------------------------------------------------
# Set predicted class.
prediction_table[, "predicted_class" := model_predictions$class]
# Add class probabilities.
class_probability_columns <- get_class_probability_name(x = object)
for (ii in seq_along(class_probability_columns)) {
prediction_table[, (class_probability_columns[ii]) := model_predictions$predicted[, ii]]
}
} else if (object@outcome_type %in% c("continuous", "count")) {
# numerical outcomes ---------------------------------------------------
# Extract predicted regression values.
prediction_table[, "predicted_outcome" := model_predictions$predicted]
} else if (object@outcome_type %in% c("survival")) {
# survival outcomes ----------------------------------------------------
# Get the unique event times
event_times <- model_predictions$time.interest
# Set default time, if not provided.
time <- ifelse(is.null(time), max(event_times), time)
if (type == "default") {
# Cumulative hazard.
# Get the cumulative hazards at the given time point.
prediction_table <- process_random_forest_survival_predictions(
event_matrix = model_predictions$chf,
event_times = event_times,
prediction_table = prediction_table,
time = time,
type = "cumulative_hazard")
} else if (type == "survival_probability") {
# Survival probability.
# Get the survival probability at the given time point.
prediction_table <- process_random_forest_survival_predictions(
event_matrix = model_predictions$survival,
event_times = event_times,
prediction_table = prediction_table,
time = time,
type = "survival")
} else {
..error_outcome_type_not_implemented(object@outcome_type)
}
}
return(prediction_table)
} else {
# User-specified method --------------------------------------------------
# Check if the model was trained.
if (!model_is_trained(object)) {
return(NULL)
}
# Check if the data is empty.
if (is_empty(data)) {
return(NULL)
}
# Make predictions using the model.
return(predict(
object = object@model,
newdata = data@data,
...))
}
}
)
# ..predict_survival_probability -----------------------------------------------
setMethod(
"..predict_survival_probability",
signature(
object = "familiarRFSRC",
data = "dataObject"),
function(object, data, time, ...) {
if (!object@outcome_type %in% c("survival")) return(callNextMethod())
# Check that required packages are loaded and installed.
require_package(object, "predict")
return(..predict(
object = object,
data = data,
time = time,
type = "survival_probability"))
}
)
# ..vimp -----------------------------------------------------------------------
setMethod(
"..vimp",
signature(object = "familiarRFSRC"),
function(object, data = NULL, ...) {
# Attempt to train the model if it has not been trained yet.
if (!model_is_trained(object)) {
object <- .train(
object = object,
data = data,
get_additional_info = FALSE,
trim_model = FALSE,
anonymous = FALSE)
}
# Check if the model has been trained upon retry.
if (!model_is_trained(object)) return(callNextMethod())
# Check that required packages are loaded and installed.
require_package(object, "vimp")
# Find the vimp_method specified.
vimp_method <- object@hyperparameters$fs_vimp_method
if (is.null(vimp_method)) {
..error_reached_unreachable_code(paste0(
"..vimp,familiarRFSRC: vimp method was not specified as a ",
"hyperparameter (fs_vimp_method)"))
}
vimp_method <- as.character(vimp_method)
# Extract the variable importance score
if (vimp_method == "permutation") {
# Check if the model is anonymous, and rebuild if it is. VIMP does not
# work otherwise.
if (inherits(object@model, "anonymous")) {
object <- .train(
object = object,
data = data,
get_additional_info = FALSE,
trim_model = FALSE,
anonymous = FALSE)
# Check that the non-anonymous model is trained.
if (!model_is_trained(object)) return(callNextMethod())
}
# Determine permutation variable importance
vimp_score <- randomForestSRC::vimp(
object = object@model,
importance = "permute"
)$importance
# Check that the variable importance score is not empty.
if (is_empty(vimp_score)) return(callNextMethod())
# The variable importance score for binomial and multinomial
# outcomes is per class
if (is.matrix(vimp_score)) {
vimp_score_names <- rownames(vimp_score)
vimp_score <- vimp_score[, 1]
} else {
vimp_score_names <- names(vimp_score)
}
# Create variable importance object.
vimp_object <- methods::new(
"vimpTable",
vimp_table = data.table::data.table(
"score" = vimp_score,
"name" = vimp_score_names),
score_aggregation = "max",
invert = TRUE)
return(vimp_object)
} else if (vimp_method == "minimum_depth") {
# Check if the model is anonymous, and rebuild if it is. VIMP does not
# work otherwise.
if (inherits(object@model, "anonymous")) {
object <- .train(
object = object,
data = data,
get_additional_info = FALSE,
trim_model = FALSE,
anonymous = FALSE)
# Check that the non-anonymous model is trained.
if (!model_is_trained(object)) return(callNextMethod())
}
# Determine minimum depth variable importance
vimp_score <- randomForestSRC::var.select(
object = object@model,
method = "md",
verbose = FALSE
)$md.obj$order
# Check that the variable importance score is not empty.
if (is_empty(vimp_score)) return(callNextMethod())
# Select the "min depth" column, which is the first column
if (is.matrix(vimp_score)) {
vimp_score_names <- rownames(vimp_score)
vimp_score <- vimp_score[, 1]
} else {
vimp_score_names <- names(vimp_score)
}
# Create variable importance object.
vimp_object <- methods::new(
"vimpTable",
vimp_table = data.table::data.table(
"score" = vimp_score,
"name" = vimp_score_names),
score_aggregation = "max",
invert = FALSE)
return(vimp_object)
} else if (vimp_method == "variable_hunting") {
# Check if the model is anonymous, and rebuild if it is. VIMP does not
# work otherwise.
if (inherits(object@model, "anonymous")) {
object <- .train(
object = object,
data = data,
get_additional_info = FALSE,
trim_model = FALSE,
anonymous = FALSE)
# Check that the non-anonymous model is trained.
if (!model_is_trained(object)) return(callNextMethod())
}
# Perform variable hunting
vimp_score <- randomForestSRC::var.select(
object = object@model,
method = "vh",
K = object@hyperparameters$fs_vh_fold,
nstep = object@hyperparameters$fs_vh_step_size,
nrep = object@hyperparameters$fs_vh_n_rep,
verbose = FALSE,
refit = FALSE
)$varselect
# Check that the variable importance score is not empty.
if (is_empty(vimp_score)) return(callNextMethod())
# Select the "rel.freq" column, which is the second column
if (is.matrix(vimp_score)) {
vimp_score_names <- rownames(vimp_score)
vimp_score <- vimp_score[, 2]
} else {
vimp_score_names <- names(vimp_score)
}
# Create variable importance object.
vimp_object <- methods::new(
"vimpTable",
vimp_table = data.table::data.table(
"score" = vimp_score,
"name" = vimp_score_names),
score_aggregation = "max",
invert = TRUE)
return(vimp_object)
} else if (vimp_method == "holdout") {
# Check that data is present.
if (is_empty(data)) {
warning(paste0(
"Data is required to assess variable importance using ",
"randomForestRFSRC::holdout.vimp"))
return(callNextMethod())
}
if (get_n_features(data) <= 1) {
warning(paste0(
"Variable importance using randomForestRFSRC::holdout.vimp ",
"requires more than 1 feature."))
# Create variable importance object.
vimp_object <- methods::new(
"vimpTable",
vimp_table = data.table::data.table(
"score" = 0.0,
"name" = get_feature_columns(x = data)),
score_aggregation = "max",
invert = TRUE)
return(vimp_object)
}
# Find feature columns in the data.
feature_columns <- get_feature_columns(x = data)
# Parse formula.
if (object@outcome_type == "survival") {
# Creating a local instance of survival::Surv() prevents an
# error.
Surv <- survival::Surv
formula <- stats::reformulate(
termlabels = feature_columns,
response = quote(Surv(outcome_time, outcome_event)))
} else if (object@outcome_type %in% c("binomial", "multinomial", "count", "continuous")) {
formula <- stats::reformulate(
termlabels = feature_columns,
response = quote(outcome))
} else {
..error_outcome_type_not_implemented(object@outcome_type)
}
# Get arguments for the holdout.vimp learner.
learner_arguments <- list(
formula,
"data" = data@data,
"verbose" = FALSE)
# Get additional variables for the optimised learner.
if (!is(object, "familiarRFSRCDefault")) {
# Determine the sample size
sample_size <- ceiling(object@hyperparameters$sample_size * nrow(data@data))
sample_type <- ifelse(sample_size == nrow(data@data), "swr", "swor")
learner_arguments <- c(
learner_arguments,
list(
"mtry" = max(c(1, ceiling(object@hyperparameters$m_try * get_n_features(data)))),
"samptype" = sample_type,
"sampsize" = sample_size,
"nodesize" = object@hyperparameters$node_size,
"nodedepth" = object@hyperparameters$tree_depth,
"nsplit" = object@hyperparameters$n_split,
"splitrule" = as.character(object@hyperparameters$split_rule)))
}
# Perform holdout variable importance.
vimp_score <- do.call(
randomForestSRC::holdout.vimp,
args = learner_arguments
)$importance
# Check that the variable importance score is not empty.
if (is_empty(vimp_score)) return(callNextMethod())
# Select the "all" column, which is the first column
if (is.matrix(vimp_score)) {
vimp_score_names <- rownames(vimp_score)
vimp_score <- vimp_score[, 1]
} else {
vimp_score_names <- names(vimp_score)
}
# Create variable importance object.
vimp_object <- methods::new(
"vimpTable",
vimp_table = data.table::data.table(
"score" = vimp_score,
"name" = vimp_score_names),
score_aggregation = "max",
invert = TRUE)
return(vimp_object)
} else {
..error_reached_unreachable_code(paste0(
"..vimp,familiarRFSRC: unknown vimp method was specified: ", vimp_method))
}
return(NULL)
}
)
# ..set_calibration_info -------------------------------------------------------
setMethod(
"..set_calibration_info",
signature(object = "familiarRFSRC"),
function(object, data) {
# Check if calibration info already.
if (has_calibration_info(object)) return(object)
if (object@outcome_type == "survival") {
# Determine baseline survival.
object@calibration_info <- get_baseline_survival(data = data)
} else {
return(callNextMethod())
}
return(object)
}
)
# ..set_vimp_parameters --------------------------------------------------------
setMethod(
"..set_vimp_parameters",
signature(object = "familiarRFSRC"),
function(object, method, ...) {
# Determine variable importance method
if (startswith_any(
method,
prefix = c(
"random_forest_permutation",
"random_forest_rfsrc_permutation"))) {
vimp_method <- "permutation"
} else if (startswith_any(
method,
prefix = c(
"random_forest_minimum_depth",
"random_forest_rfsrc_minimum_depth"))) {
vimp_method <- "minimum_depth"
} else if (startswith_any(
method,
prefix = c(
"random_forest_variable_hunting",
"random_forest_rfsrc_variable_hunting"))) {
vimp_method <- "variable_hunting"
} else if (startswith_any(
method,
prefix = c(
"random_forest_holdout",
"random_forest_rfsrc_holdout"))) {
vimp_method <- "holdout"
} else {
..error_reached_unreachable_code(paste0(
"..set_vimp_parameters,familiarRFSRC: unknown feature selection ",
"method was specified."))
}
# Check if a list of hyperparameters is already present.
if (is.null(object@hyperparameters)) {
hyperparameters <- list()
} else {
hyperparameters <- object@hyperparameters
}
# Update the fs_vimp_method hyperparameters.
hyperparameters$fs_vimp_method <- vimp_method
# Store in the object
object@hyperparameters <- hyperparameters
return(object)
}
)
# .trim_model-------------------------------------------------------------------
setMethod(
".trim_model",
signature(object = "familiarRFSRC"),
function(object, ...) {
# Update model by removing the call.
object@model$call <- call("trimmed")
# Add show.
object <- .capture_show(object)
# Remove the predictions.
object@model$predicted <- NULL
object@model$predicted.oob <- NULL
# Remove yvar due to redundancy.
object@model$yvar <- NULL
# Set is_trimmed to TRUE.
object@is_trimmed <- TRUE
return(object)
}
)
.get_available_rfsrc_learners <- function(show_general = TRUE) {
return(c("random_forest", "random_forest_rfsrc"))
}
.get_available_rfsrc_default_learners <- function(show_general = TRUE) {
return(paste0(.get_available_rfsrc_learners(show_general = show_general), "_default"))
}
.get_available_rfsrc_vimp_methods <- function(show_general = TRUE) {
return(c(
"random_forest_permutation", "random_forest_minimum_depth",
"random_forest_variable_hunting", "random_forest_rfsrc_permutation",
"random_forest_rfsrc_minimum_depth", "random_forest_rfsrc_variable_hunting",
"random_forest_holdout", "random_forest_rfsrc_holdout"))
}
.get_available_rfsrc_default_vimp_methods <- function(show_general = TRUE) {
return(paste0(.get_available_rfsrc_vimp_methods(show_general = show_general), "_default"))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.