#' Generic cross-validation function for time series
#'
#' Generic cross-validation for univariate and multivariate time series
#'
#' @param y response time series; a vector or a matrix
#' @param x input covariates' matrix (optional) for ML models
#' @param fit_func a function for fitting the model (if validation of ML model)
#' @param predict_func a function for predicting values from the model (if validation of ML model)
#' @param fcast_func time series forecasting function (e.g forecast::thetaf)
#' @param fit_params a list; additional (model-specific) parameters to be passed
#' to \code{fit_func}
#' @param initial_window an integer; the initial number of consecutive values in each training set sample
#' @param horizon an integer; the number of consecutive values in test set sample
#' @param fixed_window a boolean; if FALSE, all training samples start at 1
#' @param type_forecast a string; "mean" for mean forecast, "lower", "upper" for lower and upper bounds respectively
#' @param level a numeric vector; confidence levels for prediction intervals.
#' @param seed random seed for reproducibility of results
#' @param eval_metric a function measuring the test errors; if not provided: RMSE for regression and
#' accuracy for classification
#' @param cl an integer; the number of clusters for parallel execution
#' @param errorhandling specifies how a task evalution error should be handled.
#' If value is "stop", then execution will be stopped if an error occurs. If value
#' is "remove", the result for that task will not be returned. If value is "pass",
#' then the error object generated by task evaluation will be included with the
#' rest of the results. The default value is "stop".
#' @param packages character vector of packages that the tasks depend on
#' @param verbose logical flag enabling verbose messages. This can be very useful for
#' troubleshooting.
#' @param show_progress show evolution of the algorithm
#' @param ... additional parameters
#'
#' @return
#' @export
#'
#' @examples
#'
#'
#' require(forecast)
#' data("AirPassengers")
#'
#' # Example 1 -----
#'
#' res <- crossval_ts(y=AirPassengers, initial_window = 10,
#' horizon = 3, fcast_func = forecast::thetaf)
#' print(colMeans(res))
#'
#'
#' # Example 2 -----
#'
#' fcast_func <- function (y, h, ...)
#' {
#' forecast::forecast(forecast::auto.arima(y, ...),
#' h=h, ...)
#' }
#'
#' res <- crossval_ts(y=AirPassengers, initial_window = 10, horizon = 3,
#' fcast_func = fcast_func)
#' print(colMeans(res))
#'
#'
#' # Example 3 -----
#'
#' fcast_func <- function (y, h, ...)
#' {
#' forecast::forecast(forecast::ets(y, ...),
#' h=h, ...)
#' }
#'
#' res <- crossval_ts(y=AirPassengers,
#' initial_window = 10, horizon = 3, fcast_func = fcast_func)
#' print(colMeans(res))
#'
#'
#' # Example 4 -----
#'
#' xreg <- cbind(1, 1:length(AirPassengers))
#' res <- crossval_ts(y=AirPassengers, x=xreg, fit_func = crossval::fit_lm,
#' predict_func = crossval::predict_lm,
#' initial_window = 10,
#' horizon = 3,
#' fixed_window = TRUE)
#' print(colMeans(res))
#'
#'
#' # Example 5 -----
#'
#' res <- crossval_ts(y=AirPassengers, fcast_func = forecast::thetaf,
#' initial_window = 10,
#' horizon = 3,
#' fixed_window = TRUE, type_forecast="quantiles")
#' print(colMeans(res))
#'
#'
#'#' # Example 6 -----
#'
#' xreg <- cbind(1, 1:length(AirPassengers))
#' res <- crossval_ts(y=AirPassengers, x=xreg, fit_func = crossval::fit_lm,
#' predict_func = crossval::predict_lm,
#' initial_window = 10,
#' horizon = 3,
#' fixed_window = TRUE, type_forecast="quantiles")
#' print(colMeans(res))
#'
#'
#' # Example 7 -----
#'
#' x <- ts(matrix(rnorm(50), nrow = 25))
#'
#' fcast_func <- function(y, h = 5, type_forecast=c("mean", "median"))
#' {
#' type_forecast <- match.arg(type_forecast)
#'
#' if (type_forecast == "mean")
#' {
#' means <- colMeans(y)
#' return(list(mean = t(replicate(n = h, expr = means))))
#' } else {
#' medians <- apply(y, 2, median)
#' return(list(mean = t(replicate(n = h, expr = medians))))
#' }
#'
#' }
#'
#' print(fcast_func(x))
#'
#' res <- crossval::crossval_ts(y = x, fcast_func = fcast_func, fit_params = list(type_forecast = "median"))
#' colMeans(res)
#'
#' res <- crossval::crossval_ts(y = x, fcast_func = fcast_func, fit_params = list(type_forecast = "mean"))
#' colMeans(res)
#'
#'
crossval_ts <- function(y,
x = NULL,
fit_func = crossval::fit_lm,
predict_func = crossval::predict_lm,
fcast_func = NULL,
fit_params = NULL,
# parameters of funcs
initial_window = 5,
horizon = 3,
fixed_window = TRUE,
type_forecast = c("mean", "quantiles"), # check the "quantiles" option
level = c(80, 95),
seed = 123,
eval_metric = NULL,
cl = NULL,
errorhandling = c('stop', 'remove', 'pass'),
packages = c("stats", "Rcpp"),
verbose = FALSE,
show_progress = TRUE,
...) {
if(!is.null(ncol(y)))
{
n_y <- dim(y)[1]
} else {
n_y <- length(y)
}
time_slices <-
crossval::create_time_slices(
y,
initial_window = initial_window,
horizon = horizon,
fixed_window = fixed_window
)
n_slices <- length(time_slices$train)
type_forecast <- match.arg(type_forecast) # experimental with "quantiles" option
if (!is.null(x)) # regression, ML model
{
n_x <- dim(x)[1]
p_x <- dim(x)[2]
stopifnot(n_x == n_y)
}
# performance measures
if (is.null(eval_metric))
{
eval_metric <- function(predicted, observed)
{
error <- observed - predicted
pe <- predicted / observed - 1
res <- c(
mean(error, na.rm = FALSE),
sqrt(mean(error ^ 2, na.rm = FALSE)),
mean(abs(error), na.rm = FALSE),
mean(pe, na.rm = FALSE),
mean(abs(pe), na.rm = FALSE)
)
names(res) <- c("ME", "RMSE", "MAE", "MPE", "MAPE")
return(res)
}
eval_metric <- compiler::cmpfun(eval_metric)
}
# progress bars
if (!is.null(cl)) {
cl_SOCK <- parallel::makeCluster(cl, type = "SOCK")
doSNOW::registerDoSNOW(cl_SOCK)
`%op%` <- foreach::`%dopar%`
} else {
`%op%` <- foreach::`%do%`
}
pb <- txtProgressBar(min = 0,
max = n_slices,
style = 3)
progress <- function(n) {utils::setTxtProgressBar(pb, n)}
opts <- list(progress = progress)
if (!is.null(fcast_func)) {
# 1 - interface for forecasting functions --------------------------------------------------
i <- NULL
res <- foreach::foreach(
i = 1:n_slices,
.packages = packages,
.combine = rbind,
.errorhandling = errorhandling,
.options.snow = opts,
.verbose = verbose
) %op% {
train_index <- time_slices$train[[i]]
test_index <- time_slices$test[[i]]
# 1 - 1 interface for forecasting functions: univariate --------------------------------------------------
if (is.null(ncol(y))) # univariate time series case
{
if (type_forecast == "mean")
{
preds <- switch(type_forecast,
"mean" = try(do.call(
what = fcast_func,
args = c(list(y = y[train_index],
h = horizon), fit_params))$mean, silent = FALSE)
,
"quantiles" = try(do.call(
what = fcast_func,
args = c(list(y = y[train_index],
h = horizon,
level = level), fit_params)), silent = FALSE))
} else {
# if (type_forecast == "quantiles")
# to be checked again
upper_qs <- 100 - (100 - level) / 2
lower_qs <- rev(100 - upper_qs)
qlist <- c(lower_qs, 50, upper_qs) / 100
nqs <- length(qlist)
# preds in this case will be:
#q0.025 q0.1 q0.5 q0.9 q0.975
#11 529.2090 560.6045 619.9121 679.2196 710.6152
preds <-
try(cbind(preds$lower[, ncol(preds$lower):1], preds$mean, preds$upper),
silent = TRUE)
try(colnames(preds) <- paste0("q", qlist), silent = TRUE)
}
if (class(preds)[1] == "try-error")
{
preds <- ifelse(
type_forecast == "mean",
rep(NA, horizon),
matrix(NA, nrow = horizon, ncol = nqs)
)
}
# measure the error
error_measure <-
eval_metric(preds, y[test_index]) # univariate; y[test_index, ]
} else { #multivariate time series case
# 1 - 2 interface for forecasting functions: multivariate --------------------------------------------------
preds <- switch(type_forecast,
"mean" = try(do.call(
what = fcast_func,
args = c(list(y = y[train_index, ],
h = horizon), fit_params)
)$mean, silent = FALSE)
,
"quantiles" = NULL) # quantiles not available yet
if (class(preds)[1] == "try-error" | is.null(preds))
{
preds <- ifelse(
type_forecast == "mean",
rep(NA, horizon),
matrix(NA, nrow = horizon, ncol = nqs)
)
}
# measure the error
error_measure <-
eval_metric(preds, y[test_index, ])
}
if (show_progress)
{
setTxtProgressBar(pb, i)
}
error_measure
}
close(pb)
if (!is.null(cl))
{
snow::stopCluster(cl_SOCK)
}
} else {
# if fcast_func is NULL, ML models are used
stopifnot(!is.null(fit_func))
stopifnot(!is.null(predict_func))
# 2 - interface for ml functions --------------------------------------------------
i <- NULL
res <- foreach::foreach(
i = 1:n_slices,
.packages = packages,
.combine = rbind,
.errorhandling = errorhandling,
.options.snow = opts,
.verbose = verbose
) %op% {
# predict
train_index <-
time_slices$train[[i]]
test_index <- time_slices$test[[i]]
if (is.null(ncol(y)))
{
# 2 - 1 interface for ml function: univariate --------------------------------------------------
fit_obj <-
do.call(what = fit_func,
args = c(list(x = x[train_index,],
y = y[train_index]),
fit_params))
# predict
preds <-
try(predict_func(fit_obj, newdata = x[test_index,]),
silent = TRUE)
if (class(preds)[1] == "try-error")
{
preds <- try(predict_func(fit_obj, newx = x[test_index,]),
silent = FALSE)
if (class(preds) == "try-error")
{
preds <- rep(NA, length(test_index))
}
}
# measure the error
error_measure <-
eval_metric(preds, y[test_index])
} else {
# 2 - 2 interface for ml function: multivariate (ko so far) --------------------------------------------------
stop("Not implemented")
}
if (show_progress)
{
setTxtProgressBar(pb, i)
}
error_measure
}
close(pb)
if (!is.null(cl))
{
snow::stopCluster(cl_SOCK)
}
}
return(res)
}
compiler::cmpfun(crossval_ts)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.