tests/testthat/test-schedulers.R

library(purrr)

test_that("scheduling functions", {

 x <- 0:100

 # ------------------------------------------------------------------------------

 expect_equal(
  map_dbl(x, schedule_decay_expo),
  0.1 * exp(-x)
 )

 expect_equal(
  map_dbl(x, schedule_decay_expo, initial = 1/3, decay = 7/8),
  1 / 3 * exp(-7 / 8 * x)
 )

 expect_snapshot_error(schedule_decay_expo(1, initial = -1))
 expect_snapshot_error(schedule_decay_expo(1, decay = -1))

 # ------------------------------------------------------------------------------

 expect_equal(
  map_dbl(x, schedule_decay_time),
  0.1 / (1 + x)
 )

 expect_equal(
  map_dbl(x, schedule_decay_time, initial = 1/3, decay = 7/8),
  1 / 3 / (1 + 7 / 8 * x)
 )

 expect_snapshot_error(schedule_decay_time(1, initial = -1))
 expect_snapshot_error(schedule_decay_time(1, decay = -1))

 # ------------------------------------------------------------------------------

 expect_equal(
  map_dbl(x, schedule_step),
  0.1 * (1 / 2) ^ floor(x / 5)
 )

 expect_equal(
  map_dbl(x, schedule_step, initial = 1/3, reduction = 7/8, steps = 3),
  1 / 3 * (7 / 8) ^ floor(x / 3)
 )

 expect_snapshot_error(schedule_step(1, initial = -1))
 expect_snapshot_error(schedule_step(1, reduction = -1))
 expect_snapshot_error(schedule_step(1, steps = -1))

 # ------------------------------------------------------------------------------

 expect_true( all(map_dbl(x[x %% 10 == 0], schedule_cyclic) == 0.001) )

 inc <- 0.0198
 expect_equal(
  abs(diff(map_dbl(x, schedule_cyclic))),
  rep(inc, 100),
  tolerance = 0.001
 )

 expect_equal(
  sign(diff(map_dbl(x, schedule_cyclic))),
  rep(rep(c(1, -1), each = 5), times = 10),
  tolerance = 0.001
 )

 expect_true( all(map_dbl(x[x %% 20 == 0], schedule_cyclic, step_size = 10) == 0.001) )


 expect_snapshot_error(schedule_cyclic(1, step_size = -1))
 expect_snapshot_error(schedule_cyclic(1, largest = -1))

 # ------------------------------------------------------------------------------

 expect_equal(set_learn_rate(.x, 1, type = "none"), 1)
 expect_equal(set_learn_rate(.x, 0.01, type = "none", potato = 1), .01)

 expect_equal(
  map_dbl(x, schedule_decay_time, initial = 1/3, decay = 7/8),
  map_dbl(x, ~ set_learn_rate(.x, 0.1, "decay_time", initial = 1/3, decay = 7/8))
 )

 expect_equal(
  map_dbl(x, schedule_decay_expo, initial = 1/3, decay = 7/8),
  map_dbl(x, ~ set_learn_rate(.x, 0.1, "decay_expo", initial = 1/3, decay = 7/8))
 )

 expect_equal(
  map_dbl(x, schedule_step, initial = 1/3, reduction = 7/8, steps = 3),
  map_dbl(x, ~ set_learn_rate(.x, 0.1, "step", initial = 1/3, reduction = 7/8, steps = 3))
 )

 expect_snapshot_error(set_learn_rate(1, 1, type = "decay_time", initial = -1))
 expect_snapshot_error(set_learn_rate(1, 1, type = "random"))

})
tidymodels/lantern documentation built on Feb. 28, 2024, 12:59 a.m.