predict.tf_estimator: Generate Predictions with an Estimator

Description Usage Arguments Yields Raises See Also

View source: R/tf_estimator.R

Description

Generate predicted labels / values for input data provided by input_fn().

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
## S3 method for class 'tf_estimator'
predict(
  object,
  input_fn,
  checkpoint_path = NULL,
  predict_keys = c("predictions", "classes", "class_ids", "logistic", "logits",
    "probabilities"),
  hooks = NULL,
  as_iterable = FALSE,
  simplify = TRUE,
  yield_single_examples = TRUE,
  ...
)

Arguments

object

A TensorFlow estimator.

input_fn

An input function, typically generated by the input_fn() helper function.

checkpoint_path

The path to a specific model checkpoint to be used for prediction. If NULL (the default), the latest checkpoint in model_dir is used.

predict_keys

The types of predictions that should be produced, as an R list. When this argument is not specified (the default), all possible predicted values will be returned.

hooks

A list of R functions, to be used as callbacks inside the training loop. By default, hook_history_saver(every_n_step = 10) and hook_progress_bar() will be attached if not provided to save the metrics history and create the progress bar.

as_iterable

Boolean; should a raw Python generator be returned? When FALSE (the default), the predicted values will be consumed from the generator and returned as an R object.

simplify

Whether to simplify prediction results into a tibble, as opposed to a list. Defaults to TRUE.

yield_single_examples

(Available since TensorFlow v1.7) If FALSE, yields the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors with first dimension not equal to the batch size.

...

Optional arguments passed on to the estimator's predict() method.

Yields

Evaluated values of predictions tensors.

Raises

ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between predict_keys and predictions. For example if predict_keys is not NULL but EstimatorSpec.predictions is not a dict.

See Also

Other custom estimator methods: estimator_spec(), estimator(), evaluate.tf_estimator(), export_savedmodel.tf_estimator(), train.tf_estimator()


rstudio/tfestimators documentation built on Nov. 24, 2021, 6:56 a.m.