experiments/deprecated/forecast.R

#' Soothsayer oracle training utilities
#' @description Utilities for training oracle models
#' @param x The time series (a sorted data.frame, or a tsibble).
#' @param min_h The minimal length of the forecast
#' @param max_h The maximal length of the forecast
#' @param ... Not implemented.
#' @return A random forecast length, sampled uniformly between **min_h** and **max_h**.
#' @rdname oracle_utils
#' @export
random_forecast_h <- function( x, min_h = 1, max_h = 12, ... ) {
  h <- min_h:max( min( nrow(x)/2, max_h ), min_h)
  # consider sample weights?
  test_size <- sample( h, size = 1)
  dplyr::mutate( x, forecast_h = sample( h, size = 1))
}

handle_forecast_h <- function( x, forecast_h ) {
  if( is.function(forecast_h)) {
    result <- forecast_h(x)
  }
  else if( is.numeric(forecast_h) ) {
    result <- dplyr::mutate( x, forecast_h = forecast_h)
  }
  else {
    rlang::abort("forecast_h must be a function, or an integer value of type numeric.")
  }
  return(result)
}

ts_train_test <- function( ts_tbl,
                           forecast_h = random_forecast_h,
                           ... ) {

  time_index <- tsibble::index_var( ts_tbl )
  key_var <- tsibble::key_vars(ts_tbl)

  series <- tsibble::as_tibble( ts_tbl)
  series <- dplyr::group_by(series, !!!tsibble::key(ts_tbl) )
  series <- dplyr::group_split(series)
  series <- purrr::map(series, ~ handle_forecast_h(.x, forecast_h)  )
  # here (see lower)
  series <- dplyr::bind_rows(series)
  series <- dplyr::group_by(series, !!!tsibble::key(ts_tbl) )
  # group split followed by smth like .data[["index"]] to get logicals?
  # slow(er) but probably works - or we could just do it dirrectly after like 40
  series <- dplyr::mutate(series, train = rlang::.data$index < max( .data$index) - forecast_h + 1)
  series <- dplyr::select(series, -tidyselect::all_of(c("period")) )
  series <- dplyr::as_tibble(series)
  series <- tsibble::as_tsibble(series, index = time_index, key = key_var )

  forecast_h <- dplyr::select( as.data.frame(series),
                               tidyselect::all_of( c("forecast_h",
                                                     key_var ))
  )
  forecast_h <- dplyr::distinct(forecast_h)

  series <- dplyr::select(series, -tidyselect::all_of("forecast_h") )

  train <- dplyr::filter(series, train )
  train <- dplyr::select(train, -tidyselect::all_of( "train" ) )

  test <- dplyr::filter(series,  !train )
  series <- dplyr::select(series, -tidyselect::all_of( "train" ) )

  return(list( train = train, test = test, forecast_h = forecast_h ))
}

fit_models <- function( train,
                        models = list( ar = fable::AR)) {

  values_from <- tsibble::measured_vars(train)
  model_set <- purrr::map( models, ~ .x( !!rlang::sym(values_from))  )

  fabletools::model(train, !!!model_set, .safely = TRUE )
}

forecast_models <- function( test,
                             fitted_models ) {
  fcst <- fabletools::forecast( fitted_models, new_data = test, times = 0 )
  return( list( forecasts = fcst) )
}

all_accuracy_measures <- list(
  fabletools::point_accuracy_measures,
  fabletools::interval_accuracy_measures,
  fabletools::distribution_accuracy_measures,
  fabletools::directional_accuracy_measures
)

#' Soothsayer oracle training utilities
#' @description Utilities for training oracle models
#' @param series The series to use for modelling, a \link[tsibble]{tsibble}. Note that this must only have a
#' single measure variable.
#' @param models A named list of fable compatible model functions, such as \link[fable]{ARIMA}.
#' This must be a function which returns a valid fable model definition, preferably generated by
#' \link[fabletools]{new_model_definition}
#' @param forecast_h The forecast length h, either an integer, or a function which can return one,
#' possibly based on the underlying time series. For example, see \link[soothsayer]{random_forecast_h}.
#' @param save_forecast_experiment Whether to save the forecasting experiment, by default **TRUE**, and uses
#' quicksave (see \link[qs]{qsave})
#' @param save_outfile An optional argument to add an identifier to the saved objects.
#' @param save_folder The folder to save the objects to.
#' @param save_n_threads The number of threads to use when saving objects, see documentation for \link[qs]{qsave}.
#' @return A list with fitted models, generated forecasts and computed forecast accuracies.
#' @rdname oracle_utils
#' @export
soothsayer_forecaster <- function( series,
                                   models = list( arima = fable::ARIMA,
                                                  snaive = fable::SNAIVE,
                                                  rw = fable::RW,
                                                  theta = fable::THETA,
                                                  ets = fable::ETS,
                                                  nnetar = fable::NNETAR,
                                                  croston = fable::CROSTON,
                                                  ar = fable::AR
                                                  # ar1 = fix_model_parameters(fable::AR, order(1)),
                                                  # ar3 = fix_model_parameters(fable::AR, order(3)),
                                                  # arma11 = fix_model_parameters(fable::ARIMA, pdq(1,0,1)),
                                                  # arma31 = fix_model_parameters(fable::ARIMA, pdf(3,0,1))
                                   ),
                                   forecast_h = random_forecast_h,
                                   save_forecast_experiment = TRUE,
                                   save_outfile = "1",
                                   save_folder = paste0("experiment_",
                                                        Sys.Date()
                                   ),
                                   save_n_threads = 12)
{
  train_test <- ts_train_test( series,
                               forecast_h = forecast_h )

  models <- fit_models( train = train_test[["train"]],
                        models = models )
  if( save_forecast_experiment ) {
    fs::dir_create(save_folder)
    fs::dir_create(paste0(save_folder,"/models"))

    qs::qsave( models, file = paste0( save_folder,"/models/models_",
                                      save_outfile,".qs" ),
               nthreads = save_n_threads)
  }

  forecasts <- forecast_models( train_test[["test"]], models[["models"]] )
  if( save_forecast_experiment ) {
    fs::dir_create(paste0(save_folder,"/forecasts"))
    qs::qsave( forecasts, file = paste0( save_folder,"/forecasts/forecasts_",
                                         save_outfile,".qs" ),
               nthreads = save_n_threads)
  }
  accuracies <- fabletools::accuracy(forecasts[["forecasts"]],
                                     train_test[["test"]],
                                     measures = all_accuracy_measures)
  accuracies <- dplyr::left_join( accuracies, train_test[["forecast_h"]], by = "key" )

  if( save_forecast_experiment ) {
    fs::dir_create(paste0(save_folder,"/accuracies"))
    qs::qsave( accuracies, file = paste0( save_folder,
                                          "/accuracies/accuracies_",
                                          save_outfile,".qs" ),
               nthreads = save_n_threads)
  }

  return(list( models = models,
               forecasts = forecasts,
               accuracies = accuracies ))
}
JSzitas/soothsayer documentation built on April 18, 2023, 12:59 a.m.