#' Return model hyperparameters across validation datasets
#'
#' The purpose of this function is to support investigation into
#' the stability of hyperparameters in the nested cross-validation and across
#' forecast horizons.
#'
#' @param forecast_model An object of class 'forecast_model' from \code{\link{train_model}}.
#' @param hyper_function A user-defined function for retrieving model hyperparameters. See the
#' example below for details.
#' @return An S3 object of class 'forecast_model_hyper': A data.frame of model-specific hyperparameters.
#'
#' @section Methods and related functions:
#' The output of \code{return_hyper()} has the following generic S3 methods
#'
#' \itemize{
#' \item \code{\link[=plot.forecast_model_hyper]{plot}}
#' }
#' @example /R/examples/example_return_hyper.R
#' @export
return_hyper <- function(forecast_model, hyper_function) {
if(missing(forecast_model) || !methods::is(forecast_model, "forecast_model")) {
stop("The 'forecast_model' argument takes an object of class 'forecast_model' as input. Run train_model() first.")
}
if(missing(hyper_function) || !is.function(hyper_function)) {
stop("The 'hyper_function' argument should be a user-defined function that returns a 1-row data.frame of hyperparameter results.")
}
outcome_col <- attributes(forecast_model)$outcome_col
outcome_name <- attributes(forecast_model)$outcome_name
horizon <- attributes(forecast_model)$horizons
method <- attributes(forecast_model)$method
# Defined here to catch (from '<<-' below) the user-defined hyperparameter names in hyper_function.
# This will be an attribute in the function return.
hyper_names <- NULL
# Seq along model forecast horizon > window_number.
data_out <- lapply(seq_along(forecast_model), function(i) {
data_plot <- lapply(seq_along(forecast_model[[i]]), function(j) {
data_results <- forecast_model[[i]][[j]]
data_hyper <- hyper_function(data_results$model)
hyper_names <<- names(data_hyper)
data_plot <- data.frame("model" = attributes(forecast_model)$model_name,
"horizon" = horizon[i],
"window_length" = data_results$window_length,
"window_number" = j,
"valid_window_start" = min(data_results$valid_indices),
"valid_window_stop" = max(data_results$valid_indices),
"valid_window_midpoint" = mean(data_results$valid_indices))
if (method == "direct") {
names(data_plot)[names(data_plot) == "horizon"] <- "model_forecast_horizon"
}
data_plot <- cbind(data_plot, data_hyper)
return(data_plot)
}) # End loop across validation windows.
data_plot <- dplyr::bind_rows(data_plot)
return(data_plot)
}) # End loop across model forecast horizon.
data_out <- dplyr::bind_rows(data_out)
attr(data_out, "outcome_col") <- outcome_col
attr(data_out, "outcome_name") <- outcome_name
attr(data_out, "hyper_names") <- hyper_names
class(data_out) <- c("forecast_model_hyper", class(data_out))
return(data_out)
}
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#' Plot hyperparameters
#'
#' Plot hyperparameter stability and relationship with error metrics across validation datasets and horizons.
#'
#' @param x An object of class 'forecast_model_hyper' from \code{return_hyper()}.
#' @param data_results An object of class 'training_results' from
#' \code{predict.forecast_model()}.
#' @param data_error An object of class 'validation_error' from
#' \code{return_error()}.
#' @param type Select plot type; 'stability' is the default.
#' @param horizons Optional. A numeric vector to filter results by horizon.
#' @param windows Optional. A numeric vector to filter results by validation window number.
#' @param ... Not used.
#' @return Hyper-parameter plots of class 'ggplot'.
#' @example /R/examples/example_return_hyper.R
#' @export
plot.forecast_model_hyper <- function(x, data_results, data_error,
type = c("stability", "error"),
horizons = NULL,
windows = NULL, ...) { # nocov start
if(!methods::is(data_results, "training_results")) {
stop("The 'data_results' argument takes an object of class 'training_results' as input. Run predict.forecast_model() first.")
}
if(!methods::is(data_error, "validation_error")) {
stop("The 'data_error' argument takes an object of class 'validation_error' as input. Run return_error() first.")
}
data_plot <- x
type <- type[1]
method <- attributes(data_results)$method
outcome_col <- attributes(data_plot)$outcome_col
outcome_name <- attributes(data_plot)$outcome_name
hyper_names <- attributes(data_plot)$hyper_names
# Changing 'model_forecast_horizon' to 'horizon' for standardizing plot code with multi-output models.
if (method == "direct") {
names(data_plot)[names(data_plot) == "model_forecast_horizon"] <- "horizon"
data_results$horizon <- data_results$model_forecast_horizon
data_results$model_forecast_horizon <- NULL
data_error$horizon <- data_error$model_forecast_horizon
data_error$model_forecast_horizon <- NULL
}
hyper_num <- unlist(lapply(data_plot[, hyper_names], function(x) {inherits(x, c("numeric", "double", "integer"))}))
hyper_num <- hyper_names[hyper_num]
hyper_cat <- unlist(lapply(data_plot[, hyper_names], function(x) {inherits(x, c("factor", "character"))}))
hyper_cat <- hyper_names[hyper_cat]
horizons <- if (is.null(horizons)) {unique(data_plot$horizon)} else {horizons}
windows <- if (is.null(windows)) {unique(data_plot$window_number)} else {windows}
data_plot <- data_plot[data_plot$horizon %in% horizons & data_plot$window_number %in% windows, ]
data_plot$group <- paste0("00", data_plot$horizon)
data_plot$group <- substr(data_plot$group, nchar(data_plot$group) - 2, nchar(data_plot$group))
data_plot$group <- paste0(data_plot$group, "-", data_plot$window_length)
data_plot$group <- ordered(data_plot$group)
if (length(hyper_num) > 0) {
data_hyper_num <- tidyr::gather(data_plot, "hyper", "value",
-!!names(data_plot)[!names(data_plot) %in% hyper_num])
}
if (length(hyper_cat) > 0) {
data_hyper_cat <- tidyr::gather(data_plot, "hyper", "value",
-!!names(data_plot)[!names(data_plot) %in% hyper_cat])
}
#----------------------------------------------------------------------------
if (type == "stability") {
p <- ggplot()
if (length(hyper_num) > 0) {
if (length(unique(data_hyper_num$window_number)) > 1) {
p <- p + geom_line(data = data_hyper_num,
aes(x = ordered(.data$window_number),
y = .data$value,
group = .data$group), alpha = .5)
}
p <- p + geom_point(data = data_hyper_num,
aes(x = ordered(.data$window_number),
y = .data$value,
group = .data$group))
}
if (length(hyper_cat) > 0) {
p <- p + geom_bar(data = data_hyper_cat,
aes(x = ordered(.data$window_number),
fill = ordered(.data$value)),
position = position_dodge(), alpha = .5)
}
p <- p + facet_grid(hyper ~ horizon, scales = "free")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + xlab("Window number") + ylab("Hyperparameter value/count") +
labs(color = "Horizon - Window", fill = "Hyper") + ggtitle("Hyperparameter Stability Across Validation Windows")
return(p)
}
#----------------------------------------------------------------------------
if (type == "error") {
error_metrics <- attributes(data_error)$error_metrics
data_error_merge <- data_error$error_by_window
names(data_error_merge)[names(data_error_merge) == "model_forecast_horizon"] <- "horizon"
data_error_merge <- dplyr::select(data_error_merge,
.data$model,
.data$horizon,
.data$window_number,
error_metrics)
data_error_merge$model <- as.character(data_error_merge$model)
if (length(hyper_num) > 0) {
data_hyper_num$model <- as.character(data_hyper_num$model)
data_hyper_num <- dplyr::inner_join(data_error_merge, data_hyper_num, by = c("model", "horizon", "window_number"))
data_hyper_num <- tidyr::gather(data_hyper_num, "error_metric", "error",
-!!names(data_hyper_num)[!names(data_hyper_num) %in% error_metrics])
}
if (length(hyper_cat) > 0) {
data_hyper_cat$model <- as.character(data_hyper_cat$model)
data_hyper_cat <- dplyr::inner_join(data_hyper_cat, data_error_merge, by = c("model", "horizon", "window_number"))
data_hyper_cat <- tidyr::gather(data_hyper_cat, "error_metric", "error",
-!!names(data_hyper_cat)[!names(data_hyper_cat) %in% error_metrics])
}
p <- ggplot()
if (length(hyper_num) > 0) {
if (length(unique(data_hyper_num$window_number)) > 1) {
p <- p + geom_line(data = data_hyper_num,
aes(x = .data$value,
y = .data$error,
color = ordered(.data$horizon)), alpha = .5, show.legend = FALSE)
}
p <- p + geom_point(data = data_hyper_num,
aes(x = .data$value,
y = .data$error,
color = ordered(.data$horizon)), alpha = .5)
p <- p + scale_color_viridis_d()
p <- p + scale_fill_viridis_d()
p <- p + facet_grid(error_metric ~ hyper, scales = "free")
p <- p + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p <- p + xlab("Hyperparameter value") + ylab("Error metric") +
labs(color = "Horizon") + ggtitle("Forecast Error and Hyperparameter Values")
} # End numeric hyperparameter plot.
if (length(hyper_cat) > 0) {
p_cat <- ggplot()
p_cat <- p_cat + geom_col(data = data_hyper_cat,
aes(x = ordered(.data$value), y = .data$error,
fill = ordered(.data$group)
),
position = position_dodge())
p_cat <- p_cat + scale_fill_viridis_d()
p_cat <- p_cat + facet_grid(error_metric ~ hyper, scales = "free")
p_cat <- p_cat + theme_bw() + theme(panel.spacing = unit(0, "lines"))
p_cat <- p_cat + xlab("Hyperparameter value") + ylab("Error metric") +
labs(fill = "Horizon + validation window") + ggtitle("Forecast Error and Hyperparameter Values")
} # End categorical hyperparameter plot.
if (length(hyper_num) > 0 && length(hyper_cat) > 0) {
return(list(p, p_cat))
} else if (length(hyper_num) > 0) {
return(p)
} else if (length(hyper_cat) > 0) {
return(p_cat)
}
}
} # nocov end
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.