#' Function to tune N-BEATS
#'
#' @param id A quoted column name that tracks the GluonTS FieldName "item_id"
#' @param freq A pandas timeseries frequency such as "5min" for 5-minutes or "D" for daily.
#' @param recipe A gluonts recipe
#' @param horizon The forecast horizon
#' @param length The number of distinct hyperparameter for each tunable parameter except loss_function which is set to MASE
#' @param cv_slice_limit How many slice/folds in the tsCV
#' @param assess The number of samples used for each assessment resample
#' @param skip A integer indicating how many (if any) additional resamples to skip to thin the total amount of data points in the analysis resample.
#' @param initial The number of samples used for analysis/modeling in the initial resample.
#' @param epochs Number of epochs
#' @param lookback Lookback length. If NULL, will be randomly chosen
#' @param batch_size Number of examples in each batch
#' @param bagging_size The number of models that share the parameter combination of 'context_length' and 'loss_function'.
#' @param learn_rate Learning rate
#' @param loss_function Any of MASE, MAPE sMAPE. Defaults to MASE when loss_function = NULL
#' @param scale Scales numeric data by id group using mean = 0, standard deviation = 1 transformation.
#'
#'
tune_nbeats <- function(id, freq, recipe, splits, horizon, length, cv_slice_limit, assess = "12 weeks",
skip = "4 weeks", initial = "12 months", epochs = NULL, lookback = NULL, bagging_size = NULL,
learn_rate = NULL, loss_function = NULL, scale = NULL, batch_size = NULL) {
nbeats_grid <- data.frame(
epochs = if (is.null(epochs)) sample(100, size = length, replace = TRUE) else epochs,
lookback_length = if (is.null(lookback)) sample(1:7, size = length, replace = TRUE) else lookback,
batch_size = if (is.null(batch_size)) round(runif(length, min = 32, max = 1024), 0) else batch_size,
learn_rate = if (is.null(learn_rate)) runif(length, min = 1e-5, max = 1e-2) else learn_rate,
loss_function = if (is.null(loss_function)) "MASE" else loss_function,
scale = if (is.null(scale)) sample(c(TRUE, FALSE), length, replace = TRUE) else scale,
bagging_size = if (is.null(bagging_size)) sample(1:10, size = length, replace = TRUE) else bagging_size
)
nbeats_grid <- distinct(nbeats_grid)
resamples_tscv <- time_series_cv(
data = training(splits),
cumulative = TRUE,
initial = initial,
assess = assess,
skip = skip,
slice_limit = cv_slice_limit
)
model_table <- modeltime_table()
nbeats_list <- list()
cv_list <- list()
wflw_list <- list()
wflw_return <- list()
# Track which values will cause N-Beats from crashing
if(!dir.exists("grid_log")) {
dir.create("grid_log")
}
grid_log_file_name <- paste0("grid_log", "_", timestamp(prefix = "", suffix = "", quiet = TRUE), ".csv")
path_to_grid_file <- paste0("grid_log/", grid_log_file_name)
path_to_grid_file <- gsub(" ", "_", path_to_grid_file)
path_to_grid_file <- gsub(":", "_", path_to_grid_file)
# Create accuracy log file
if(!dir.exists("accuracy_log_nbeats")) {
dir.create("accuracy_log_nbeats")
}
log_accuracy_file_name <- paste0("log_accuracy", "_", timestamp(prefix = "", suffix = "", quiet = TRUE), ".csv")
path_to_file <- paste0("accuracy_log_nbeats/", log_accuracy_file_name)
path_to_file <- gsub(" ", "_", path_to_file)
path_to_file <- gsub(":", "_", path_to_file)
for(i in 1:nrow(nbeats_grid)) {
nbeats_grid_tbl <- nbeats_grid[1:i,]
if(!file.exists(path_to_grid_file)) {
message("Writing data to {path_to_grid_file}")
nbeats_grid_tbl %>% readr::write_csv(path_to_grid_file)
} else {
message("Reading in old data and writing new to {path_to_grid_file}")
old_file <- readr::read_csv(path_to_grid_file)
new_file <- bind_rows(nbeats_grid_tbl, old_file)
new_file %>% readr::write_csv(path_to_grid_file)
}
message("")
message(str_glue("Parameter set number {i} of {nrow(nbeats_grid)}"))
message(str_glue("Epochs: {nbeats_grid$epochs[[i]]}"))
message(str_glue("Lookback length: {nbeats_grid$lookback_length[[i]]}"))
message(str_glue("Number of different lookback periods: {nbeats_grid$lookback_length[i]}"))
message(str_glue("Batch size: {nbeats_grid$batch_size[i]}"))
message(str_glue("Learning rate: {nbeats_grid$learn_rate[i]}"))
message(str_glue("Scale: {nbeats_grid$scale[i]}"))
message(str_glue("Bagging: {nbeats_grid$bagging_size[i]}"))
model_spec <- nbeats(
id = id,
freq = freq,
prediction_length = horizon,
epochs = nbeats_grid$epochs[i],
lookback_length = 1:nbeats_grid$lookback_length[i] * horizon,
batch_size = nbeats_grid$batch_size[i],
learn_rate = nbeats_grid$learn_rate[i],
loss_function = nbeats_grid$loss_function[i],
scale = nbeats_grid$scale[i],
bagging_size = nbeats_grid$bagging_size[i],
) %>%
set_engine("gluonts_nbeats_ensemble")
for (j in 1:cv_slice_limit) {
#byrja <- Sys.time()
wflw_fit_nbeats <- workflow() %>%
add_model(model_spec) %>%
add_recipe(recipe) %>%
fit(training(resamples_tscv$splits[[j]]))
cv_accuracy <- wflw_fit_nbeats %>%
modeltime_table() %>%
modeltime_accuracy(testing(resamples_tscv$splits[[j]])) %>%
add_column(fold = paste("fold_", j))
cv_accuracy_summary <- cv_accuracy %>%
group_by(.model_id, .model_desc, fold) %>%
summarise(mae = mean(mae, na.rm = TRUE),
mape = mean(mape, na.rm = TRUE),
mase = mean(mase, na.rm = TRUE),
smape = mean(smape, na.rm = TRUE),
rmse = mean(rmse, na.rm = TRUE),
rsq = mean(rsq, na.rm = TRUE))
cv_list[[j]] <- cv_accuracy_summary %>% bind_cols(nbeats_grid[i,])
# enda <- Sys.time() - byrja
# cv_accuracy$tune_time <- enda
#cv_list[[j]] <- cv_accuracy %>% bind_cols(nbeats_grid[i,])
}
nbeats_list[[i]] <- bind_rows(cv_list)
wflw_return[[i]] <- wflw_fit_nbeats
if(!file.exists(path_to_file)) {
message("Writing data to {path_to_file}")
bind_rows(nbeats_list) %>% readr::write_csv(path_to_file)
} else {
message("Reading in old data and writing new to {path_to_file}")
old_file <- readr::read_csv(path_to_file)
new_file <- bind_rows(bind_rows(nbeats_list), old_file)
new_file %>% readr::write_csv(path_to_file)
}
}
nbeats_list <- bind_rows(nbeats_list)
best_model_index <- which(nbeats_list$rmse == min(nbeats_list$rmse))
best_model <- wflw_return[[best_model_index]]
return_list <- list()
return_list$nbeats_list <- nbeats_list
return_list$best_model <- best_model
return(return_list)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.