knitr::opts_chunk$set( message = FALSE, digits = 3, collapse = TRUE, comment = "#>" ) options(digits = 3) library(dials) library(rpart)
Some statistical and machine learning models contain tuning parameters (also known as hyperparameters), which are parameters that cannot be directly estimated by the model. An example would be the number of neighbors in a K-nearest neighbors model. To determine reasonable values of these elements, some indirect method is used such as resampling or profile likelihood. Search methods, such as genetic algorithms or Bayesian search can also be used to determine good values.
In any case, some information is needed to create a grid or to validate whether a candidate value is appropriate (e.g. the number of neighbors should be a positive integer).
dials is designed to:
dialsproposes some standardized names so that the user doesn't need to memorize the syntactical minutiae of every package.
Parameter objects contain information about possible values, ranges, types, and other aspects. They have two classes: the general
param class and a more specific subclass related to the type of variable. Double and integer valued data have the subclass
quant_param while character and logicals have
qual_param. There are some common elements for each:
Otherwise, the information contained in parameter objects are different for different data types.
An example of a numeric tuning parameter is the cost-complexity parameter of CART trees, otherwise known as $C_p$. A parameter object for $C_p$ can be created in
Note that this parameter is handled in log units and the default range of values is between
0.1. The range of possible values can be returned and changed based on some utility functions. We'll use the pipe operator here:
library(dplyr) cost_complexity() %>% range_get() cost_complexity() %>% range_set(c(-5, 1)) # Or using the `range` argument # during creation cost_complexity(range = c(-5, 1))
Values for this parameter can be obtained in a few different ways. To get a sequence of values that span the range:
# Natural units: cost_complexity() %>% value_seq(n = 4) # Stay in the transformed space: cost_complexity() %>% value_seq(n = 4, original = FALSE)
Random values can be sampled too. A random uniform distribution is used (between the range values). Since this parameter has a transformation associated with it, the values are simulated in the transformed scale and then returned in the natural units (although the
original argument can be used here):
set.seed(5473) cost_complexity() %>% value_sample(n = 4)
For CART trees, there is a discrete set of values that exist for a given data set. It may be a good idea to assign these possible values to the object. We can get them by fitting an initial
rpart model and then adding the values to the object. For
mtcars, there are only three values:
library(rpart) cart_mod <- rpart(mpg ~ ., data = mtcars, control = rpart.control(cp = 0.000001)) cart_mod$cptable cp_vals <- cart_mod$cptable[, "CP"] # We should only keep values associated with at least one split: cp_vals <- cp_vals[ cart_mod$cptable[, "nsplit"] > 0 ] # Here the specific Cp values, on their natural scale, are added: mtcars_cp <- cost_complexity() %>% value_set(cp_vals)
The error occurs because the values are not in the transformed scale:
mtcars_cp <- cost_complexity() %>% value_set(log10(cp_vals)) mtcars_cp
Now, if a sequence or random sample is requested, it uses the set values:
mtcars_cp %>% value_seq(2) # Sampling specific values is done with replacement mtcars_cp %>% value_sample(20) %>% table()
Any transformations from the
scales package can be used with the numeric parameters, or a custom transformation generated with
trans_raise <- scales::trans_new( "raise", transform = function(x) 2^x , inverse = function(x) -log2(x) ) custom_cost <- cost(range = c(1, 10), trans = trans_raise) custom_cost
Note that if a transformation is used, the
range argument specifies the parameter range on the transformed scale.
For this version of
cost(), parameter values are sampled between 1 and 10 and then transformed back to the original scale by the inverse
-log2(). So on the original scale, the sampled values are between
-log2(c(10, 1)) value_sample(custom_cost, 100) %>% range()
In the discrete case there is no notion of a range. The parameter objects are defined by their discrete values. For example, consider a parameter for the types of kernel functions that is used with distance functions:
The helper functions are analogues to the quantitative parameters:
# redefine values weight_func() %>% value_set(c("rectangular", "triangular")) weight_func() %>% value_sample(3) # the sequence is returned in the order of the levels weight_func() %>% value_seq(3)
The package contains two constructors that can be used to create new quantitative and qualitative parameters,
new_qual_param(). The How to create a tuning parameter function article walks you through a detailed example.
There are some cases where the range of parameter values are data dependent. For example, the upper bound on the number of neighbors cannot be known if the number of data points in the training set is not known. For that reason, some parameters have an unknown placeholder:
mtry() sample_size() num_terms() num_comp() # and so on
These values must be initialized prior to generating parameter values. The
finalize() methods can be used to help remove the unknowns:
finalize(mtry(), x = mtcars[, -1])
These are collection of parameters used in a model, recipe, or other object. They can also be created manually and can have alternate identification fields:
glmnet_set <- parameters(list(lambda = penalty(), alpha = mixture())) glmnet_set # can be updated too update(glmnet_set, alpha = mixture(c(.3, .6)))
These objects can be very helpful when creating tuning grids.
Sets or combinations of parameters can be created for use in grid search.
grid_latin_hypercube() take any number of
param objects or a parameter set.
For example, for a glmnet model, a regular grid might be:
grid_regular( mixture(), penalty(), levels = 3 # or c(3, 4), etc )
and, similarly, a random grid is created using
set.seed(1041) grid_random( mixture(), penalty(), size = 6 )
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.