#' Create time-contiguous validation datasets for model evaluation
#'
#' Flexibly create blocks of time-contiguous validation datasets to assess the forecast accuracy
#' of trained models at various times in the past. These validation datasets are similar to
#' the outer loop of a nested cross-validation model training setup.
#'
#' @param lagged_df An object of class 'lagged_df' or 'grouped_lagged_df' from \code{\link{create_lagged_df}}.
#' @param window_length An integer that defines the length of the contiguous validation dataset in dataset rows/dates.
#' If dates were given in \code{create_lagged_df()}, the validation window is 'window_length' * 'date frequency' in calendar time.
#' Setting \code{window_length = 0} trains the model on (a) the entire dataset or (b) between a single \code{window_start} and
#' \code{window_stop} value. Specifying multiple \code{window_start} and \code{window_stop} values with vectors of
#' length > 1 overrides \code{window_length}.
#' @param window_start Optional. A row index or date identifying the row/date to start creating contiguous validation datasets. A
#' vector of start rows/dates can be supplied for greater control. The length and order of \code{window_start} should match \code{window_stop}.
#' If \code{length(window_start) > 1}, \code{window_length}, \code{skip}, and \code{include_partial_window} are ignored.
#' @param window_stop Optional. An index or date identifying the row/date to stop creating contiguous validation datasets. A
#' vector of start rows/dates can be supplied for greater control. The length and order of \code{window_stop} should match \code{window_start}.
#' If \code{length(window_stop) > 1}, \code{window_length}, \code{skip}, and \code{include_partial_window} are ignored.
#' @param skip An integer giving a fixed number of dataset rows/dates to skip between validation datasets. If dates were given
#' in \code{create_lagged_df()}, the time between validation windows is \code{skip} * 'date frequency'.
#' @param include_partial_window Boolean. If \code{TRUE}, keep validation datasets that are shorter than \code{window_length}.
#' @return An S3 object of class 'windows': A data.frame giving the indices for the validation datasets.
#'
#' @section Methods and related functions:
#'
#' The output of \code{create_windows()} is passed into
#'
#' \itemize{
#' \item \code{\link{train_model}}
#' }
#'
#' and has the following generic S3 methods
#'
#' \itemize{
#' \item \code{\link[=plot.windows]{plot}}
#' }
#' @example /R/examples/example_create_windows.R
#' @importFrom rlang !!
#' @importFrom rlang !!!
#' @export
create_windows <- function(lagged_df, window_length = 12L,
window_start = NULL, window_stop = NULL, skip = 0,
include_partial_window = TRUE) {
#----------------------------------------------------------------------------
if (!methods::is(lagged_df, "lagged_df")) {
stop("This function takes an object of class 'lagged_df' as input. Run create_lagged_df() first.")
}
data <- lagged_df
rm(lagged_df)
if (length(window_length) != 1 || !methods::is(window_length, "numeric")) {
stop("The 'window_length' argument needs to be a single positive integer.")
}
#----------------------------------------------------------------------------
outcome_col <- attributes(data)$outcome_col
outcome_name <- attributes(data)$outcome_name
date_indices <- attributes(data)$date_indices
frequency <- attributes(data)$frequency
data_start <- attributes(data)$data_start
data_stop <- attributes(data)$data_stop
#----------------------------------------------------------------------------
window_start <- if (is.null(window_start)) {data_start} else {window_start}
window_stop <- if (is.null(window_stop)) {data_stop} else {window_stop}
if (!is.null(date_indices) && !xor((methods::is(window_start, "Date") && methods::is(window_stop, "Date")), ((methods::is(window_start, "POSIXt") && methods::is(window_stop, "POSIXt"))))) {
stop("Dates were provided with the input dataset created with 'create_lagged_df()'; Enter a vector of window start dates of class 'Date' or 'POSIXt'.")
}
if (length(window_start) == 1 && length(window_stop) == 1) { # A single start and stop date.
if (!window_start >= data_start) {
stop(paste0("The start of all validation windows needs to occur on or after row/date ", data_start, " which is the beginning of the dataset."))
}
if (!window_stop <= data_stop) {
stop(paste0("The end of all validation windows needs to occur on or before row/date ", data_stop, " which is the end of the dataset."))
}
if (is.null(date_indices) && window_length > (as.numeric(window_stop - window_start) + 1)) {
stop(paste0("The window length is wider than 'window_stop - window_start'. Set 'window_length = 0' to get 1 validation window for this period."))
}
} else { # A vector of multiple start and stop dates.
if (length(window_start) != length(window_stop)) {
stop(paste0("length(window_start) != length(window_stop); each validation window needs a start and stop date."))
}
if (!all(window_stop >= window_start)) {
stop(paste0("'window_stop' needs to be greater than 'window_start' for all validation windows"))
}
}
# Creating windows with a non-date index.
if (is.null(date_indices)) {
if (length(window_start) == 1 && length(window_stop) == 1) { # A single start and stop date.
# If the window_length is 0 there are no nested cross-validation windows needed.
if (window_length == 0) {
window_matrices <- data.frame("start" = window_start, "stop" = window_stop, "window_length" = window_length)
} else {
# Create a vector of indices that give the last index/row for a full validation window.
max_train_indices <- window_stop - window_length + 1
window_matrices <- purrr::map2(max_train_indices, window_length, function(max_train_index, window_len) {
start_index <- 1:max_train_index
stop_index <- 1:max_train_index + window_len - 1
window_matrix <- cbind("start" = start_index, "stop" = stop_index, "window_length" = window_len)
window_matrix <- window_matrix[seq(window_start, nrow(window_matrix), window_len + skip), , drop = FALSE]
# The partial window is an additional row that represents the final, partial validation window.
if (isTRUE(include_partial_window)) {
window_matrix_partial <- window_matrix[nrow(window_matrix), , drop = FALSE]
window_matrix_partial[, "start"] <- window_matrix_partial[, "stop"] + 1 + skip
window_matrix_partial[, "stop"] <- window_stop
# Cleaning up windows that exceed the number of rows in the dataframe due to the 'skip' parameter.
if (window_matrix_partial[, "start"] <= window_matrix_partial[, "stop"]) {
window_matrix <- rbind(window_matrix, window_matrix_partial)
rownames(window_matrix) <- NULL
}
}
window_matrix
})
window_matrices <- as.data.frame(window_matrices[[1]])
}
} else { # Multiple start and stop indices.
window_matrices <- data.frame("start" = window_start, "stop" = window_stop, "window_length" = "custom")
} # End index-based cross-validation windows.
} else { # Creating cross-validation windows with dates.
if (length(window_start) == 1 && length(window_stop) == 1) { # A single start and stop date.
# If the window_length is 0 there are no nested cross-validation windows needed.
if (window_length == 0) {
window_matrices <- data.frame("start" = window_start, "stop" = window_stop, "window_length" = window_length)
} else {
all_dates <- seq(window_start, window_stop, frequency)
start_dates <- all_dates[seq(1, length(all_dates), by = window_length + skip)]
stop_dates <- all_dates[which(all_dates %in% start_dates) + window_length - 1]
window_matrices <- data.frame("start" = start_dates, "stop" = stop_dates, "window_length" = window_length)
if (isTRUE(include_partial_window)) {
window_matrices$stop[is.na(window_matrices$stop)] <- max(date_indices, na.rm = TRUE)
} else {
window_matrices <- window_matrices[complete.cases(window_matrices), ]
}
} # End date-based windows with a single start and stop date.
} else { # Multiple start and stop dates.
window_matrices <- data.frame("start" = window_start, "stop" = window_stop, "window_length" = "custom")
} # End date-based windows with a multiple start and stop dates.
} # End date-based windows.
attributes(window_matrices) <- unlist(list(attributes(window_matrices), # Keep the data.frame's attributes.
list("skip" = skip,
"outcome_col" = outcome_col,
"outcome_name" = outcome_name)), recursive = FALSE)
class(window_matrices) <- c("windows", class(window_matrices))
return(window_matrices)
}
#------------------------------------------------------------------------------
#' Plot validation datasets
#'
#' Plot validation datasets across time.
#'
#' @param x An object of class 'windows' from \code{create_windows()}.
#' @param lagged_df An object of class 'lagged_df' from \code{create_lagged_df()}.
#' @param show_labels Boolean. If \code{TRUE}, show validation dataset IDs on the plot.
#' @param group_filter Optional. A string for filtering plot results for grouped time series (e.g., \code{"group_col_1 == 'A'"}).
#' This string is passed to \code{dplyr::filter()} internally.
#' @param ... Not used.
#' @return A plot of the outer-loop nested cross-validation windows of class 'ggplot'.
#' @example /R/examples/example_plot_windows.R
#' @export
plot.windows <- function(x, lagged_df, show_labels = TRUE, group_filter = NULL, ...) { # nocov start
#----------------------------------------------------------------------------
if (!methods::is(lagged_df, "lagged_df")) {
stop("The 'lagged_df' argument takes an object of class 'lagged_df' as input. Run create_lagged_df() first.")
}
windows <- x
rm(x)
#----------------------------------------------------------------------------
method <- attributes(lagged_df)$method
outcome <- attributes(lagged_df)$outcome
outcome_col <- attributes(lagged_df)$outcome_col
outcome_name <- attributes(lagged_df)$outcome_name
outcome_levels <- attributes(lagged_df)$outcome_levels
row_indices <- attributes(lagged_df)$row_indices
date_indices <- attributes(lagged_df)$date_indices
frequency <- attributes(lagged_df)$frequency
groups <- attributes(lagged_df)$groups
#----------------------------------------------------------------------------
# If there are multiple horizons in the lagged_df, select the first dataset and columns for plotting.
if (method == "direct") {
data_plot <- dplyr::select(lagged_df[[1]], !!outcome_name, !!groups)
} else if (method == "multi_output") {
data_plot <- dplyr::bind_cols(outcome, dplyr::select(lagged_df[[1]], !!groups))
}
if (is.null(date_indices)) { # Index-based x-axis in plot.
data_plot$index <- row_indices
} else { # Date-based x-axis in plot.
data_plot$index <- date_indices
}
if (!is.null(group_filter)) {
data_plot <- dplyr::filter(data_plot, eval(parse(text = group_filter)))
}
#----------------------------------------------------------------------------
# Create different line segments in ggplot with `color = ggplot_color_group`.
data_plot$ggplot_color_group <- apply(data_plot[, groups, drop = FALSE], 1, function(x) {paste(x, collapse = "-")})
windows$window <- 1:nrow(windows)
# Find the x and y coordinates for the plot labels. The labels for factor outcomes are calculated
# in the gglot2 code.
if (isTRUE(show_labels) || missing(show_labels)) {
data_plot_group <- windows %>%
dplyr::group_by(.data$window_length, .data$window) %>%
dplyr::summarise("index" = .data$start + ((.data$stop - .data$start) / 2)) # Window midpoint for plot label.
if (is.null(outcome_levels)) { # Numeric outcome.
data_plot_group$label_height <- ifelse(min(data_plot[, 1], na.rm = TRUE) < 0,
(max(data_plot[, 1], na.rm = TRUE) - base::abs(min(data_plot[, 1], na.rm = TRUE))) / 2,
(max(data_plot[, 1], na.rm = TRUE) + base::abs(min(data_plot[, 1], na.rm = TRUE))) / 2)
}
data_plot_group <- data_plot_group[!is.na(data_plot_group$window), ]
}
#----------------------------------------------------------------------------
# Fill in date gaps with NAs so ggplot doesn't connect line segments where there were no entries recorded.
if (!is.null(groups)) {
data_plot_template <- expand.grid("index" = seq(min(date_indices, na.rm = TRUE), max(date_indices, na.rm = TRUE), by = frequency),
"ggplot_color_group" = unique(data_plot$ggplot_color_group),
stringsAsFactors = FALSE)
data_plot <- dplyr::left_join(data_plot_template, data_plot, by = c("index", "ggplot_color_group"))
data_plot$ggplot_color_group <- factor(data_plot$ggplot_color_group, levels = unique(data_plot$ggplot_color_group), ordered = TRUE)
# Create a dataset of points for those instances where there the outcomes are NA before and after a given instance.
# Points are needed because ggplot will not plot a 1-instance geom_line().
data_plot_point <- data_plot %>%
dplyr::group_by(.data$ggplot_color_group) %>%
dplyr::mutate("lag" = dplyr::lag(!!rlang::sym(outcome_name), 1),
"lead" = dplyr::lead(!!rlang::sym(outcome_name), 1)) %>%
dplyr::filter(is.na(.data$lag) & is.na(.data$lead))
data_plot_point$ggplot_color_group <- factor(data_plot_point$ggplot_color_group, levels = levels(data_plot$ggplot_color_group), ordered = TRUE)
} else { # Grouped time series.
data_plot$ggplot_color_group <- factor(data_plot$ggplot_color_group, levels = unique(data_plot$ggplot_color_group), ordered = TRUE)
}
#----------------------------------------------------------------------------
p <- ggplot()
p <- p + geom_rect(data = windows, aes(xmin = .data$start, xmax = .data$stop,
ymin = -Inf, ymax = Inf), fill = "grey85", show.legend = FALSE)
if (is.null(outcome_levels)) { # Numeric outcome.
p <- p + geom_line(data = data_plot, aes(x = .data$index, y = eval(parse(text = outcome_name)),
color = .data$ggplot_color_group), size = 1.05)
} else { # Factor outcome.
data_plot <- data_plot[!is.na(data_plot[, outcome_name]), ] # Removes NA from the legend from geom_tile().
data_plot[, outcome_name] <- factor(data_plot[, outcome_name], levels = outcome_levels, ordered = TRUE)
p <- p + geom_tile(data = data_plot, aes(x = .data$index, y = .data$ggplot_color_group,
fill = eval(parse(text = outcome_name))))
}
if (!is.null(groups) && is.null(outcome_levels)) { # Numeric outcome with groups.
if (nrow(data_plot_point) >= 1) {
p <- p + geom_point(data = data_plot_point, aes(x = .data$index, y = eval(parse(text = outcome_name)),
color = .data$ggplot_color_group), show.legend = FALSE)
}
}
if (isTRUE(show_labels) || missing(show_labels)) {
if (is.null(outcome_levels)) { # Numeric outcome.
p <- p + geom_label(data = data_plot_group, aes(x = .data$index, y = .data$label_height,
label = .data$window), color = "black", size = 4)
} else { # Factor outcome.
data_plot_group$label_height <- ordered(levels(ordered(data_plot$ggplot_color_group))[1])
p <- p + geom_label(data = data_plot_group, aes(x = .data$index, y = .data$label_height,
label = .data$window), color = "black", size = 4)
}
}
p <- p + theme_bw()
if (is.null(groups) && is.null(outcome_levels)) { # Numeric outcome without groups.
p <- p + theme(legend.position = "none")
}
if (is.null(outcome_levels)) { # Numeric outcome
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(color = "Groups") + ggtitle("Validation Windows")
} else { # Factor outcome.
if (length(levels(data_plot$ggplot_color_group)) == 1) {
p <- p + xlab("Dataset index") + ylab("Outcome") + labs(fill = "Outcome") + ggtitle("Validation Windows")
} else {
p <- p + xlab("Dataset index") + ylab("Groups") + labs(fill = "Outcome") + ggtitle("Validation Windows")
}
}
return(suppressWarnings(p))
} # nocov end
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.