Nothing
# DEEP AR TEST ----
context("Test NBEATS ENSEMBLE")
# MODEL FITTING ----
test_that("nbeats ensemble: model fitting", {
skip_if_no_gluonts()
# Model Spec
model_spec <<- nbeats(
id = "id",
freq = "M",
prediction_length = 24,
epochs = 1,
batch_size = 2,
num_batches_per_epoch = 10,
learn_rate = 0.01,
learn_rate_decay_factor = 0.25,
learn_rate_min = 2e-5,
patience = 100,
clip_gradient = 1,
penalty = 0.2,
lookback_length = list(12),
loss_function = list("MAPE"),
bagging_size = 2,
num_stacks = 10,
num_blocks = list(2)
) %>%
set_engine("gluonts_nbeats_ensemble")
# ** MODEL FIT
# Model Fit
model_fit <- model_spec %>%
fit(log(value) ~ date + id, data = training(m750_splits))
# Test print
expect_equal(print(model_fit), model_fit)
# Structure
testthat::expect_s3_class(model_fit$fit, "nbeats_ensemble_fit_impl")
testthat::expect_s3_class(model_fit$fit$data, "tbl_df")
testthat::expect_equal(names(model_fit$fit$data)[1], "date")
testthat::expect_equal(model_fit$fit$extras$id, "id")
testthat::expect_equal(model_fit$fit$extras$idx_column, "date")
testthat::expect_equal(model_fit$fit$extras$value_column, "value")
testthat::expect_s3_class(model_fit$fit$extras$constructed_tbl[[1]], "data.frame")
# $fit
testthat::expect_s3_class(model_fit$fit$models$model_1, "gluonts.model.n_beats._ensemble.NBEATSEnsemblePredictor")
testthat::expect_equal(model_fit$fit$models$model_1$freq %>% py_to_r(), "M")
testthat::expect_equal(model_fit$fit$models$model_1$prediction_length %>% py_to_r(), 24)
# $preproc
testthat::expect_equal(model_fit$preproc$y_var, "value")
# ** PREDICTIONS
# Predictions
predictions_tbl <- model_fit %>%
modeltime_calibrate(testing(m750_splits)) %>%
modeltime_forecast(new_data = testing(m750_splits))
# Structure
testthat::expect_identical(nrow(testing(m750_splits)), nrow(predictions_tbl))
testthat::expect_identical(testing(m750_splits)$date, predictions_tbl$.index)
})
# UPDATE MODEL SPEC ----
testthat::test_that("nbeats ensemble: update model spec", {
skip_if_no_gluonts()
model_spec_updated <- model_spec %>%
update(
id = "id_2",
freq = "D",
prediction_length = 36,
epochs = 2,
batch_size = 4,
num_batches_per_epoch = 6,
learn_rate = 0.0001,
learn_rate_decay_factor = 0.5,
learn_rate_min = 1e-5,
patience = 10,
clip_gradient = 10
)
expect_equal(eval_tidy(model_spec_updated$args$id), "id_2")
expect_equal(eval_tidy(model_spec_updated$args$freq), "D")
expect_equal(eval_tidy(model_spec_updated$args$prediction_length), 36)
expect_equal(eval_tidy(model_spec_updated$args$epochs), 2)
expect_equal(eval_tidy(model_spec_updated$args$batch_size), 4)
expect_equal(eval_tidy(model_spec_updated$args$num_batches_per_epoch), 6)
expect_equal(eval_tidy(model_spec_updated$args$learn_rate), 0.0001)
expect_equal(eval_tidy(model_spec_updated$args$learn_rate_decay_factor), 0.5)
expect_equal(eval_tidy(model_spec_updated$args$learn_rate_min), 1e-5)
expect_equal(eval_tidy(model_spec_updated$args$patience), 10)
expect_equal(eval_tidy(model_spec_updated$args$clip_gradient), 10)
expect_equal(eval_tidy(model_spec_updated$args$loss_function), list("MAPE"))
expect_equal(eval_tidy(model_spec_updated$args$num_stacks), 10)
expect_equal(eval_tidy(model_spec_updated$args$num_blocks), list(2))
})
# CHECKS / VALIDATIONS ----
testthat::test_that("nbeats ensemble: checks/validations", {
skip_if_no_gluonts()
# Missing prediction length
expect_error({
nbeats() %>%
set_engine("gluonts_nbeats_ensemble") %>%
fit(value ~ date + id, training(m750_splits))
})
# Missing freq
expect_error({
nbeats(prediction_length = 24) %>%
set_engine("gluonts_nbeats_ensemble") %>%
fit(value ~ date + id, training(m750_splits))
})
# Missing ID argument
expect_error({
nbeats(freq = "M", prediction_length = 24) %>%
set_engine("gluonts_nbeats_ensemble") %>%
fit(value ~ date + id, training(m750_splits))
})
# ID column not provided
expect_error({
model_spec %>%
fit(value ~ date, training(m750_splits))
})
# Date column not provided
expect_error({
model_spec %>%
fit(value ~ id, training(m750_splits))
})
})
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.