GridSearchCV | R Documentation |
Runs grid search cross validation scheme to find best model training parameters.
Grid search CV is used to train a machine learning model with multiple combinations of training hyper parameters and finds the best combination of parameters which optimizes the evaluation metric. It creates an exhaustive set of hyperparameter combinations and train model on each combination.
trainer
superml trainer object, could be either XGBTrainer, RFTrainer, NBTrainer etc.
parameters
a list of parameters to tune
n_folds
number of folds to use to split the train data
scoring
scoring metric used to evaluate the best model, multiple values can be provided. currently supports: auc, accuracy, mse, rmse, logloss, mae, f1, precision, recall
evaluation_scores
parameter for internal use
new()
GridSearchCV$new(trainer = NA, parameters = NA, n_folds = NA, scoring = NA)
trainer
superml trainer object, could be either XGBTrainer, RFTrainer, NBTrainer etc.
parameters
list, a list of parameters to tune
n_folds
integer, number of folds to use to split the train data
scoring
character, scoring metric used to evaluate the best model, multiple values can be provided. currently supports: auc, accuracy, mse, rmse, logloss, mae, f1, precision, recall
Create a new 'GridSearchCV' object.
A 'GridSearchCV' object.
rf <- RFTrainer$new() gst <-GridSearchCV$new(trainer = rf, parameters = list(n_estimators = c(100), max_depth = c(5,2,10)), n_folds = 3, scoring = c('accuracy','auc'))
fit()
GridSearchCV$fit(X, y)
X
data.frame or data.table
y
character, name of target variable
Trains the model using grid search
NULL
rf <- RFTrainer$new() gst <-GridSearchCV$new(trainer = rf, parameters = list(n_estimators = c(100), max_depth = c(5,2,10)), n_folds = 3, scoring = c('accuracy','auc')) data("iris") gst$fit(iris, "Species")
best_iteration()
GridSearchCV$best_iteration(metric = NULL)
metric
character, which metric to use for evaluation
Returns the best parameters
a list of best parameters
rf <- RFTrainer$new() gst <-GridSearchCV$new(trainer = rf, parameters = list(n_estimators = c(100), max_depth = c(5,2,10)), n_folds = 3, scoring = c('accuracy','auc')) data("iris") gst$fit(iris, "Species") gst$best_iteration()
clone()
The objects of this class are cloneable with this method.
GridSearchCV$clone(deep = FALSE)
deep
Whether to make a deep clone.
## ------------------------------------------------
## Method `GridSearchCV$new`
## ------------------------------------------------
rf <- RFTrainer$new()
gst <-GridSearchCV$new(trainer = rf,
parameters = list(n_estimators = c(100),
max_depth = c(5,2,10)),
n_folds = 3,
scoring = c('accuracy','auc'))
## ------------------------------------------------
## Method `GridSearchCV$fit`
## ------------------------------------------------
rf <- RFTrainer$new()
gst <-GridSearchCV$new(trainer = rf,
parameters = list(n_estimators = c(100),
max_depth = c(5,2,10)),
n_folds = 3,
scoring = c('accuracy','auc'))
data("iris")
gst$fit(iris, "Species")
## ------------------------------------------------
## Method `GridSearchCV$best_iteration`
## ------------------------------------------------
rf <- RFTrainer$new()
gst <-GridSearchCV$new(trainer = rf,
parameters = list(n_estimators = c(100),
max_depth = c(5,2,10)),
n_folds = 3,
scoring = c('accuracy','auc'))
data("iris")
gst$fit(iris, "Species")
gst$best_iteration()
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.