R/print_methods.R

Defines functions str.tf_estimator print.tf_estimator tf_estimator_type

tf_estimator_type <- function(estimator) {
  if (inherits(estimator, "tf_estimator_regressor"))
    "regressor"
  else if (inherits(estimator, "tf_estimator_classifier"))
    "classifier"
  else
    "estimator"
}

#' @export
print.tf_estimator <- function(x, ...) {
  
  if (is.null(x$estimator) || py_is_null_xptr(x$estimator))
    return(cat("<pointer: 0x0>\n"))
  
  header <- sprintf(
    "A TensorFlow %s [%s]",
    tf_estimator_type(x),
    as.character(x$estimator)
  )
  
  model_dir <-  x$estimator$model_dir

  fields <- list(
    "Model Directory" = model_dir
  )
  
  body <- enumerate(fields, function(key, val) {
    sprintf("%s: %s", key, val)
  })

  # Model checkpoint only exists when it's been trained
  if (!is.null(latest_checkpoint(model_dir))) {
    global_step <- variable_value(x)[[graph_keys()$GLOBAL_STEP]]
    model_trained_info <- sprintf(
      "Model has been trained for %i %s.",
      as.integer(global_step),
      if (global_step > 1) "steps" else "step"
    )
  } else {
    model_trained_info <- sprintf("Model has not yet been trained.")
  }

  output <- paste(
    header,
    body,
    model_trained_info,
    sep = "\n",
    collapse = "\n"
  )
  
  cat(output, sep = "\n")
}

#' @export
str.tf_estimator <- function(object, ...) {
  paste0(capture.output(print(object)), collapse = "\n")
}
rstudio/tflearn documentation built on Nov. 25, 2021, 2:45 a.m.