tests/testthat/test-tuning.R

test_that("Tuning - random search", {
    Sys.setenv("OMP_NUM_THREADS" = 1)

    set.seed(123)
    torch::torch_manual_seed(1234)

    tuned_choice <- sits_tuning(
        samples_modis_ndvi,
        ml_method = sits_tempcnn(),
        params = sits_tuning_hparams(
            epochs = 10,
            optimizer = torch::optim_adam,
            opt_hparams = list(
                lr = choice(0.01, 0.05, 0.001, 0.0005)
            )
        ),
        trials = 2,
        multicores = 1,
        progress = FALSE
    )
    expect_true(max(tuned_choice$accuracy) > 0.5)
    expect_true(max(tuned_choice$kappa) > 0.5)

    tuned_beta <- sits_tuning(
        samples_modis_ndvi,
        ml_method = sits_tempcnn(),
        params = sits_tuning_hparams(
            epochs = 10,
            optimizer = torch::optim_adam,
            opt_hparams = list(
                lr = beta(0.2, 4)
            )
        ),
        trials = 2,
        multicores = 2,
        progress = FALSE
    )

    expect_true(max(tuned_beta$accuracy) > 0.5)
    expect_true(max(tuned_beta$kappa) > 0.5)

    tuned_unif <- sits_tuning(
        samples_modis_ndvi,
        ml_method = sits_tempcnn(),
        params = sits_tuning_hparams(
            epochs = 10,
            optimizer = torch::optim_adam,
            opt_hparams = list(
                lr = uniform(min = 0.01, max = 0.02)
            )
        ),
        trials = 2,
        multicores = 1,
        progress = FALSE
    )

    expect_true(max(tuned_unif$accuracy) > 0.5)
    expect_true(max(tuned_unif$kappa) > 0.5)

    tuned_norm <- sits_tuning(
        samples_modis_ndvi,
        ml_method = sits_tempcnn(),
        params = sits_tuning_hparams(
            epochs = 10,
            optimizer = torch::optim_adam,
            opt_hparams = list(
                lr = normal(mean = 0.01, sd = 0.002)
            )
        ),
        trials = 2,
        multicores = 1,
        progress = FALSE
    )

    expect_true(max(tuned_norm$accuracy) > 0.5)
    expect_true(max(tuned_norm$kappa) > 0.5)

    tuned_lnorm <- sits_tuning(
        samples_modis_ndvi,
        ml_method = sits_tempcnn(),
        params = sits_tuning_hparams(
            epochs = 10,
            optimizer = torch::optim_adam,
            opt_hparams = list(
                lr = lognormal(meanlog = -4.5, sdlog = 1)
            )
        ),
        trials = 2,
        multicores = 1,
        progress = FALSE
    )

    expect_true(max(tuned_lnorm$accuracy) > 0.5)
    expect_true(max(tuned_lnorm$kappa) > 0.5)

    tuned_lunif <- sits_tuning(
        samples_modis_ndvi,
        ml_method = sits_tempcnn(),
        params = sits_tuning_hparams(
            epochs = 10,
            optimizer = torch::optim_adam,
            opt_hparams = list(
                lr = loguniform(10^-2, 10^-4)
            )
        ),
        trials = 2,
        multicores = 1,
        progress = FALSE
    )

    expect_true(max(tuned_lunif$accuracy) > 0.5)
    expect_true(max(tuned_lunif$kappa) > 0.5)

    Sys.unsetenv("OMP_NUM_THREADS")
})
e-sensing/sits documentation built on May 11, 2024, 8:18 p.m.