GridSearchCV: Tune Predictive Model Hyper-parameters with Grid Search and...

GridSearchCVR Documentation

Tune Predictive Model Hyper-parameters with Grid Search and Cross-Validation

Description

GridSearchCV allows the user to specify a Grid Search schema for tuning predictive model hyper-parameters with Cross-Validation. GridSearchCV gives the user complete flexibility in the predictive model and performance metrics.

Public fields

learner

Predictive modeling function.

scorer

List of performance metric functions.

splitter

Function that splits data into cross validation folds.

tune_params

Data.frame of full hyper-parameter grid created from $tune_params

Methods

Public methods


Method fit()

fit tunes user-specified model hyper-parameters via Grid Search and Cross-Validation.

Usage
GridSearchCV$fit(
  formula = NULL,
  data = NULL,
  x = NULL,
  y = NULL,
  progress = FALSE
)
Arguments
formula

An object of class formula: a symbolic description of the model to be fitted.

data

An optional data frame, or other object containing the variables in the model. If data is not provided, how formula is handled depends on $learner.

x

Predictor data (independent variables), alternative interface to data with formula.

y

Response vector (dependent variable), alternative interface to data with formula.

progress

Logical; indicating whether to print progress across the hyper-parameter grid.

Details

fit follows standard R modeling convention by surfacing a formula modeling interface as well as an alternate matrix option. The user should use whichever interface is supported by the specified $learner function.

Returns

An object of class FittedGridSearchCV.

Examples
if (require(rpart) && require(rsample) && require(yardstick)) {

  iris_new <- iris[sample(1:nrow(iris), nrow(iris)), ]
  iris_new$Species <- factor(iris_new$Species == "virginica")
  iris_train <- iris_new[1:100, ]
  iris_validate <- iris_new[101:150, ]

  ### Basic Example

  iris_grid_cv <- GridSearchCV$new(
    learner = rpart::rpart,
    learner_args = list(method = "class"),
    tune_params = list(
      minsplit = seq(10, 30, by = 5),
      maxdepth = seq(20, 30, by = 2)
    ),
    splitter = rsample::vfold_cv,
    splitter_args = list(v = 3),
    scorer = list(accuracy = yardstick::accuracy_vec),
    optimize_score = "max",
    prediction_args = list(accuracy = list(type = "class"))
  )
  iris_grid_cv_fitted <- iris_grid_cv$fit(
    formula = Species ~ .,
    data = iris_train
  )

  ### Example with multiple metric functions

  iris_grid_cv <- GridSearchCV$new(
    learner = rpart::rpart,
    learner_args = list(method = "class"),
    tune_params = list(
      minsplit = seq(10, 30, by = 5),
      maxdepth = seq(20, 30, by = 2)
    ),
    splitter = rsample::vfold_cv,
    splitter_args = list(v = 3),
    scorer = list(
      accuracy = yardstick::accuracy_vec,
      auc = yardstick::roc_auc_vec
    ),
    optimize_score = "max",
    prediction_args = list(
      accuracy = list(type = "class"),
      auc = list(type = "prob")
    ),
    convert_predictions = list(
      accuracy = NULL,
      auc = function(i) i[, "FALSE"]
    )
  )
  iris_grid_cv_fitted <- iris_grid_cv$fit(
    formula = Species ~ .,
    data = iris_train
  )
}

Method new()

Create a new GridSearchCV object.

Usage
GridSearchCV$new(
  learner = NULL,
  tune_params = NULL,
  splitter = NULL,
  scorer = NULL,
  optimize_score = c("min", "max"),
  learner_args = NULL,
  splitter_args = NULL,
  scorer_args = NULL,
  prediction_args = NULL,
  convert_predictions = NULL
)
Arguments
learner

Function that estimates a predictive model. It is essential that this function support either a formula interface with formula and data arguments, or an alternate matrix interface with x and y arguments.

tune_params

A named list specifying the arguments of $learner to tune.

splitter

A function that computes cross validation folds from an input data set or a pre-computed list of cross validation fold indices. If splitter is a function, it must have a data argument for the input data, and it must return a list of cross validation fold indices. If splitter is a list of integers, the number of cross validation folds is length(splitter) and each element contains the indices of the data observations that are included in that fold.

scorer

A named list of metric functions to evaluate model performance on evaluation_data. Any provided metric function must have truth and estimate arguments, for true outcome values and predicted outcome values respectively, and must return a single numeric metric value. The last metric function will be the one used to identify the optimal model from the Grid Search.

optimize_score

One of "max" or "min"; Whether to maximize or minimize the metric defined in scorer to find the optimal Grid Search parameters.

learner_args

A named list of additional arguments to pass to learner.

splitter_args

A named list of additional arguments to pass to splitter.

scorer_args

A named list of additional arguments to pass to scorer. scorer_args must either be length 1 or length(scorer) in the case where different arguments are being passed to each scoring function.

prediction_args

A named list of additional arguments to pass to predict. prediction_args must either be length 1 or length(scorer) in the case where different arguments are being passed to each scoring function.

convert_predictions

A list of functions to convert predicted values prior to being evaluated by the metric functions supplied in scorer. This list should either be length 1, in which case the same function will be applied to all predicted values, or length(scorer) in which case each function in convert_predictions will correspond with each function in scorer.

Returns

An object of class GridSearch.


Method clone()

The objects of this class are cloneable with this method.

Usage
GridSearchCV$clone(deep = FALSE)
Arguments
deep

Whether to make a deep clone.

Examples


## ------------------------------------------------
## Method `GridSearchCV$fit`
## ------------------------------------------------

if (require(rpart) && require(rsample) && require(yardstick)) {

  iris_new <- iris[sample(1:nrow(iris), nrow(iris)), ]
  iris_new$Species <- factor(iris_new$Species == "virginica")
  iris_train <- iris_new[1:100, ]
  iris_validate <- iris_new[101:150, ]

  ### Basic Example

  iris_grid_cv <- GridSearchCV$new(
    learner = rpart::rpart,
    learner_args = list(method = "class"),
    tune_params = list(
      minsplit = seq(10, 30, by = 5),
      maxdepth = seq(20, 30, by = 2)
    ),
    splitter = rsample::vfold_cv,
    splitter_args = list(v = 3),
    scorer = list(accuracy = yardstick::accuracy_vec),
    optimize_score = "max",
    prediction_args = list(accuracy = list(type = "class"))
  )
  iris_grid_cv_fitted <- iris_grid_cv$fit(
    formula = Species ~ .,
    data = iris_train
  )

  ### Example with multiple metric functions

  iris_grid_cv <- GridSearchCV$new(
    learner = rpart::rpart,
    learner_args = list(method = "class"),
    tune_params = list(
      minsplit = seq(10, 30, by = 5),
      maxdepth = seq(20, 30, by = 2)
    ),
    splitter = rsample::vfold_cv,
    splitter_args = list(v = 3),
    scorer = list(
      accuracy = yardstick::accuracy_vec,
      auc = yardstick::roc_auc_vec
    ),
    optimize_score = "max",
    prediction_args = list(
      accuracy = list(type = "class"),
      auc = list(type = "prob")
    ),
    convert_predictions = list(
      accuracy = NULL,
      auc = function(i) i[, "FALSE"]
    )
  )
  iris_grid_cv_fitted <- iris_grid_cv$fit(
    formula = Species ~ .,
    data = iris_train
  )
}

dmolitor/modelselection documentation built on Jan. 4, 2023, 1:08 p.m.