R/crossval_ts.R

Defines functions crossval_ts

Documented in crossval_ts

#' Generic cross-validation function for time series
#'
#' Generic cross-validation for univariate time series
#'
#' @param y response time series; a vector
#' @param x input covariates' matrix (optional)
#' @param fit_func a function for fitting the model
#' @param predict_func a function for predicting values from the model
#' @param fcast_func time series forecasting function
#' @param fit_params a list; additional (model-specific) parameters to be passed
#' to \code{fit_func}
#' @param initial_window an integer; the initial number of consecutive values in each training set sample
#' @param horizon an integer; the number of consecutive values in test set sample
#' @param fixed_window a boolean; if FALSE, all training samples start at 1
#' @param type_forecast a string; "mean" for mean forecast, "lower", "upper" for lower and upper bounds respectively
#' @param level a numeric vector; confidence levels for prediction intervals.
#' @param seed random seed for reproducibility of results
#' @param eval_metric a function measuring the test errors; if not provided: RMSE for regression and
#' accuracy for classification
#' @param cl an integer; the number of clusters for parallel execution
#' @param errorhandling specifies how a task evalution error should be handled.
#' If value is "stop", then execution will be stopped if an error occurs. If value
#' is "remove", the result for that task will not be returned. If value is "pass",
#' then the error object generated by task evaluation will be included with the
#' rest of the results. The default value is "stop".
#' @param packages character vector of packages that the tasks depend on
#' @param verbose logical flag enabling verbose messages. This can be very useful for
#' troubleshooting.
#' @param show_progress show evolution of the algorithm
#' @param ... additional parameters
#'
#' @return
#' @export
#'
#' @examples
#'
#'
#' require(forecast)
#' data("AirPassengers")
#'
#' # Example 1 -----
#'
#' res <- crossval_ts(y=AirPassengers, initial_window = 10, fcast_func = thetaf)
#' print(colMeans(res))
#'
#'
#' # Example 2 -----
#'
#' fcast_func <- function (y, h, ...)
#' {
#'       forecast::forecast(forecast::auto.arima(y, ...),
#'       h=h, ...)
#' }
#'
#' res <- crossval_ts(y=AirPassengers, initial_window = 10, fcast_func = fcast_func)
#' print(colMeans(res))
#'
#'
#' # Example 3 -----
#'
#' fcast_func <- function (y, h, ...)
#' {
#'       forecast::forecast(forecast::ets(y, ...),
#'       h=h, ...)
#' }
#'
#' res <- crossval_ts(y=AirPassengers, initial_window = 10, fcast_func = fcast_func)
#' print(colMeans(res))
#'
#'
#' # Example 4 -----
#'
#' xreg <- cbind(1, 1:length(AirPassengers))
#' res <- crossval_ts(y=AirPassengers, x=xreg, fit_func = crossval::fit_lm,
#' predict_func = crossval::predict_lm,
#' initial_window = 10,
#' horizon = 3,
#' fixed_window = TRUE)
#' print(colMeans(res))
#'
#'
#' # Example 5 -----
#'
#' res <- crossval_ts(y=AirPassengers, x=xreg, fcast_func = thetaf,
#' initial_window = 10,
#' horizon = 3,
#' fixed_window = TRUE, type_forecast="quantiles")
#' print(colMeans(res))
#'
#'
#'#' # Example 6 -----
#'
#' xreg <- cbind(1, 1:length(AirPassengers))
#' res <- crossval_ts(y=AirPassengers, x=xreg, fit_func = crossval::fit_lm,
#' predict_func = crossval::predict_lm,
#' initial_window = 10,
#' horizon = 3,
#' fixed_window = TRUE, type_forecast="quantiles")
#' print(colMeans(res))
#'
#'
crossval_ts <- function(y,
                        x = NULL,
                        fit_func = crossval::fit_lm,
                        predict_func = crossval::predict_lm,
                        fcast_func = NULL,
                        fit_params = NULL,
                        # parameters of funcs
                        initial_window = 5,
                        horizon = 3,
                        fixed_window = TRUE,
                        type_forecast = c("mean", "quantiles"),
                        level = c(80, 95),
                        seed = 123,
                        eval_metric = NULL,
                        cl = NULL,
                        errorhandling = c('stop', 'remove', 'pass'),
                        packages = c("stats", "Rcpp"),
                        verbose = FALSE,
                        show_progress = TRUE,
                        ...) {
  n_y <- length(y)
  time_slices <-
    crossval::create_time_slices(
      y,
      initial_window = initial_window,
      horizon = horizon,
      fixed_window = fixed_window
    )
  n_slices <- length(time_slices$train)
  type_forecast <- match.arg(type_forecast)

  if (!is.null(x))
  {
    n_x <- dim(x)[1]
    p_x <- dim(x)[2]
    stopifnot(n_x == n_y)
  }

  if (is.null(eval_metric))
  {
    eval_metric <- function(predicted, observed)
    {
      error <- observed - predicted
      pe <- predicted / observed - 1

      res <- c(
        mean(error, na.rm = FALSE),
        sqrt(mean(error ^ 2, na.rm = FALSE)),
        mean(abs(error), na.rm = FALSE),
        mean(pe, na.rm = FALSE),
        mean(abs(pe), na.rm = FALSE)
      )

      names(res) <- c("ME", "RMSE", "MAE", "MPE", "MAPE")

      return(res)
    }
    eval_metric <- compiler::cmpfun(eval_metric)
  }

  if (!is.null(cl)) {


    # 1 - parallel execution --------------------------------------------------

    cl_SOCK <- parallel::makeCluster(cl, type = "SOCK")
    doSNOW::registerDoSNOW(cl_SOCK)

    pb <- txtProgressBar(min = 0,
                         max = n_slices,
                         style = 3)
    progress <- function(n)
      utils::setTxtProgressBar(pb, n)
    opts <- list(progress = progress)

    `%op%` <- foreach::`%dopar%`

    if (!is.null(fcast_func)) {

      # 1 - 1 interface for forecasting functions --------------------------------------------------

      i <- NULL
      res <- foreach::foreach(
        i = 1:n_slices,
        .packages = packages,
        .combine = rbind,
        .errorhandling = errorhandling,
        .options.snow = opts,
        .verbose = verbose
      ) %op% {

        train_index <- time_slices$train[[i]]
        test_index <- time_slices$test[[i]]

        if (is.null(ncol(y)))
        {
          # univariate
          preds <- switch(type_forecast,
                          "mean" = try(do.call(what = fcast_func,
                                               args = list(y = y[train_index],
                                               h = horizon, ...))$mean, silent = FALSE),
                          "quantiles" = try(do.call(what = fcast_func,
                                                    args = list(y = y[train_index],
                                                            h = horizon,
                                                            level = level,
                                                            ...)),
                                            silent = FALSE))

          if (type_forecast == "quantiles")
          {
            upper_qs <- 100 - (100 - level)/2
            lower_qs <- rev(100 - upper_qs)
            qlist <- c(lower_qs, 50, upper_qs)/100
            nqs <- length(qlist)
            preds <- cbind(preds$lower[, ncol(preds$lower):1], preds$mean, preds$upper)
            colnames(preds) <- paste0("q", qlist)
          }


          if (class(preds)[1] == "try-error")
          {
            preds <- ifelse(type_forecast == "mean",
                            rep(NA, horizon),
                            matrix(NA, nrow = horizon, ncol = nqs))
          }

          # measure the error
          error_measure <-
            eval_metric(preds, y[test_index])

        } else {
          # multivariate
          # multivariate time series
          # TODO
          return(0)
        }

        if (show_progress)
        {
          setTxtProgressBar(pb, i)
        }

        error_measure

      }
      close(pb)
      snow::stopCluster(cl_SOCK)

    } else {


      # 1 - 2 interface for ml function --------------------------------------------------

      i <- NULL
      res <- foreach::foreach(
        i = 1:n_slices,
        .packages = packages,
        .combine = rbind,
        .errorhandling = errorhandling,
        .options.snow = opts,
        .verbose = verbose
      ) %op% {
        # predict
        train_index <-
          time_slices$train[[i]]
        test_index <- time_slices$test[[i]]


        if (is.null(ncol(y)))
        {
          # univariate time series
          fit_obj <-
            do.call(what = fit_func,
                    args = c(list(x = x[train_index, ],
                                  y = y[train_index]),
                             fit_params))

          # predict
          preds <-
            try(predict_func(fit_obj, newdata = x[test_index, ]),
                silent = TRUE)

          if (class(preds)[1] == "try-error")
          {
            preds <- try(predict_func(fit_obj, newx = x[test_index, ]),
                         silent = FALSE)
            if (class(preds) == "try-error")
            {
              preds <- rep(NA, length(test_index))
            }
          }

          # measure the error
          error_measure <-
            eval_metric(preds, y[test_index])

        } else {

          # multivariate time series
          # TODO
          return(0)
        }

        if (show_progress)
        {
          setTxtProgressBar(pb, i)
        }

        error_measure
      }
      close(pb)
      snow::stopCluster(cl_SOCK)
    }

  } else {


    # 2 - sequential execution --------------------------------------------------

    `%op%` <- foreach::`%do%`
    if (show_progress)
    {
      pb <- txtProgressBar(min = 0,
                           max = n_slices,
                           style = 3)
    }

    i <- j <- NULL

    if (!is.null(fcast_func)) {
      # stopifnot(is.null(fit_func))
      # stopifnot(is.null(predict_func))
      #print("\n")
      #base::message("Forecasting function...")


      # 2 - 1 interface for forecasting functions --------------------------------------------------

      res <- foreach::foreach(
        i = 1:n_slices,
        .packages = packages,
        .combine = rbind,
        .verbose = FALSE,
        .errorhandling = errorhandling,
        .export = c("fcast_params")
      ) %op% {
        train_index <- time_slices$train[[i]]
        test_index <- time_slices$test[[i]]

        if (is.null(ncol(y)))
          # univariate
        {
          preds <- switch(type_forecast,
                          "mean" = try(do.call(what = fcast_func,
                                               args = list(y = y[train_index],
                                                           h = horizon, ...))$mean,
                                       silent = FALSE),
                          "quantiles" = try(do.call(what = fcast_func,
                                                    args = list(y = y[train_index],
                                                            h = horizon,
                                                            level = level,
                                                            ...)),
                                        silent = FALSE))


          if (type_forecast == "quantiles")
          {
            upper_qs <- 100 - (100 - level)/2
            lower_qs <- rev(100 - upper_qs)
            qlist <- c(lower_qs, 50, upper_qs)/100
            nqs <- length(qlist)
            preds <- cbind(preds$lower[, ncol(preds$lower):1], preds$mean, preds$upper)
            colnames(preds) <- paste0("q", qlist)
          }


          if (class(preds)[1] == "try-error")
          {
            preds <- ifelse(type_forecast == "mean",
                            rep(NA, horizon),
                            matrix(NA, nrow = horizon, ncol = nqs))
          }

          # measure the error
          error_measure <-
            eval_metric(preds, y[test_index])

        } else {
          # multivariate
          # multivariate time series
          # TODO
          return(0)
        }

        if (show_progress)
        {
          setTxtProgressBar(pb, i)
        }

        error_measure
      }


    } else {


      # 2 - 2 interface for ml function --------------------------------------------------

      stopifnot(!is.null(x))
      stopifnot(is.null(fcast_func))

      #print("\n")
      #base::message("ML function...")

      res <- foreach::foreach(
        i = 1:n_slices,
        .packages = packages,
        .combine = rbind,
        .verbose = FALSE,
        .errorhandling = errorhandling,
        .export = c("fcast_params")
      ) %op% {
        # predict
        train_index <-
          time_slices$train[[i]]
        test_index <- time_slices$test[[i]]


        if (is.null(ncol(y)))
        {
          # univariate time series
          fit_obj <-
            do.call(what = fit_func,
                    args = c(list(x = x[train_index, ],
                                  y = y[train_index]),
                             fit_params))

          # predict
          preds <-
            try(predict_func(fit_obj, newdata = x[test_index, ]),
                silent = TRUE)

          if (class(preds) == "try-error")
          {
            preds <- try(predict_func(fit_obj, newx = x[test_index, ]),
                         silent = FALSE)
            if (class(preds) == "try-error")
            {
              preds <- rep(NA, length(test_index))
            }
          }

          # measure the error
          error_measure <-
            eval_metric(preds, y[test_index])

        } else {
          # multivariate time series
          # TODO
          return(0)
        }

        if (show_progress)
        {
          setTxtProgressBar(pb, i)
        }

        error_measure
      }


    }

    if (show_progress)
    {
      close(pb)
    }

  }

  return(res)

}
compiler::cmpfun(crossval_ts)
thierrymoudiki/crossval documentation built on Aug. 17, 2020, 5:51 a.m.