GridSearchCV | R Documentation |
GridSearchCV
allows the user to specify a Grid Search schema for tuning
predictive model hyper-parameters with Cross-Validation. GridSearchCV
gives
the user complete flexibility in the predictive model and performance
metrics.
learner
Predictive modeling function.
scorer
List of performance metric functions.
splitter
Function that splits data into cross validation folds.
tune_params
Data.frame of full hyper-parameter grid created from
$tune_params
fit()
fit
tunes user-specified model hyper-parameters via Grid Search and
Cross-Validation.
GridSearchCV$fit( formula = NULL, data = NULL, x = NULL, y = NULL, progress = FALSE )
formula
An object of class formula: a symbolic description of the model to be fitted.
data
An optional data frame, or other object containing the
variables in the model. If data
is not provided, how formula
is
handled depends on $learner
.
x
Predictor data (independent variables), alternative interface to data with formula.
y
Response vector (dependent variable), alternative interface to data with formula.
progress
Logical; indicating whether to print progress across the hyper-parameter grid.
fit
follows standard R modeling convention by surfacing a formula
modeling interface as well as an alternate matrix option. The user should
use whichever interface is supported by the specified $learner
function.
An object of class FittedGridSearchCV.
if (require(rpart) && require(rsample) && require(yardstick)) { iris_new <- iris[sample(1:nrow(iris), nrow(iris)), ] iris_new$Species <- factor(iris_new$Species == "virginica") iris_train <- iris_new[1:100, ] iris_validate <- iris_new[101:150, ] ### Basic Example iris_grid_cv <- GridSearchCV$new( learner = rpart::rpart, learner_args = list(method = "class"), tune_params = list( minsplit = seq(10, 30, by = 5), maxdepth = seq(20, 30, by = 2) ), splitter = rsample::vfold_cv, splitter_args = list(v = 3), scorer = list(accuracy = yardstick::accuracy_vec), optimize_score = "max", prediction_args = list(accuracy = list(type = "class")) ) iris_grid_cv_fitted <- iris_grid_cv$fit( formula = Species ~ ., data = iris_train ) ### Example with multiple metric functions iris_grid_cv <- GridSearchCV$new( learner = rpart::rpart, learner_args = list(method = "class"), tune_params = list( minsplit = seq(10, 30, by = 5), maxdepth = seq(20, 30, by = 2) ), splitter = rsample::vfold_cv, splitter_args = list(v = 3), scorer = list( accuracy = yardstick::accuracy_vec, auc = yardstick::roc_auc_vec ), optimize_score = "max", prediction_args = list( accuracy = list(type = "class"), auc = list(type = "prob") ), convert_predictions = list( accuracy = NULL, auc = function(i) i[, "FALSE"] ) ) iris_grid_cv_fitted <- iris_grid_cv$fit( formula = Species ~ ., data = iris_train ) }
new()
Create a new GridSearchCV object.
GridSearchCV$new( learner = NULL, tune_params = NULL, splitter = NULL, scorer = NULL, optimize_score = c("min", "max"), learner_args = NULL, splitter_args = NULL, scorer_args = NULL, prediction_args = NULL, convert_predictions = NULL )
learner
Function that estimates a predictive model. It is
essential that this function support either a formula interface with
formula
and data
arguments, or an alternate matrix interface with
x
and y
arguments.
tune_params
A named list specifying the arguments of $learner
to
tune.
splitter
A function that computes cross validation folds from an
input data set or a pre-computed list of cross validation fold indices.
If splitter
is a function, it must have a data
argument for the
input data, and it must return a list of cross validation fold indices.
If splitter
is a list of integers, the number of cross validation
folds is length(splitter)
and each element contains the indices of
the data observations that are included in that fold.
scorer
A named list of metric functions to evaluate model
performance on evaluation_data
. Any provided metric function
must have truth
and estimate
arguments, for true outcome values and
predicted outcome values respectively, and must return a single numeric
metric value. The last metric function will be the one used to identify
the optimal model from the Grid Search.
optimize_score
One of "max" or "min"; Whether to maximize or
minimize the metric defined in scorer
to find the optimal Grid Search
parameters.
learner_args
A named list of additional arguments to pass to
learner
.
splitter_args
A named list of additional arguments to pass to
splitter
.
scorer_args
A named list of additional arguments to pass to
scorer
. scorer_args
must either be length 1 or length(scorer)
in
the case where different arguments are being passed to each scoring
function.
prediction_args
A named list of additional arguments to pass to
predict
. prediction_args
must either be length 1 or
length(scorer)
in the case where different arguments are being passed
to each scoring function.
convert_predictions
A list of functions to convert predicted
values prior to being evaluated by the metric functions supplied in
scorer
. This list should either be length 1, in which case the same
function will be applied to all predicted values, or length(scorer)
in which case each function in convert_predictions
will correspond
with each function in scorer
.
An object of class GridSearch.
clone()
The objects of this class are cloneable with this method.
GridSearchCV$clone(deep = FALSE)
deep
Whether to make a deep clone.
## ------------------------------------------------ ## Method `GridSearchCV$fit` ## ------------------------------------------------ if (require(rpart) && require(rsample) && require(yardstick)) { iris_new <- iris[sample(1:nrow(iris), nrow(iris)), ] iris_new$Species <- factor(iris_new$Species == "virginica") iris_train <- iris_new[1:100, ] iris_validate <- iris_new[101:150, ] ### Basic Example iris_grid_cv <- GridSearchCV$new( learner = rpart::rpart, learner_args = list(method = "class"), tune_params = list( minsplit = seq(10, 30, by = 5), maxdepth = seq(20, 30, by = 2) ), splitter = rsample::vfold_cv, splitter_args = list(v = 3), scorer = list(accuracy = yardstick::accuracy_vec), optimize_score = "max", prediction_args = list(accuracy = list(type = "class")) ) iris_grid_cv_fitted <- iris_grid_cv$fit( formula = Species ~ ., data = iris_train ) ### Example with multiple metric functions iris_grid_cv <- GridSearchCV$new( learner = rpart::rpart, learner_args = list(method = "class"), tune_params = list( minsplit = seq(10, 30, by = 5), maxdepth = seq(20, 30, by = 2) ), splitter = rsample::vfold_cv, splitter_args = list(v = 3), scorer = list( accuracy = yardstick::accuracy_vec, auc = yardstick::roc_auc_vec ), optimize_score = "max", prediction_args = list( accuracy = list(type = "class"), auc = list(type = "prob") ), convert_predictions = list( accuracy = NULL, auc = function(i) i[, "FALSE"] ) ) iris_grid_cv_fitted <- iris_grid_cv$fit( formula = Species ~ ., data = iris_train ) }
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.