tests/testthat/test-ensemble_weighted.R

context("TEST: ensemble_weighted()")

# TEST ENSEMBLE AVERAGE ----

# USED FOR TESTS ----
wflw_fit_arima <- workflow() %>%
    add_model(
        spec = arima_reg(seasonal_period = 12) %>% set_engine("auto_arima")
    ) %>%
    add_recipe(
        recipe = recipe(value ~ date, data = training(m750_splits))
    ) %>%
    fit(training(m750_splits))

wflw_fit_prophet <- workflow() %>%
    add_model(
        spec = prophet_reg() %>% set_engine("prophet")
    ) %>%
    add_recipe(
        recipe = recipe(value ~ date, data = training(m750_splits))
    ) %>%
    fit(training(m750_splits))

rec_glmnet <- recipe(value ~ date, data = training(m750_splits)) %>%
    step_timeseries_signature(date) %>%
    step_rm(matches("(iso$)|(xts$)|(am.pm)|(hour$)|(minute)|(second)")) %>%
    step_zv(all_predictors()) %>%
    step_normalize(all_numeric_predictors()) %>%
    step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
    step_rm(date)

# rec_glmnet %>% prep() %>% juice() %>% glimpse()


wflw_fit_glmnet <- workflow() %>%
    add_model(
        spec = linear_reg(penalty = 0.1) %>% set_engine("glmnet")
    ) %>%
    add_recipe(
        recipe = rec_glmnet
    ) %>%
    fit(training(m750_splits))

m750_models_2 <- modeltime_table(
    wflw_fit_arima,
    wflw_fit_prophet,
    wflw_fit_glmnet
)


# Median ----
test_that("ensemble_weighted()", {

    loadings <- c(3,2,1)

    ensemble_fit_wt <- m750_models_2 %>%
        ensemble_weighted(loadings = loadings)

    # Structure
    expect_s3_class(ensemble_fit_wt, "mdl_time_ensemble")
    expect_s3_class(ensemble_fit_wt, "mdl_time_ensemble_wt")

    expect_s3_class(ensemble_fit_wt$model_tbl, "mdl_time_tbl")
    expect_equal(ensemble_fit_wt$parameters$loadings, loadings)
    expect_true(ensemble_fit_wt$parameters$scale_loadings)

    expect_equal(ensemble_fit_wt$fit$loadings_tbl$.loadings, loadings / sum(loadings))

    expect_equal(ensemble_fit_wt$n_models, 3)
    expect_equal(ensemble_fit_wt$desc, "ENSEMBLE (WEIGHTED): 3 MODELS")

    # Print
    expect_equal(print(ensemble_fit_wt), ensemble_fit_wt)

    # Modeltime Table
    expect_equal(
        modeltime_table(ensemble_fit_wt) %>% pull(.model_desc),
        "ENSEMBLE (WEIGHTED): 3 MODELS"
    )

    # Forecast
    fcast <- modeltime_table(ensemble_fit_wt) %>%
        modeltime_forecast(testing(m750_splits))

    expect_equal(nrow(fcast), nrow(testing(m750_splits)))
    expect_equal(fcast$.index, testing(m750_splits)$date)

    # Calibration
    calibration_tbl <- ensemble_fit_wt %>%
        modeltime_calibrate(testing(m750_splits))

    expect_false(is.na(calibration_tbl$.type))

    # Accuracy
    accuracy_tbl <- calibration_tbl %>% modeltime_accuracy()

    expect_false(is.na(accuracy_tbl$mae))

    # Forecast
    forecast_tbl <- calibration_tbl %>%
        modeltime_forecast(
            new_data    = testing(m750_splits),
            actual_data = m750
        )

    n_actual <- nrow(m750)

    expect_equal(nrow(forecast_tbl), 24 + n_actual)
    expect_equal(ncol(forecast_tbl), 7)

    # Forecast - Test Keep New Data
    forecast_tbl <- calibration_tbl %>%
        modeltime_forecast(
            new_data    = testing(m750_splits),
            actual_data = m750,
            keep_data   = TRUE
        )

    expect_equal(nrow(forecast_tbl), 24 + n_actual)
    expect_equal(ncol(forecast_tbl), 10)

    # Refit
    refit_tbl <- calibration_tbl %>%
        modeltime_refit(m750)

    training_results_tbl <- refit_tbl %>%
        pluck(".model", 1, "model_tbl", ".model", 1, "fit", "fit", "fit", "data")

    expect_equal(nrow(training_results_tbl), nrow(m750))

})



# Checks/Errors ----
test_that("Checks/Errors: ensemble_weighted()", {

    # Object is Missing
    expect_error(ensemble_weighted())

    # Incorrect Object
    expect_error(ensemble_weighted(1))

    # No loadings
    expect_error(ensemble_weighted(m750_models_2))

    # Needs correct number of loadings
    expect_error({
        m750_models_2 %>%
            ensemble_weighted(loadings = 1)
    })

    # Needs more than one model
    expect_error({
        m750_models_2 %>%
            slice(1) %>%
            ensemble_weighted(loadings = 1:3)
    })

})

Try the modeltime.ensemble package in your browser

Any scripts or data that you put into this service are public.

modeltime.ensemble documentation built on April 18, 2023, 5:09 p.m.