tests/testthat/test-algo-nbeats-ensemble.R

# DEEP AR TEST ----
context("Test NBEATS ENSEMBLE")


# gc()
# py_gc <- reticulate::import("gc")
# py_gc$collect()

# MODEL FITTING ----

test_that("nbeats ensemble: model fitting", {

    skip_if_no_gluonts()

    # skip_on_ci()

    # Model Spec
    model_spec <- nbeats(
        id                      = "id",
        freq                    = "M",
        prediction_length       = 24,
        epochs                  = 1,
        batch_size              = 2,
        num_batches_per_epoch   = 10,
        learn_rate              = 0.01,
        learn_rate_decay_factor = 0.25,
        learn_rate_min          = 2e-5,
        patience                = 100,
        clip_gradient           = 1,
        penalty                 = 0.2,

        lookback_length         = list(12),
        loss_function           = list("MAPE"),
        bagging_size            = 2,
        num_stacks              = 10,
        num_blocks              = list(2),

        scale                   = TRUE
    ) %>%
        set_engine("gluonts_nbeats_ensemble")

    # ** MODEL FIT

    # Model Fit
    model_fit <- model_spec %>%
        fit(log(value) ~ date + id, data = training(m750_splits))

    # Test print
    expect_equal(print(model_fit), model_fit)

    # Structure

    testthat::expect_s3_class(model_fit$fit, "nbeats_ensemble_fit_impl")

    testthat::expect_s3_class(model_fit$fit$data, "tbl_df")

    testthat::expect_equal(names(model_fit$fit$data)[1], "date")

    testthat::expect_equal(model_fit$fit$extras$id, "id")
    testthat::expect_equal(model_fit$fit$extras$idx_column, "date")
    testthat::expect_equal(model_fit$fit$extras$value_column, "value")
    testthat::expect_s3_class(model_fit$fit$extras$constructed_tbl, "data.frame")
    testthat::expect_s3_class(model_fit$fit$extras$scale_params, "data.frame")

    # $fit

    # testthat::expect_s3_class(model_fit$fit$models$model_1, "gluonts.model.n_beats._ensemble.NBEATSEnsemblePredictor")

    testthat::expect_equal(model_fit$fit$models$model_1$freq %>% py_to_r(), "M")
    testthat::expect_equal(model_fit$fit$models$model_1$prediction_length %>% py_to_r(), 24)



    # $preproc

    testthat::expect_equal(model_fit$preproc$y_var, "value")


    # ** PREDICTIONS

    # Predictions
    predictions_tbl <- model_fit %>%
        modeltime_calibrate(testing(m750_splits)) %>%
        modeltime_forecast(new_data = testing(m750_splits))

    # Structure
    testthat::expect_identical(nrow(testing(m750_splits)), nrow(predictions_tbl))
    testthat::expect_identical(testing(m750_splits)$date, predictions_tbl$.index)

    # UPDATE MODEL SPEC ----

    skip_if_no_gluonts()

    model_spec_updated <- model_spec %>%
        update(
            id                      = "id_2",
            freq                    = "D",
            prediction_length       = 36,
            epochs                  = 2,
            batch_size              = 4,
            num_batches_per_epoch   = 6,
            learn_rate              = 0.0001,
            learn_rate_decay_factor = 0.5,
            learn_rate_min          = 1e-5,
            patience                = 10,
            clip_gradient           = 10
        )

    expect_equal(eval_tidy(model_spec_updated$args$id), "id_2")
    expect_equal(eval_tidy(model_spec_updated$args$freq), "D")
    expect_equal(eval_tidy(model_spec_updated$args$prediction_length), 36)
    expect_equal(eval_tidy(model_spec_updated$args$epochs), 2)
    expect_equal(eval_tidy(model_spec_updated$args$batch_size), 4)
    expect_equal(eval_tidy(model_spec_updated$args$num_batches_per_epoch), 6)
    expect_equal(eval_tidy(model_spec_updated$args$learn_rate), 0.0001)
    expect_equal(eval_tidy(model_spec_updated$args$learn_rate_decay_factor), 0.5)
    expect_equal(eval_tidy(model_spec_updated$args$learn_rate_min), 1e-5)
    expect_equal(eval_tidy(model_spec_updated$args$patience), 10)
    expect_equal(eval_tidy(model_spec_updated$args$clip_gradient), 10)

    expect_equal(eval_tidy(model_spec_updated$args$loss_function), list("MAPE"))
    expect_equal(eval_tidy(model_spec_updated$args$num_stacks), 10)
    expect_equal(eval_tidy(model_spec_updated$args$num_blocks), list(2))

    # CHECKS / VALIDATIONS ----

    skip_if_no_gluonts()

    # Missing prediction length
    expect_error({
        nbeats() %>%
            set_engine("gluonts_nbeats_ensemble") %>%
            fit(value ~ date + id, training(m750_splits))
    })

    # Missing freq
    expect_error({
        nbeats(prediction_length = 24) %>%
            set_engine("gluonts_nbeats_ensemble") %>%
            fit(value ~ date + id, training(m750_splits))
    })

    # Missing ID argument
    expect_error({
        nbeats(freq = "M", prediction_length = 24) %>%
            set_engine("gluonts_nbeats_ensemble") %>%
            fit(value ~ date + id, training(m750_splits))
    })

    # ID column not provided
    expect_error({
        model_spec %>%
            fit(value ~ date, training(m750_splits))
    })

    # Date column not provided
    expect_error({
        model_spec %>%
            fit(value ~ id, training(m750_splits))
    })
})
business-science/modeltime.gluonts documentation built on Jan. 20, 2024, 3:59 a.m.