estimator: Construct a Custom Estimator

Description Usage Arguments Details Model Functions See Also

View source: R/tf_custom_estimator.R

Description

Construct a custom estimator, to be used to train and evaluate TensorFlow models.

Usage

1
2
3
4
5
6
7
estimator(
  model_fn,
  model_dir = NULL,
  config = NULL,
  params = NULL,
  class = NULL
)

Arguments

model_fn

The model function. See Model Function for details on the structure of a model function.

model_dir

Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If NULL, the model_dir in config will be used if set. If both are set, they must be same. If both are NULL, a temporary directory will be used.

config

Configuration object.

params

List of hyper parameters that will be passed into model_fn. Keys are names of parameters, values are basic python types.

class

An optional set of R classes to add to the generated object.

Details

The Estimator object wraps a model which is specified by a model_fn, which, given inputs and a number of other parameters, returns the operations necessary to perform training, evaluation, and prediction.

All outputs (checkpoints, event files, etc.) are written to model_dir, or a subdirectory thereof. If model_dir is not set, a temporary directory is used.

The config argument can be used to passed run configuration object containing information about the execution environment. It is passed on to the model_fn, if the model_fn has a parameter named "config" (and input functions in the same manner). If the config parameter is not passed, it is instantiated by estimator(). Not passing config means that defaults useful for local execution are used. estimator() makes config available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding checkpointing.

The params argument contains hyperparameters. It is passed to the model_fn, if the model_fn has a parameter named "params", and to the input functions in the same manner. estimator() only passes params along, it does not inspect it. The structure of params is therefore entirely up to the developer.

None of estimator's methods can be overridden in subclasses (its constructor enforces this). Subclasses should use model_fn to configure the base class, and may add methods implementing specialized functionality.

Model Functions

The model_fn should be an R function of the form:

1
2
3
4
5
6
7
8
function(features, labels, mode, params) {
    # 1. Configure the model via TensorFlow operations.
    # 2. Define the loss function for training and evaluation.
    # 3. Define the training optimizer.
    # 4. Define how predictions should be produced.
    # 5. Return the result as an `estimator_spec()` object.
    estimator_spec(mode, predictions, loss, train_op, eval_metric_ops)
}

The model function's inputs are defined as follows:

features The feature tensor(s).
labels The label tensor(s).
mode The current training mode ("train", "eval", "infer"). These can be accessed through the mode_keys() object.
params An optional list of hyperparameters, as received through the estimator() constructor.

See estimator_spec() for more details as to how the estimator specification should be constructed, and https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator/Estimator for more information as to how the model function should be constructed.

See Also

Other custom estimator methods: estimator_spec(), evaluate.tf_estimator(), export_savedmodel.tf_estimator(), predict.tf_estimator(), train.tf_estimator()


tfestimators documentation built on Aug. 10, 2021, 1:06 a.m.