Nothing
new_tf_estimator <- function(estimator, args = NULL, ...,
subclass = NULL) {
structure(
list(
estimator = estimator,
args = args,
...),
class = c(subclass, "tf_estimator")
)
}
new_tf_regressor <- function(estimator, args = NULL, ..., subclass = NULL) {
new_tf_estimator(estimator, args = args, ...,
subclass = c(subclass, "tf_estimator_regressor"))
}
new_tf_classifier <- function(estimator, args = NULL, ..., subclass = NULL) {
new_tf_estimator(estimator, args = args, ...,
subclass = c(subclass, "tf_estimator_classifier"))
}
is.tf_custom_estimator <- function(object) {
inherits(object, "tf_custom_estimator")
}
#' Base Documentation for Canned Estimators
#'
#' @param object A TensorFlow estimator.
#'
#' @param feature_columns An \R list containing all of the feature columns used
#' by the model (typically, generated by [feature_columns()]).
#'
#' @param model_dir Directory to save the model parameters, graph, and so on.
#' This can also be used to load checkpoints from the directory into a
#' estimator to continue training a previously saved model.
#'
#' @param label_dimension Number of regression targets per example. This is the
#' size of the last dimension of the labels and logits `Tensor` objects
#' (typically, these have shape `[batch_size, label_dimension]`).
#'
#' @param label_vocabulary A list of strings represents possible label values.
#' If given, labels must be string type and have any value in
#' `label_vocabulary`. If it is not given, that means labels are already
#' encoded as integer or float within `[0, 1]` for `n_classes == 2` and
#' encoded as integer values in `{0, 1,..., n_classes -1}` for `n_classes >
#' 2`. Also there will be errors if vocabulary is not provided and labels are
#' string.
#'
#' @param weight_column A string, or a numeric column created by
#' [column_numeric()] defining feature column representing weights. It is used
#' to down weight or boost examples during training. It will be multiplied by
#' the loss of the example. If it is a string, it is used as a key to fetch
#' weight tensor from the `features` argument. If it is a numeric column,
#' then the raw tensor is fetched by key `weight_column$key`, then
#' `weight_column$normalizer_fn` is applied on it to get weight tensor.
#'
#' @param n_classes The number of label classes.
#'
#' @param config A run configuration created by [run_config()], used to configure the runtime
#' settings.
#'
#' @param input_layer_partitioner An optional partitioner for the input layer.
#' Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
#'
#' @param partitioner An optional partitioner for the input layer.
#'
#' @name estimators
NULL
#' Base Documentation for train, evaluate, and predict.
#'
#' @param input_fn An input function, typically generated by the [input_fn()]
#' helper function.
#'
#' @param hooks A list of \R functions, to be used as callbacks inside the
#' training loop. By default, `hook_history_saver(every_n_step = 10)` and
#' `hook_progress_bar()` will be attached if not provided to save the metrics
#' history and create the progress bar.
#'
#' @param checkpoint_path The path to a specific model checkpoint to be used for
#' prediction. If `NULL` (the default), the latest checkpoint in `model_dir`
#' is used.
#'
#' @name train-evaluate-predict
NULL
#' Train an Estimator
#'
#' Train an estimator on a set of input data provides by the `input_fn()`.
#'
#' @inheritParams train-evaluate-predict
#'
#' @template roxlate-object-estimator
#'
#' @param steps The number of steps for which the model should be trained on
#' this particular `train()` invocation. If `NULL` (the default), this
#' function will either train forever, or until the supplied `input_fn()` has
#' provided all available data.
#' @param max_steps The total number of steps for which the model should be
#' trained. If set, `steps` must be `NULL`. If the estimator has already been
#' trained a total of `max_steps` times, then no training will be performed.
#' @param saving_listeners (Available since TensorFlow v1.4) A list of
#' `CheckpointSaverListener` objects used for callbacks that run immediately
#' before or after checkpoint savings.
#' @param ... Optional arguments, passed on to the estimator's `train()` method.
#'
#' @return A data.frame of the training loss history.
#' @export
#' @family custom estimator methods
train.tf_estimator <- function(object,
input_fn,
steps = NULL,
hooks = NULL,
max_steps = NULL,
saving_listeners = NULL,
...)
{
args <- list(
input_fn = normalize_input_fn(object, input_fn),
steps = cast_nullable_scalar_integer(steps),
max_steps = cast_nullable_scalar_integer(max_steps),
...
)
if (tf_version() >= '1.4') {
args$saving_listeners <- saving_listeners
}
args$hooks <- resolve_train_hooks(hooks, steps)
with_logging_verbosity(tf$logging$WARN, {
do.call(object$estimator$train, args)
})
# move tfevents file to a separate /logs folder under model_dir
mv_tf_events_file(model_dir(object))
history <- new_tf_estimator_history(.globals$history[[mode_keys()$TRAIN]])
tfruns::write_run_metadata("metrics", compose_history_metadata(history))
invisible(history)
}
#' Generate Predictions with an Estimator
#'
#' Generate predicted labels / values for input data provided by `input_fn()`.
#'
#' @inheritParams train-evaluate-predict
#'
#' @template roxlate-object-estimator
#'
#' @param predict_keys The types of predictions that should be produced, as an
#' \R list. When this argument is not specified (the default), all possible
#' predicted values will be returned.
#' @param simplify Whether to simplify prediction results into a \code{tibble},
#' as opposed to a list. Defaults to \code{TRUE}.
#' @param as_iterable Boolean; should a raw Python generator be returned? When
#' `FALSE` (the default), the predicted values will be consumed from the
#' generator and returned as an \R object.
#' @param yield_single_examples (Available since TensorFlow v1.7) If `FALSE`,
#' yields the whole batch as returned by the `model_fn` instead of decomposing
#' the batch into individual elements. This is useful if `model_fn` returns some
#' tensors with first dimension not equal to the batch size.
#' @param ... Optional arguments passed on to the estimator's `predict()`
#' method.
#'
#' @section Yields: Evaluated values of `predictions` tensors.
#'
#' @section Raises: ValueError: Could not find a trained model in model_dir.
#' ValueError: if batch length of predictions are not same. ValueError: If
#' there is a conflict between `predict_keys` and `predictions`. For example
#' if `predict_keys` is not `NULL` but `EstimatorSpec.predictions` is not a
#' `dict`.
#'
#' @export
#' @family custom estimator methods
predict.tf_estimator <- function(object,
input_fn,
checkpoint_path = NULL,
predict_keys = c("predictions", "classes", "class_ids", "logistic", "logits", "probabilities"),
hooks = NULL,
as_iterable = FALSE,
simplify = TRUE,
yield_single_examples = TRUE,
...)
{
predict_keys <- resolve_predict_keys(match.arg(predict_keys, several.ok = TRUE))
args <- list(
input_fn = normalize_input_fn(object, input_fn),
checkpoint_path = checkpoint_path,
hooks = normalize_session_run_hooks(hooks),
predict_keys = predict_keys,
...
)
if (tf_version() >= '1.7') {
args$yield_single_examples <- yield_single_examples
}
predictions <- do.call(object$estimator$predict, args)
if (!as_iterable) {
if (!any(inherits(predictions, "python.builtin.iterator"),
inherits(predictions, "python.builtin.generator"))) {
warning("predictions are not iterable, no need to convert again")
} else {
predictions <- iterate(predictions)
# convert Python bytestrings back into R strings
for (i in seq_along(predictions)) {
classes <- predictions[[i]]$classes
if (is.list(classes)) {
isBytes <- vapply(classes, function(class) {
inherits(class, "python.builtin.bytes")
}, logical(1))
if (all(isBytes)) {
decoded <- vapply(classes, function(class) {
class$decode()
}, character(1))
predictions[[i]]$classes <- decoded
}
}
}
}
}
simplify_results(predictions, simplify)
}
#' Evaluate an Estimator
#'
#' Evaluate an estimator on input data provided by an `input_fn()`.
#'
#' For each step, this method will call `input_fn()` to produce a single batch
#' of data. Evaluation continues until:
#'
#' - `steps` batches are processed, or
#' - The `input_fn()` is exhausted of data.
#'
#' @inheritParams train-evaluate-predict
#'
#' @template roxlate-object-estimator
#'
#' @param name Name of the evaluation if user needs to run multiple evaluations
#' on different data sets, such as on training data vs test data. Metrics for
#' different evaluations are saved in separate folders, and appear separately
#' in tensorboard.
#' @param steps The number of steps for which the model should be evaluated on
#' this particular `evaluate()` invocation. If `NULL` (the default), this function
#' will either evaluate forever, or until the supplied `input_fn()` has provided
#' all available data.
#' @param simplify Whether to simplify evaluation results into a \code{tibble}, as
#' opposed to a list. Defaults to \code{TRUE}.
#' @param ... Optional arguments passed on to the estimator's `evaluate()`
#' method.
#'
#' @return An \R list of evaluation metrics.
#'
#' @export
#' @family custom estimator methods
evaluate.tf_estimator <- function(object,
input_fn,
steps = NULL,
checkpoint_path = NULL,
name = NULL,
hooks = NULL,
simplify = TRUE,
...)
{
evaluation_results <- with_logging_verbosity(tf$logging$WARN, {
object$estimator$evaluate(
input_fn = normalize_input_fn(object, input_fn),
steps = cast_nullable_scalar_integer(steps),
checkpoint_path = checkpoint_path,
name = name,
hooks = resolve_eval_hooks(hooks, steps),
...
)
})
tfruns::write_run_metadata("evaluation", evaluation_results)
simplify_results(evaluation_results, simplify)
}
#' Save an Estimator
#'
#' Save an estimator (alongside its weights) to the directory `export_dir_base`.
#'
#' @details
#'
#' This method builds a new graph by first calling the serving_input_receiver_fn
#' to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to
#' generate the model graph based on those features. It restores the given
#' checkpoint (or, lacking that, the most recent checkpoint) into this graph in
#' a fresh session. Finally it creates a timestamped export directory below the
#' given export_dir_base, and writes a `SavedModel` into it containing a single
#' `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will
#' provide one `SignatureDef` for each element of the export_outputs dict
#' returned from the model_fn, named using the same keys. One of these keys is
#' always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating
#' which signature will be served when a serving request does not specify one.
#' For each signature, the outputs are provided by the corresponding
#' `ExportOutput`s, and the inputs are always the input receivers provided by
#' the serving_input_receiver_fn. Extra assets may be written into the
#' SavedModel via the extra_assets argument. This should be a dict, where each
#' key gives a destination path (including the filename) relative to the
#' assets.extra directory. The corresponding value gives the full path of the
#' source file to be copied. For example, the simple case of copying a single
#' file without renaming it is specified as `{'my_asset_file.txt':
#' '/path/to/my_asset_file.txt'}`.
#'
#' @template roxlate-object-estimator
#'
#' @param export_dir_base A string containing a directory in which to export the
#' SavedModel.
#' @param serving_input_receiver_fn A function that takes no argument and
#' returns a `ServingInputReceiver`. Required for custom models.
#' @param assets_extra A dict specifying how to populate the assets.extra
#' directory within the exported SavedModel, or `NULL` if no extra assets are
#' needed.
#' @param as_text whether to write the SavedModel proto in text format.
#' @param checkpoint_path The checkpoint path to export. If `NULL` (the
#' default), the most recent checkpoint found within the model directory is
#' chosen.
#' @param overwrite Should the \code{export_dir} directory be overwritten?
#' @param versioned Should the model be exported under a versioned subdirectory?
#' @param ... Optional arguments passed on to the estimator's
#' `export_savedmodel()` method.
#'
#' @return The path to the exported directory, as a string.
#'
#' @section Raises: ValueError: if no serving_input_receiver_fn is provided, no
#' export_outputs are provided, or no checkpoint can be found.
#'
#' @export
#' @family custom estimator methods
export_savedmodel.tf_estimator <- function(object,
export_dir_base,
serving_input_receiver_fn = NULL,
assets_extra = NULL,
as_text = FALSE,
checkpoint_path = NULL,
overwrite = TRUE,
versioned = !overwrite,
...)
{
if (!overwrite && !versioned && file.exists(export_dir_base))
stop("Path '", export_dir_base, "' already exists, use 'overwrite = TRUE' instead.")
if (is.null(serving_input_receiver_fn)) {
if (is.tf_custom_estimator(object))
stop("A 'tf_custom_estimator' requires a custom `serving_input_receiver_fn`.")
feature_columns_spec <- c(
object$args$dnn_feature_columns,
object$args$linear_feature_columns,
object$args$feature_columns
)
if (length(grep("regressor", class(object))) == 0 &&
length(grep("classifier", class(object))) == 0) {
stop("Currently only classifier and regressor are supported. Please specify a custom serving_input_receiver_fn. ")
}
if (tf_version() < '1.4') {
if (length(grep("regressor", class(object))) != 0) {
input_spec <- regressor_parse_example_spec(
feature_columns = feature_columns_spec,
weight_column = object$args$weight_column,
label_key = "label"
)
} else {
input_spec <- classifier_parse_example_spec(
feature_columns = feature_columns_spec,
weight_column = object$args$weight_column,
label_key = "label"
)
}
serving_input_receiver_fn <- tf$estimator$export$build_parsing_serving_input_receiver_fn(input_spec)
} else {
features <- list()
for (feature in feature_columns_spec) {
default_tensor <- tf$constant(value = feature$dtype$min, shape = shape(1, feature$shape))
# first dimension is variable since it's required by cloudml-like interfaces to push multiple instances
features[[feature$name]] <- tf$placeholder_with_default(
input = default_tensor,
shape = shape(NULL, feature$shape)
)
}
serving_input_receiver_fn <- tf$estimator$export$build_raw_serving_input_receiver_fn(features)
}
}
export_target <- if (versioned) export_dir_base else tempdir()
export_result <- object$estimator$export_savedmodel(
export_dir_base = export_target,
serving_input_receiver_fn = serving_input_receiver_fn,
assets_extra = assets_extra,
as_text = as_text,
checkpoint_path = checkpoint_path,
...
)
if (!versioned) {
if (overwrite && file.exists(export_dir_base))
unlink(export_dir_base, recursive = TRUE)
if (!file.exists(export_dir_base))
dir.create(export_dir_base, recursive = TRUE)
file.copy(file.path(export_result, "."), export_dir_base, recursive = TRUE)
unlink(export_result, recursive = TRUE)
export_result <- export_dir_base
}
invisible(export_result)
}
#' Get variable names and values associated with an estimator
#'
#' These helper functions extract the names and values of variables
#' in the graphs associated with trained estimator models.
#'
#' @name variable_names_values
#' @param object A trained estimator model.
#' @return For \code{variable_names()}, a vector of variable names. For \code{variable_values()}, a named list of variable values.
#' @export
variable_names <- function(object) {
model_dir <- object$estimator$model_dir
if (!length(list.files(model_dir)))
stop("'variable_names()' must be called on a trained model")
if (tensorflow::tf_version() >= "1.4") {
object$estimator$get_variable_names()
} else {
model_dir %>%
list_variable_names() %>%
unlist()
}
}
#' @rdname variable_names_values
#' @param variable (Optional) Names of variables to extract as a character vector. If not specified, values for all variables are returned.
#' @export
variable_value <- function(object, variable = NULL) {
model_dir <- object$estimator$model_dir
if (!length(list.files(model_dir)))
stop("'variable_value()' must be called on a trained model")
variable_names <- variable_names(object)
if (!is.null(variable)) {
not_found <- variable[!(variable %in% variable_names)] %>%
unlist()
if (length(not_found))
stop("Variable not found: ", paste0(not_found, collapse = ", "))
} else {
variable <- variable_names
}
if (tensorflow::tf_version() >= "1.4") {
variable %>%
lapply(object$estimator$get_variable_value) %>%
rlang::set_names(unlist(variable))
} else {
ckp <- model_dir %>%
latest_checkpoint() %>%
tf$python$training$checkpoint_utils$load_checkpoint()
variable %>%
lapply(function(var_name) ckp$get_tensor(var_name[[1]])) %>%
rlang::set_names(unlist(variable))
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.