tests/testthat/test-tf-custom-models.R

context("Testing tf_custom_models methods")

source("helper-utils.R")

test_succeeds("custom model works on iris data", {
  
  constructed_input_fn <- input_fn(
    object = iris,
    response = "Species",
    features = c(
      "Sepal.Length",
      "Sepal.Width",
      "Petal.Length",
      "Petal.Width"),
    batch_size = 10L
  )

  tmp_dir <- tempfile()
  
  # training
  classifier <- estimator(model_fn = simple_custom_model_fn, model_dir = tmp_dir) 
  classifier %>% train(input_fn = constructed_input_fn, steps = 2L)
  
  # check whether tensorboard works with custom estimator
  # tensorboard(log_dir = tmp_dir, launch_browser = FALSE)

  # predictions simplified
  predictions <- predict(classifier, input_fn = constructed_input_fn, simplify = TRUE)
  expect_equal(dim(predictions), c(150, 2))
  # predictions not simplified
  predictions <- predict(classifier, input_fn = constructed_input_fn, simplify = FALSE)
  expect_equal(length(predictions), 150)
  
  # extract predicted classes
  predicted_classes <- unlist(lapply(predictions, function(prediction) {
    prediction$class
  }))
  expect_equal(length(predicted_classes), 150)
  
  # extract predicted probabilities
  predicted_probs <- lapply(predictions, function(prediction) {
    prediction$prob
  })
  expect_equal(length(predicted_probs), 150)
  expect_equal(length(unlist(predicted_probs)), 150 * length(unique(iris$Species)))
  expect_lte(max(unlist(predicted_probs)), 1)
  expect_gte(min(unlist(predicted_probs)), 0)
  # each row of probability should sum to 1
  expect_equal(lapply(predictions, function(pred) sum(pred$prob)), rep(list(1), length(predictions)))
  
  # evaluate
  expect_equal(names(evaluate(classifier, constructed_input_fn, steps = 2L, simplify = FALSE)),
               c("loss", "global_step"))
})

Try the tfestimators package in your browser

Any scripts or data that you put into this service are public.

tfestimators documentation built on Aug. 10, 2021, 1:06 a.m.