context("add_recursive_ml_forecast_model")
test_that("check add_recursive_ml_forecast_model with valid, differing, inputs", {
xregs <- list("spotprice", "gemprice", c("spotprice", "gemprice"))
for (xreg_cols in xregs) {
function_input <- tstools::initialize_ts_forecast_data(
data = dummy_gasprice,
date_col = "year_month",
col_of_interest = "gasprice",
group_cols = c("state", "oil_company"),
xreg_cols = xreg_cols
) %>%
dplyr::filter(grouping == "state = New York & oil_company = CompanyA")
ts_object_train <- function_input %>%
dplyr::slice(1:189) %>%
tstools::transform_data_to_ts_object()
ts_object_valid <- function_input %>%
dplyr::slice(190:191) %>%
tstools::transform_data_to_ts_object()
expect_silent(
function_output <- list() %>%
add_recursive_ml_forecast_model(
caret_model_tag = "svmRadialSigma",
ts_object_train = ts_object_train,
ts_object_valid = ts_object_valid,
periods_ahead = 2,
fc_name = "fc_svm_rec_1",
model_type = "multivariate",
verbose = F
) %>%
add_recursive_ml_forecast_model(
caret_model_tag = "svmRadialSigma",
ts_object_train = ts_object_train,
ts_object_valid = ts_object_valid,
periods_ahead = 2,
fc_name = "fc_svm_rec_2",
model_type = "multivariate",
verbose = F
) %>%
add_recursive_ml_forecast_model(
caret_model_tag = "svmRadialSigma",
ts_object_train = ts_object_train,
ts_object_valid = ts_object_valid,
periods_ahead = 2,
fc_name = "fc_svm_rec_2",
model_type = "multivariate",
verbose = F
)
)
expect_true(is.list(function_output))
expect_equal(names(function_output), c("fc_svm_rec_1", "fc_svm_rec_2"))
expect_equal(names(function_output$fc_svm_rec_1), c("model", "fc_data"))
expect_equal(names(function_output$fc_svm_rec_1), c("model", "fc_data"))
expect_equal(class(function_output$fc_svm_rec_1$model), "character")
expect_equal(nrow(function_output$fc_svm_rec_1$fc_data), 2)
expect_equal(ncol(function_output$fc_svm_rec_1$fc_data), 3)
expect_equal(nrow(function_output$fc_svm_rec_2$fc_data), 2)
expect_equal(ncol(function_output$fc_svm_rec_2$fc_data), 3)
expect_equal(function_output$fc_svm_rec_1$fc_data$period, c(200610, 200611))
expect_equal(function_output$fc_svm_rec_1$fc_data$fc_date %>% unique(), 200609)
expect_equal(function_output$fc_svm_rec_2$fc_data$period, c(200610, 200611))
expect_equal(function_output$fc_svm_rec_2$fc_data$fc_date %>% unique(), 200609)
}
})
test_that("check add_recursive_ml_forecast_model for univariate models with valid, differing, inputs", {
function_input <- tstools::initialize_ts_forecast_data(
data = dummy_gasprice,
date_col = "year_month",
col_of_interest = "gasprice",
group_cols = c("state", "oil_company")
) %>%
dplyr::filter(grouping == "state = New York & oil_company = CompanyA") %>%
tstools::transform_data_to_ts_object()
capture.output(
function_output <- list() %>%
add_recursive_ml_forecast_model(
caret_model_tag = "svmRadialSigma",
ts_object_train = function_input,
periods_ahead = 2,
fc_name = "fc_svm_rec_1",
model_type = "univariate",
verbose = F
) %>%
add_recursive_ml_forecast_model(
caret_model_tag = "svmRadialSigma",
ts_object_train = function_input,
periods_ahead = 2,
fc_name = "fc_svm_rec_2",
model_type = "univariate",
verbose = F
) %>%
add_recursive_ml_forecast_model(
caret_model_tag = "svmRadialSigma",
ts_object_train = function_input,
periods_ahead = 2,
fc_name = "fc_svm_rec_2",
model_type = "univariate",
verbose = F
),
file = 'NUL'
)
expect_true(is.list(function_output))
expect_equal(names(function_output), c("fc_svm_rec_1", "fc_svm_rec_2"))
expect_equal(names(function_output$fc_svm_rec_1), c("model", "fc_data"))
expect_equal(names(function_output$fc_svm_rec_1), c("model", "fc_data"))
expect_equal(class(function_output$fc_svm_rec_1$model), "character")
expect_equal(nrow(function_output$fc_svm_rec_1$fc_data), 2)
expect_equal(ncol(function_output$fc_svm_rec_1$fc_data), 3)
expect_equal(nrow(function_output$fc_svm_rec_2$fc_data), 2)
expect_equal(ncol(function_output$fc_svm_rec_2$fc_data), 3)
expect_equal(function_output$fc_svm_rec_1$fc_data$period, c(200612, 200701))
expect_equal(function_output$fc_svm_rec_1$fc_data$fc_date %>% unique(), 200611)
expect_equal(function_output$fc_svm_rec_2$fc_data$period, c(200612, 200701))
expect_equal(function_output$fc_svm_rec_2$fc_data$fc_date %>% unique(), 200611)
})
test_that("check add_recursive_ml_forecast_model for univariate models with invalid inputs", {
function_input <- tstools::initialize_ts_forecast_data(
data = dummy_gasprice,
date_col = "year_month",
col_of_interest = "gasprice",
group_cols = c("state", "oil_company"),
xreg_cols = c("spotprice", "gemprice")
) %>%
dplyr::filter(grouping == "state = New York & oil_company = CompanyA")
ts_object_train <- function_input %>%
dplyr::slice(1:189) %>%
tstools::transform_data_to_ts_object()
ts_object_valid <- function_input %>%
dplyr::slice(190:191) %>%
tstools::transform_data_to_ts_object()
expect_error(
add_recursive_ml_forecast_model(
fc_models = "potato"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = dummy_gasprice,
periods_ahead = 12,
fc_name = "fc_svm_rec_1",
model_type = "univariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
ts_object_valid = dummy_gasprice,
periods_ahead = 12,
fc_name = "fc_svm_rec_1",
model_type = "multivariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
periods_ahead = 2,
fc_name = 42,
model_type = "univariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
periods_ahead = 12,
fc_name = "fc_svm_rec_1",
model_type = "omnivariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
periods_ahead = 12,
fc_name = "fc_svm_rec_1",
model_type = "multivariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
ts_object_valid = ts_object_valid,
periods_ahead = 12,
fc_name = "fc_svm_rec_1",
model_type = "multivariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
periods_ahead = -3,
fc_name = "fc_svm_rec_1",
model_type = "univariate"
)
)
expect_error(
add_recursive_ml_forecast_model(
fc_models = list(),
ts_object_train = ts_object_train,
periods_ahead = 12,
periods_history = -3,
fc_name = "fc_svm_rec_1",
model_type = "univariate"
)
)
})
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.