tests/testthat/test_regr_keras.R

test_that("autotest regression custom model", {
  skip_on_os("solaris")

  model = keras_model_sequential() %>%
  layer_dense(units = 12L, input_shape = 2L, activation = "relu") %>%
  layer_dense(units = 12L, activation = "relu") %>%
  layer_dense(units = 1L, activation = "linear") %>%
    compile(optimizer = optimizer_adam(lr = 10e-3),
      loss = "mean_squared_error",
      metrics = "mean_squared_logarithmic_error")
  learner = LearnerRegrKeras$new()
  learner$param_set$values$model = model
  learner$param_set$values$epochs = 3L
  expect_learner(learner)

  result = run_autotest(learner, exclude = "(feat_single|sanity)", check_replicable = FALSE)
  expect_true(result, info = result$error)
  k_clear_session()
})

test_that("autotest low memory generator", {
  skip_on_os("solaris")
  model = keras_model_sequential() %>%
    layer_dense(units = 12L, input_shape = 2L, activation = "relu") %>%
    layer_dense(units = 12L, activation = "relu") %>%
    layer_dense(units = 1L, activation = "linear") %>%
    compile(optimizer = optimizer_adam(lr = 10e-3),
            loss = "mean_squared_error",
            metrics = "mean_squared_logarithmic_error")
  learner = LearnerRegrKeras$new()
  learner$param_set$values$model = model
  learner$param_set$values$low_memory=TRUE
  learner$param_set$values$epochs = 3L
  expect_learner(learner)

  result = run_autotest(learner, exclude = "(feat_single|sanity)", check_replicable = FALSE)
  expect_true(result, info = result$error)
  k_clear_session()
})

test_that("autotest low memory zero validation_split", {
  skip_on_os("solaris")
  model = keras_model_sequential() %>%
    layer_dense(units = 12L, input_shape = 2L, activation = "relu") %>%
    layer_dense(units = 12L, activation = "relu") %>%
    layer_dense(units = 1L, activation = "linear") %>%
    compile(optimizer = optimizer_adam(lr = 10e-3),
            loss = "mean_squared_error",
            metrics = "mean_squared_logarithmic_error")
  learner = LearnerRegrKeras$new()
  learner$param_set$values$model = model
  learner$param_set$values$low_memory=TRUE
  learner$param_set$values$validation_split=0
  learner$param_set$values$epochs = 3L
  expect_learner(learner)

  result = run_autotest(learner, exclude = "(feat_single|sanity)", check_replicable = FALSE)
  expect_true(result, info = result$error)
  k_clear_session()
})

test_that("autotest feed forward", {
  skip_on_os("solaris")
  learner = LearnerRegrKerasFF$new()
  learner$param_set$values$epochs = 3L
  expect_learner(learner)
  result = run_autotest(learner, exclude = "(feat_single|sanity)", check_replicable = FALSE)
  expect_true(result, info = result$error)
  k_clear_session()
})

test_that("Learner methods", {
  fp = tempfile(fileext = ".h5")
  lrn = lrn("regr.kerasff", epochs = 3L)
  expect_error(lrn$plot())
  expect_error(lrn$save(fp))
  lrn$train(mlr_tasks$get("mtcars"))

  # Saving to h5
  lrn$save(fp)
  expect_file_exists(fp)

  # Plotting
  p = lrn$plot()
  expect_class(p, "ggplot")


  lrn$load_model_from_file(fp)
  prd = lrn$predict(mlr_tasks$get("mtcars"))
  expect_true(inherits(prd, "Prediction"))

  unlink(fp)
  k_clear_session()
})
mlr-org/mlr3keras documentation built on April 12, 2022, 11:35 a.m.