#' Soothsayer oracle training utilities
#' @description Utilities for training oracle models
#' @param x The time series (a sorted data.frame, or a tsibble).
#' @param min_h The minimal length of the forecast
#' @param max_h The maximal length of the forecast
#' @param ... Not implemented.
#' @return A random forecast length, sampled uniformly between **min_h** and **max_h**.
#' @rdname oracle_utils
#' @export
random_forecast_h <- function( x, min_h = 1, max_h = 12, ... ) {
h <- min_h:max( min( nrow(x)/2, max_h ), min_h)
# consider sample weights?
test_size <- sample( h, size = 1)
dplyr::mutate( x, forecast_h = sample( h, size = 1))
}
handle_forecast_h <- function( x, forecast_h ) {
if( is.function(forecast_h)) {
result <- forecast_h(x)
}
else if( is.numeric(forecast_h) ) {
result <- dplyr::mutate( x, forecast_h = forecast_h)
}
else {
rlang::abort("forecast_h must be a function, or an integer value of type numeric.")
}
return(result)
}
ts_train_test <- function( ts_tbl,
forecast_h = random_forecast_h,
... ) {
time_index <- tsibble::index_var( ts_tbl )
key_var <- tsibble::key_vars(ts_tbl)
series <- tsibble::as_tibble( ts_tbl)
series <- dplyr::group_by(series, !!!tsibble::key(ts_tbl) )
series <- dplyr::group_split(series)
series <- purrr::map(series, ~ handle_forecast_h(.x, forecast_h) )
# here (see lower)
series <- dplyr::bind_rows(series)
series <- dplyr::group_by(series, !!!tsibble::key(ts_tbl) )
# group split followed by smth like .data[["index"]] to get logicals?
# slow(er) but probably works - or we could just do it dirrectly after like 40
series <- dplyr::mutate(series, train = rlang::.data$index < max( .data$index) - forecast_h + 1)
series <- dplyr::select(series, -tidyselect::all_of(c("period")) )
series <- dplyr::as_tibble(series)
series <- tsibble::as_tsibble(series, index = time_index, key = key_var )
forecast_h <- dplyr::select( as.data.frame(series),
tidyselect::all_of( c("forecast_h",
key_var ))
)
forecast_h <- dplyr::distinct(forecast_h)
series <- dplyr::select(series, -tidyselect::all_of("forecast_h") )
train <- dplyr::filter(series, train )
train <- dplyr::select(train, -tidyselect::all_of( "train" ) )
test <- dplyr::filter(series, !train )
series <- dplyr::select(series, -tidyselect::all_of( "train" ) )
return(list( train = train, test = test, forecast_h = forecast_h ))
}
fit_models <- function( train,
models = list( ar = fable::AR)) {
values_from <- tsibble::measured_vars(train)
model_set <- purrr::map( models, ~ .x( !!rlang::sym(values_from)) )
fabletools::model(train, !!!model_set, .safely = TRUE )
}
forecast_models <- function( test,
fitted_models ) {
fcst <- fabletools::forecast( fitted_models, new_data = test, times = 0 )
return( list( forecasts = fcst) )
}
all_accuracy_measures <- list(
fabletools::point_accuracy_measures,
fabletools::interval_accuracy_measures,
fabletools::distribution_accuracy_measures,
fabletools::directional_accuracy_measures
)
#' Soothsayer oracle training utilities
#' @description Utilities for training oracle models
#' @param series The series to use for modelling, a \link[tsibble]{tsibble}. Note that this must only have a
#' single measure variable.
#' @param models A named list of fable compatible model functions, such as \link[fable]{ARIMA}.
#' This must be a function which returns a valid fable model definition, preferably generated by
#' \link[fabletools]{new_model_definition}
#' @param forecast_h The forecast length h, either an integer, or a function which can return one,
#' possibly based on the underlying time series. For example, see \link[soothsayer]{random_forecast_h}.
#' @param save_forecast_experiment Whether to save the forecasting experiment, by default **TRUE**, and uses
#' quicksave (see \link[qs]{qsave})
#' @param save_outfile An optional argument to add an identifier to the saved objects.
#' @param save_folder The folder to save the objects to.
#' @param save_n_threads The number of threads to use when saving objects, see documentation for \link[qs]{qsave}.
#' @return A list with fitted models, generated forecasts and computed forecast accuracies.
#' @rdname oracle_utils
#' @export
soothsayer_forecaster <- function( series,
models = list( arima = fable::ARIMA,
snaive = fable::SNAIVE,
rw = fable::RW,
theta = fable::THETA,
ets = fable::ETS,
nnetar = fable::NNETAR,
croston = fable::CROSTON,
ar = fable::AR
# ar1 = fix_model_parameters(fable::AR, order(1)),
# ar3 = fix_model_parameters(fable::AR, order(3)),
# arma11 = fix_model_parameters(fable::ARIMA, pdq(1,0,1)),
# arma31 = fix_model_parameters(fable::ARIMA, pdf(3,0,1))
),
forecast_h = random_forecast_h,
save_forecast_experiment = TRUE,
save_outfile = "1",
save_folder = paste0("experiment_",
Sys.Date()
),
save_n_threads = 12)
{
train_test <- ts_train_test( series,
forecast_h = forecast_h )
models <- fit_models( train = train_test[["train"]],
models = models )
if( save_forecast_experiment ) {
fs::dir_create(save_folder)
fs::dir_create(paste0(save_folder,"/models"))
qs::qsave( models, file = paste0( save_folder,"/models/models_",
save_outfile,".qs" ),
nthreads = save_n_threads)
}
forecasts <- forecast_models( train_test[["test"]], models[["models"]] )
if( save_forecast_experiment ) {
fs::dir_create(paste0(save_folder,"/forecasts"))
qs::qsave( forecasts, file = paste0( save_folder,"/forecasts/forecasts_",
save_outfile,".qs" ),
nthreads = save_n_threads)
}
accuracies <- fabletools::accuracy(forecasts[["forecasts"]],
train_test[["test"]],
measures = all_accuracy_measures)
accuracies <- dplyr::left_join( accuracies, train_test[["forecast_h"]], by = "key" )
if( save_forecast_experiment ) {
fs::dir_create(paste0(save_folder,"/accuracies"))
qs::qsave( accuracies, file = paste0( save_folder,
"/accuracies/accuracies_",
save_outfile,".qs" ),
nthreads = save_n_threads)
}
return(list( models = models,
forecasts = forecasts,
accuracies = accuracies ))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.