R/tf_estimator.R

Defines functions variable_value variable_names export_savedmodel.tf_estimator evaluate.tf_estimator predict.tf_estimator train.tf_estimator is.tf_custom_estimator new_tf_classifier new_tf_regressor new_tf_estimator

Documented in evaluate.tf_estimator export_savedmodel.tf_estimator predict.tf_estimator train.tf_estimator variable_names variable_value

new_tf_estimator <- function(estimator, args = NULL, ..., 
                             subclass = NULL) {
  structure(
    list(
      estimator = estimator,
      args = args,
      ...),
    class = c(subclass, "tf_estimator")
  )
}

new_tf_regressor <- function(estimator, args = NULL, ..., subclass = NULL) {
  new_tf_estimator(estimator, args = args, ..., 
                   subclass = c(subclass, "tf_estimator_regressor"))
}

new_tf_classifier <- function(estimator, args = NULL, ..., subclass = NULL) {
  new_tf_estimator(estimator, args = args, ..., 
                   subclass = c(subclass, "tf_estimator_classifier"))
}

is.tf_custom_estimator <- function(object) {
  inherits(object, "tf_custom_estimator")
}

#' Base Documentation for Canned Estimators
#' 
#' @param object A TensorFlow estimator.
#' 
#' @param feature_columns An \R list containing all of the feature columns used
#'   by the model (typically, generated by [feature_columns()]).
#' 
#' @param model_dir Directory to save the model parameters, graph, and so on.
#'   This can also be used to load checkpoints from the directory into a
#'   estimator to continue training a previously saved model.
#'
#' @param label_dimension Number of regression targets per example. This is the 
#'   size of the last dimension of the labels and logits `Tensor` objects 
#'   (typically, these have shape `[batch_size, label_dimension]`).
#'   
#' @param label_vocabulary A list of strings represents possible label values.
#'   If given, labels must be string type and have any value in
#'   `label_vocabulary`. If it is not given, that means labels are already
#'   encoded as integer or float within `[0, 1]` for `n_classes == 2` and
#'   encoded as integer values in `{0, 1,..., n_classes  -1}` for `n_classes >
#'   2`. Also there will be errors if vocabulary is not provided and labels are
#'   string.
#'
#' @param weight_column A string, or a numeric column created by
#'   [column_numeric()] defining feature column representing weights. It is used
#'   to down weight or boost examples during training. It will be multiplied by
#'   the loss of the example. If it is a string, it is used as a key to fetch
#'   weight tensor from the `features` argument. If it is a numeric column,
#'   then the raw tensor is fetched by key `weight_column$key`, then
#'   `weight_column$normalizer_fn` is applied on it to get weight tensor.
#'   
#' @param n_classes The number of label classes.
#'   
#' @param config A run configuration created by [run_config()], used to configure the runtime
#'   settings.
#'   
#' @param input_layer_partitioner An optional partitioner for the input layer.
#'   Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
#' 
#' @param partitioner An optional partitioner for the input layer.
#' 
#' @name estimators
NULL

#' Base Documentation for train, evaluate, and predict. 
#' 
#' @param input_fn An input function, typically generated by the [input_fn()]
#'   helper function.
#'
#' @param hooks A list of \R functions, to be used as callbacks inside the
#'   training loop. By default, `hook_history_saver(every_n_step = 10)` and
#'   `hook_progress_bar()` will be attached if not provided to save the metrics
#'   history and create the progress bar.
#'
#' @param checkpoint_path The path to a specific model checkpoint to be used for
#'   prediction. If `NULL` (the default), the latest checkpoint in `model_dir`
#'   is used.
#'
#' @name train-evaluate-predict
NULL

#' Train an Estimator
#'
#' Train an estimator on a set of input data provides by the `input_fn()`.
#'
#' @inheritParams train-evaluate-predict
#'
#' @template roxlate-object-estimator
#'
#' @param steps The number of steps for which the model should be trained on
#'   this particular `train()` invocation. If `NULL` (the default), this
#'   function will either train forever, or until the supplied `input_fn()` has
#'   provided all available data.
#' @param max_steps The total number of steps for which the model should be
#'   trained. If set, `steps` must be `NULL`. If the estimator has already been
#'   trained a total of `max_steps` times, then no training will be performed.
#' @param saving_listeners (Available since TensorFlow v1.4) A list of
#'   `CheckpointSaverListener` objects used for callbacks that run immediately
#'   before or after checkpoint savings.
#' @param ... Optional arguments, passed on to the estimator's `train()` method.
#'
#' @return A data.frame of the training loss history.
#' @export
#' @family custom estimator methods
train.tf_estimator <- function(object,
                               input_fn,
                               steps = NULL,
                               hooks = NULL,
                               max_steps = NULL,
                               saving_listeners = NULL,
                               ...)
{
  args <- list(
    input_fn = normalize_input_fn(object, input_fn),
    steps = cast_nullable_scalar_integer(steps),
    max_steps = cast_nullable_scalar_integer(max_steps),
    ...
    )
  
  if (tf_version() >= '1.4') {
    args$saving_listeners <- saving_listeners
  }

  args$hooks <- resolve_train_hooks(hooks, steps)
  
  with_logging_verbosity(tf$logging$WARN, {
    do.call(object$estimator$train, args)
  })

  # move tfevents file to a separate /logs folder under model_dir
  mv_tf_events_file(model_dir(object))
  
  history <- new_tf_estimator_history(.globals$history[[mode_keys()$TRAIN]])
  tfruns::write_run_metadata("metrics", compose_history_metadata(history))
  invisible(history)
}

#' Generate Predictions with an Estimator
#'
#' Generate predicted labels / values for input data provided by `input_fn()`.
#'
#' @inheritParams train-evaluate-predict
#'
#' @template roxlate-object-estimator
#'
#' @param predict_keys The types of predictions that should be produced, as an
#'   \R list. When this argument is not specified (the default), all possible
#'   predicted values will be returned.
#' @param simplify Whether to simplify prediction results into a \code{tibble},
#'   as opposed to a list. Defaults to \code{TRUE}.
#' @param as_iterable Boolean; should a raw Python generator be returned? When
#'   `FALSE` (the default), the predicted values will be consumed from the
#'   generator and returned as an \R object.
#' @param yield_single_examples (Available since TensorFlow v1.7) If `FALSE`,
#'   yields the whole batch as returned by the `model_fn` instead of decomposing
#'   the batch into individual elements. This is useful if `model_fn` returns some
#'   tensors with first dimension not equal to the batch size.
#' @param ... Optional arguments passed on to the estimator's `predict()`
#'   method.
#'
#' @section Yields: Evaluated values of `predictions` tensors.
#'
#' @section Raises: ValueError: Could not find a trained model in model_dir.
#'   ValueError: if batch length of predictions are not same. ValueError: If
#'   there is a conflict between `predict_keys` and `predictions`. For example
#'   if `predict_keys` is not `NULL` but `EstimatorSpec.predictions` is not a
#'   `dict`.
#'
#' @export
#' @family custom estimator methods
predict.tf_estimator <- function(object,
                                 input_fn,
                                 checkpoint_path = NULL,
                                 predict_keys = c("predictions", "classes", "class_ids", "logistic", "logits", "probabilities"),
                                 hooks = NULL,
                                 as_iterable = FALSE,
                                 simplify = TRUE,
                                 yield_single_examples = TRUE,
                                 ...)
{
  predict_keys <- resolve_predict_keys(match.arg(predict_keys, several.ok = TRUE))
  args <- list(
    input_fn = normalize_input_fn(object, input_fn),
    checkpoint_path = checkpoint_path,
    hooks = normalize_session_run_hooks(hooks),
    predict_keys = predict_keys,
    ...
  )
  if (tf_version() >= '1.7') {
    args$yield_single_examples <- yield_single_examples
  }

  predictions <- do.call(object$estimator$predict, args)

  if (!as_iterable) {
    if (!any(inherits(predictions, "python.builtin.iterator"),
             inherits(predictions, "python.builtin.generator"))) {
      warning("predictions are not iterable, no need to convert again")
    } else {
      predictions <- iterate(predictions)
      
      # convert Python bytestrings back into R strings
      for (i in seq_along(predictions)) {
        classes <- predictions[[i]]$classes
        if (is.list(classes)) {
          
          isBytes <- vapply(classes, function(class) {
            inherits(class, "python.builtin.bytes")
          }, logical(1))
          
          if (all(isBytes)) {
            decoded <- vapply(classes, function(class) {
              class$decode()
            }, character(1))
            predictions[[i]]$classes <- decoded
          }
        }
      }
    }
  }
  
  simplify_results(predictions, simplify)
}

#' Evaluate an Estimator
#'
#' Evaluate an estimator on input data provided by an `input_fn()`.
#'
#' For each step, this method will call `input_fn()` to produce a single batch
#' of data. Evaluation continues until:
#'
#' - `steps` batches are processed, or
#' - The `input_fn()` is exhausted of data.
#'
#' @inheritParams train-evaluate-predict
#'
#' @template roxlate-object-estimator
#'
#' @param name Name of the evaluation if user needs to run multiple evaluations
#'   on different data sets, such as on training data vs test data. Metrics for
#'   different evaluations are saved in separate folders, and appear separately
#'   in tensorboard.
#' @param steps The number of steps for which the model should be evaluated on
#'   this particular `evaluate()` invocation. If `NULL` (the default), this function
#'   will either evaluate forever, or until the supplied `input_fn()` has provided
#'   all available data.
#' @param simplify Whether to simplify evaluation results into a \code{tibble}, as 
#'   opposed to a list. Defaults to \code{TRUE}.
#' @param ... Optional arguments passed on to the estimator's `evaluate()`
#'   method.
#'
#' @return An \R list of evaluation metrics.
#'
#' @export
#' @family custom estimator methods
evaluate.tf_estimator <- function(object,
                                  input_fn,
                                  steps = NULL,
                                  checkpoint_path = NULL,
                                  name = NULL,
                                  hooks = NULL,
                                  simplify = TRUE,
                                  ...)
{
  evaluation_results <- with_logging_verbosity(tf$logging$WARN, {
    object$estimator$evaluate(
      input_fn = normalize_input_fn(object, input_fn),
      steps = cast_nullable_scalar_integer(steps),
      checkpoint_path = checkpoint_path,
      name = name,
      hooks = resolve_eval_hooks(hooks, steps),
      ...
    )
  })
  
  tfruns::write_run_metadata("evaluation", evaluation_results)
  simplify_results(evaluation_results, simplify)
}

#' Save an Estimator
#'
#' Save an estimator (alongside its weights) to the directory `export_dir_base`.
#'
#' @details
#'
#' This method builds a new graph by first calling the serving_input_receiver_fn
#' to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to
#' generate the model graph based on those features. It restores the given
#' checkpoint (or, lacking that, the most recent checkpoint) into this graph in
#' a fresh session. Finally it creates a timestamped export directory below the
#' given export_dir_base, and writes a `SavedModel` into it containing a single
#' `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will
#' provide one `SignatureDef` for each element of the export_outputs dict
#' returned from the model_fn, named using the same keys. One of these keys is
#' always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating
#' which signature will be served when a serving request does not specify one.
#' For each signature, the outputs are provided by the corresponding
#' `ExportOutput`s, and the inputs are always the input receivers provided by
#' the serving_input_receiver_fn. Extra assets may be written into the
#' SavedModel via the extra_assets argument. This should be a dict, where each
#' key gives a destination path (including the filename) relative to the
#' assets.extra directory. The corresponding value gives the full path of the
#' source file to be copied. For example, the simple case of copying a single
#' file without renaming it is specified as `{'my_asset_file.txt':
#' '/path/to/my_asset_file.txt'}`.
#'
#' @template roxlate-object-estimator
#'
#' @param export_dir_base A string containing a directory in which to export the
#'   SavedModel.
#' @param serving_input_receiver_fn A function that takes no argument and
#'   returns a `ServingInputReceiver`. Required for custom models.
#' @param assets_extra A dict specifying how to populate the assets.extra
#'   directory within the exported SavedModel, or `NULL` if no extra assets are
#'   needed.
#' @param as_text whether to write the SavedModel proto in text format.
#' @param checkpoint_path The checkpoint path to export. If `NULL` (the
#'   default), the most recent checkpoint found within the model directory is
#'   chosen.
#' @param overwrite Should the \code{export_dir} directory be overwritten?
#' @param versioned Should the model be exported under a versioned subdirectory?
#' @param ... Optional arguments passed on to the estimator's
#'   `export_savedmodel()` method.
#'
#' @return The path to the exported directory, as a string.
#'
#' @section Raises: ValueError: if no serving_input_receiver_fn is provided, no
#'   export_outputs are provided, or no checkpoint can be found.
#'
#' @export
#' @family custom estimator methods
export_savedmodel.tf_estimator <- function(object,
                                           export_dir_base,
                                           serving_input_receiver_fn = NULL,
                                           assets_extra = NULL,
                                           as_text = FALSE,
                                           checkpoint_path = NULL,
                                           overwrite = TRUE,
                                           versioned = !overwrite,
                                           ...)
{
  if (!overwrite && !versioned && file.exists(export_dir_base))
    stop("Path '", export_dir_base, "' already exists, use 'overwrite = TRUE' instead.")
  
  if (is.null(serving_input_receiver_fn)) {
    if (is.tf_custom_estimator(object))
      stop("A 'tf_custom_estimator' requires a custom `serving_input_receiver_fn`.")
    
    feature_columns_spec <- c(
      object$args$dnn_feature_columns,
      object$args$linear_feature_columns,
      object$args$feature_columns
      )
    
    if (length(grep("regressor", class(object))) == 0 &&
        length(grep("classifier", class(object))) == 0) {
      stop("Currently only classifier and regressor are supported. Please specify a custom serving_input_receiver_fn. ")
    }
    
    if (tf_version() < '1.4') {
      if (length(grep("regressor", class(object))) != 0) {
        input_spec <- regressor_parse_example_spec(
          feature_columns = feature_columns_spec,
          weight_column = object$args$weight_column,
          label_key = "label"
        )
      } else {
        input_spec <- classifier_parse_example_spec(
          feature_columns = feature_columns_spec,
          weight_column = object$args$weight_column,
          label_key = "label"
        )
      }
      
      serving_input_receiver_fn <- tf$estimator$export$build_parsing_serving_input_receiver_fn(input_spec)
    } else {
      features <- list()
      for (feature in feature_columns_spec) {
        default_tensor <- tf$constant(value = feature$dtype$min, shape = shape(1, feature$shape))
        
        # first dimension is variable since it's required by cloudml-like interfaces to push multiple instances
        features[[feature$name]] <- tf$placeholder_with_default(
          input = default_tensor,
          shape = shape(NULL, feature$shape)
        )
      }
      
      serving_input_receiver_fn <- tf$estimator$export$build_raw_serving_input_receiver_fn(features)
    }
  }
  
  export_target <- if (versioned) export_dir_base else tempdir()
  
  export_result <- object$estimator$export_savedmodel(
    export_dir_base = export_target,
    serving_input_receiver_fn = serving_input_receiver_fn,
    assets_extra = assets_extra,
    as_text = as_text,
    checkpoint_path = checkpoint_path,
    ...
  )
  
  if (!versioned) {
    if (overwrite && file.exists(export_dir_base))
      unlink(export_dir_base, recursive = TRUE)
    
    if (!file.exists(export_dir_base))
      dir.create(export_dir_base, recursive = TRUE)
    
    file.copy(file.path(export_result, "."), export_dir_base, recursive = TRUE)
    unlink(export_result, recursive = TRUE)
    export_result <- export_dir_base
  }

  invisible(export_result)
}

#' Get variable names and values associated with an estimator
#' 
#' These helper functions extract the names and values of variables
#'   in the graphs associated with trained estimator models.
#'   
#' @name variable_names_values
#' @param object A trained estimator model.
#' @return For \code{variable_names()}, a vector of variable names. For \code{variable_values()}, a named list of variable values.
#' @export
variable_names <- function(object) {
  model_dir <- object$estimator$model_dir
  if (!length(list.files(model_dir)))
    stop("'variable_names()' must be called on a trained model")

  if (tensorflow::tf_version() >= "1.4") {
    object$estimator$get_variable_names()
  } else {
    model_dir %>%
      list_variable_names() %>%
      unlist()
  }
}

#' @rdname variable_names_values
#' @param variable (Optional) Names of variables to extract as a character vector. If not specified, values for all variables are returned.
#' @export
variable_value <- function(object, variable = NULL) {
  model_dir <- object$estimator$model_dir
  if (!length(list.files(model_dir)))
    stop("'variable_value()' must be called on a trained model")

  variable_names <- variable_names(object)
  
  if (!is.null(variable)) {
    not_found <- variable[!(variable %in% variable_names)] %>%
      unlist()
    if (length(not_found))
      stop("Variable not found: ", paste0(not_found, collapse = ", "))
  } else {
    variable <- variable_names
  }
  
  if (tensorflow::tf_version() >= "1.4") {
    variable %>%
      lapply(object$estimator$get_variable_value) %>%
      rlang::set_names(unlist(variable))
  } else {
    ckp <- model_dir %>%
      latest_checkpoint() %>%
      tf$python$training$checkpoint_utils$load_checkpoint()
    variable %>%
      lapply(function(var_name) ckp$get_tensor(var_name[[1]])) %>%
      rlang::set_names(unlist(variable))
  }
}

Try the tfestimators package in your browser

Any scripts or data that you put into this service are public.

tfestimators documentation built on Aug. 10, 2021, 1:06 a.m.