View source: R/hmda.suggest.param.R
hmda.suggest.param | R Documentation |
Suggests candidate hyperparameter values for tree-based
algorithms. It computes a hyperparameter grid whose total number
of model combinations is near a specified target. For GBM models,
default candidates include max_depth, ntrees, learn_rate,
sample_rate, and col_sample_rate. For DRF models, if a vector of predictor
variables (x
) and a modeling family ("regression" or "classificaiton")
are provided, a vector of mtries is also suggested.
hmda.suggest.param(algorithm, n_models, x = NULL, family = NULL)
algorithm |
A character string specifying the algorithm, which can be either "gbm" (gradient boosting machines) or "drf" (distributed random forest). |
n_models |
An integer for the desired approximate number of model combinations in the grid. Must be at least 100. |
x |
(Optional) A vector of predictor names. If provided and its length is at least 20, it is used to compute mtries for DRF.å |
family |
(Optional) A character string indicating the
modeling family. Must be either "classification"
or "regression". This is used with |
The function first checks that n_models
is at least 100,
then validates the family
parameter if provided. The
algorithm name is normalized to lowercase and must be either
"gbm" or "drf". For "gbm", a default grid of hyperparameters is
defined. For "drf", if both x
and family
are provided,
the function computes mtries via suggest_mtries()
. If not,
a default grid is set without mtries. Finally, the candidate grid is
pruned or expanded using hmda.adjust.params()
so that the
total number of combinations is near n_models
.
A named list of hyperparameter value vectors. This list is suitable for use with HMDA and H2O grid search functions.
## Not run:
library(h2o)
h2o.init()
# Example 1: Suggest hyperparameters for GBM with about 120 models.
params_gbm <- hmda.suggest.param("gbm", n_models = 120)
print(params_gbm)
# Example 2: Suggest hyperparameters for DRF (classification) with
# 100 predictors.
params_drf <- hmda.suggest.param(
algorithm = "drf",
n_models = 150,
x = paste0("V", 1:100),
family = "classification"
)
print(params_drf)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.