tests/testthat/test-flags.R

context("flags")

test_that("flags can be defined", {
  with_tests_dir({
    FLAGS <- define_flags()
    expect_equivalent(FLAGS, readRDS("flags.rds"))
  })
})

test_that("flags are assigned the correct types", {
  with_tests_dir({
    FLAGS <- define_flags()
    expect_type(FLAGS$learning_rate, "double")
    expect_type(FLAGS$max_steps, "integer")
    expect_type(FLAGS$data_dir, "character")
    expect_type(FLAGS$fake_data, "logical")
  })
})

test_that("flags_parse returns defaults when there are no overrides", {
  with_tests_dir({
    FLAGS <- define_flags()
    expect_equivalent(FLAGS, readRDS("flags.rds"))
  })
})

test_that("flags_parse overrides based on command line args", {
  with_tests_dir({
    FLAGS <- define_flags(arguments = c("--learning-rate", "0.02"))
    expect_equal(FLAGS$learning_rate, 0.02)
  })
})

test_that("flags_parse throws an error for unknown command line args", {
  with_tests_dir({
    expect_error({
      FLAGS <- define_flags(arguments = c("--learn-rate", "0.02"))
    })
  })
})

test_that("flags_parse overrides based on config file values", {
  with_tests_dir({
    FLAGS <- define_flags(file = "flags-override.yml")
    expect_equal(FLAGS$learning_rate, 0.02)
    FLAGS <- define_flags(file = "flags-profile-override.yml", config = "myconfig")
    expect_equal(FLAGS$learning_rate, 0.03)
  })
})

test_that("flags_parse skips --args for passthrough args", {
  with_tests_dir({
    FLAGS <- flags(
      flag_numeric("gradient_descent_optimizer", 0.5),
      arguments = list(
        "--gradient-descent-optimizer",
        "0.47",
        "--args",
        "--job-dir",
        "gs://rstudio-cloudml/mnist/staging/2")
    )

    expect_equal(FLAGS$gradient_descent_optimizer, 0.47)
  })
})
rstudio/tfruns documentation built on Feb. 6, 2024, 11:29 a.m.