context("Testing tf_custom_models methods")
source("helper-utils.R")
test_succeeds("custom model works on iris data", {
constructed_input_fn <- input_fn(
object = iris,
response = "Species",
features = c(
"Sepal.Length",
"Sepal.Width",
"Petal.Length",
"Petal.Width"),
batch_size = 10L
)
tmp_dir <- tempfile()
# training
classifier <- estimator(model_fn = simple_custom_model_fn, model_dir = tmp_dir)
classifier %>% train(input_fn = constructed_input_fn, steps = 2L)
# check whether tensorboard works with custom estimator
# tensorboard(log_dir = tmp_dir, launch_browser = FALSE)
# predictions simplified
predictions <- predict(classifier, input_fn = constructed_input_fn, simplify = TRUE)
expect_equal(dim(predictions), c(150, 2))
# predictions not simplified
predictions <- predict(classifier, input_fn = constructed_input_fn, simplify = FALSE)
expect_equal(length(predictions), 150)
# extract predicted classes
predicted_classes <- unlist(lapply(predictions, function(prediction) {
prediction$class
}))
expect_equal(length(predicted_classes), 150)
# extract predicted probabilities
predicted_probs <- lapply(predictions, function(prediction) {
prediction$prob
})
expect_equal(length(predicted_probs), 150)
expect_equal(length(unlist(predicted_probs)), 150 * length(unique(iris$Species)))
expect_lte(max(unlist(predicted_probs)), 1)
expect_gte(min(unlist(predicted_probs)), 0)
# each row of probability should sum to 1
expect_equal(lapply(predictions, function(pred) sum(pred$prob)), rep(list(1), length(predictions)))
# evaluate
expect_equal(names(evaluate(classifier, constructed_input_fn, steps = 2L, simplify = FALSE)),
c("loss", "global_step"))
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.