tests/testthat/test-ensemble_average.R

context("USED TO SET UP MODELS FOR TESTS")
context("TEST: ensemble_average()")

library(testthat)

# USED FOR TESTS ----
wflw_fit_arima <- workflow() %>%
    add_model(
        spec = arima_reg(seasonal_period = 12) %>% set_engine("auto_arima")
    ) %>%
    add_recipe(
        recipe = recipe(value ~ date, data = training(m750_splits))
    ) %>%
    fit(training(m750_splits))

wflw_fit_prophet <- workflow() %>%
    add_model(
        spec = prophet_reg() %>% set_engine("prophet")
    ) %>%
    add_recipe(
        recipe = recipe(value ~ date, data = training(m750_splits))
    ) %>%
    fit(training(m750_splits))

rec_glmnet <- recipe(value ~ date, data = training(m750_splits)) %>%
    step_timeseries_signature(date) %>%
    step_rm(matches("(iso$)|(xts$)|(am.pm)|(hour$)|(minute)|(second)")) %>%
    step_zv(all_predictors()) %>%
    step_normalize(all_numeric_predictors()) %>%
    step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
    step_rm(date)

# rec_glmnet %>% prep() %>% juice() %>% glimpse()


wflw_fit_glmnet <- workflow() %>%
    add_model(
        spec = linear_reg(penalty = 0.1) %>% set_engine("glmnet")
    ) %>%
    add_recipe(
        recipe = rec_glmnet
    ) %>%
    fit(training(m750_splits))

m750_models_2 <- modeltime_table(
    wflw_fit_arima,
    wflw_fit_prophet,
    wflw_fit_glmnet
)


# TEST ENSEMBLE AVERAGE ----

# Median ----
test_that("ensemble_average(type = 'median')", {

    testthat::skip_on_cran()

    ensemble_fit_median <- m750_models_2 %>%
        ensemble_average(type = "median")

    # Structure
    expect_s3_class(ensemble_fit_median, "mdl_time_ensemble")
    expect_s3_class(ensemble_fit_median, "mdl_time_ensemble_avg")

    expect_s3_class(ensemble_fit_median$model_tbl, "mdl_time_tbl")
    expect_equal(ensemble_fit_median$parameters$type, "median")
    expect_equal(ensemble_fit_median$n_models, 3)
    expect_equal(ensemble_fit_median$desc, "ENSEMBLE (MEDIAN): 3 MODELS")

    # Print
    expect_equal(print(ensemble_fit_median), ensemble_fit_median)

    # Modeltime Table
    expect_equal(
        modeltime_table(ensemble_fit_median) %>% pull(.model_desc),
        "ENSEMBLE (MEDIAN): 3 MODELS"
    )

    # Calibration
    calibration_tbl <- ensemble_fit_median %>%
        modeltime_calibrate(testing(m750_splits))

    expect_false(is.na(calibration_tbl$.type))

    # Accuracy
    accuracy_tbl <- calibration_tbl %>% modeltime_accuracy()

    expect_false(is.na(accuracy_tbl$mae))
    expect_true(accuracy_tbl$mae < 300)

    # Forecast
    forecast_tbl <- calibration_tbl %>%
        modeltime_forecast(
            new_data    = testing(m750_splits),
            actual_data = m750
        )

    n_actual <- nrow(m750)

    expect_equal(nrow(forecast_tbl), 24 + n_actual)
    expect_equal(ncol(forecast_tbl), 7)

    # Forecast - Test Keep New Data
    forecast_tbl <- calibration_tbl %>%
        modeltime_forecast(
            new_data    = testing(m750_splits),
            actual_data = m750,
            keep_data   = TRUE
        )
    # forecast_tbl %>% group_by(id) %>% plot_modeltime_forecast()

    expect_equal(nrow(forecast_tbl), 24 + n_actual)
    expect_equal(ncol(forecast_tbl), 10)

    # Refit
    refit_tbl <- calibration_tbl %>%
        modeltime_refit(m750, control = control_refit())

    # Refit in Parallel ----
    parallel_start(2)
    refit_tbl <- calibration_tbl %>%
        modeltime_refit(
            m750,
            control = control_refit(
                verbose = TRUE,
                allow_par = TRUE,
                cores = 2,
                packages = "modeltime.ensemble")
        )
    parallel_stop()

    training_results_tbl <- refit_tbl %>%
        pluck(".model", 1, "model_tbl", ".model", 1, "fit", "fit", "fit", "data")

    expect_equal(nrow(training_results_tbl), nrow(m750))

})

# Mean ----
test_that("ensemble_average(type = 'mean')", {

    ensemble_fit_mean <- m750_models_2 %>%
        ensemble_average(type = "mean")

    # Structure
    expect_s3_class(ensemble_fit_mean, "mdl_time_ensemble")
    expect_s3_class(ensemble_fit_mean, "mdl_time_ensemble_avg")

    expect_s3_class(ensemble_fit_mean$model_tbl, "mdl_time_tbl")
    expect_equal(ensemble_fit_mean$parameters$type, "mean")
    expect_equal(ensemble_fit_mean$n_models, 3)
    expect_equal(ensemble_fit_mean$desc, "ENSEMBLE (MEAN): 3 MODELS")


    # Forecast
    fcast <- modeltime_table(ensemble_fit_mean) %>%
        modeltime_forecast(testing(m750_splits))

    expect_equal(nrow(fcast), nrow(testing(m750_splits)))
    expect_equal(fcast$.index, testing(m750_splits)$date)


})

# Checks/Errors ----
test_that("Checks/Errors: ensemble_average()", {

    # Object is Missing
    expect_error(ensemble_average())

    # Incorrect Object
    expect_error(ensemble_average(1))

    # Incorrect Type
    expect_error(ensemble_average(m750_models_2, type = "blah"))

    # Needs more than 1 model
    expect_error({
        m750_models_2 %>%
            dplyr::slice(1) %>%
            ensemble_average(type = "mean")
    })


})

Try the modeltime.ensemble package in your browser

Any scripts or data that you put into this service are public.

modeltime.ensemble documentation built on April 18, 2023, 5:09 p.m.