R/lagwalk.R

Defines functions refit.RW model_sum.RW report.RW tidy.RW glance.RW residuals.RW fitted.RW generate.RW forecast.RW SNAIVE RW train_lagwalk

Documented in fitted.RW forecast.RW generate.RW glance.RW refit.RW report.RW residuals.RW RW SNAIVE tidy.RW

train_lagwalk <- function(.data, specials, ...) {
  if (length(measured_vars(.data)) > 1) {
    abort("Only univariate responses are supported by lagwalks.")
  }
  
  y <- unclass(.data)[[measured_vars(.data)]]
  n <- length(y)

  if (all(is.na(y))) {
    abort("All observations are missing, a model cannot be estimated without data.")
  }

  drift <- specials$drift[[1]][[1]] %||% FALSE
  fixed <- specials$drift[[1]][[2]]
  lag <- specials$lag[[1]]

  y_na <- which(is.na(y))
  y_na <- y_na[y_na > lag]
  fits <- stats::lag(y, -lag)
  for (i in y_na) {
    if (is.na(fits)[i]) {
      fits[i] <- fits[i - lag]
    }
  }

  fitted <- c(rep(NA, min(lag, n)), utils::head(fits, -lag))

  # Initial model estimation or re-estimation of RW model (with drift).
  if (drift) {
    if (!rlang::is_null(fixed)) {
      b <- fixed
      b.se <- dbl() # updated in refit.RW.
    } else {
      fit <- summary(stats::lm(y - fitted ~ 1, na.action = stats::na.exclude))
      b <- fit$coefficients[1, 1]
      b.se <- fit$coefficients[1, 2]
    }
    fitted <- fitted + b
  } else {
    # No drift model.
    b <- b.se <- dbl()
  }
  
  sigma <- stats::sd(y - fitted, na.rm = TRUE)
  res <- y - fitted

  structure(
    list(
      b = b,
      b.se = b.se,
      lag = lag,
      sigma2 = sigma^2,
      .fitted = fitted,
      .resid = res,
      time = list(start = unclass(.data)[[index_var(.data)]][[1]], interval = interval(.data)),
      future = y[c(rep(NA, max(0, lag - n)), seq_len(min(n, lag)) + n - min(n, lag))]
    ),
    class = "RW"
  )
}

#' Random walk models
#'
#' \code{RW()} returns a random walk model, which is equivalent to an ARIMA(0,1,0)
#' model with an optional drift coefficient included using \code{drift()}. \code{naive()} is simply a wrapper
#' to \code{rwf()} for simplicity. \code{snaive()} returns forecasts and
#' prediction intervals from an ARIMA(0,0,0)(0,1,0)m model where m is the
#' seasonal period.
#'
#' The random walk with drift model is \deqn{Y_t=c + Y_{t-1} + Z_t}{Y[t]=c +
#' Y[t-1] + Z[t]} where \eqn{Z_t}{Z[t]} is a normal iid error. Forecasts are
#' given by \deqn{Y_n(h)=ch+Y_n}{Y[n+h]=ch+Y[n]}. If there is no drift (as in
#' \code{naive}), the drift parameter c=0. Forecast standard errors allow for
#' uncertainty in estimating the drift parameter (unlike the corresponding
#' forecasts obtained by fitting an ARIMA model directly).
#'
#' The seasonal naive model is \deqn{Y_t= Y_{t-m} + Z_t}{Y[t]=Y[t-m] + Z[t]}
#' where \eqn{Z_t}{Z[t]} is a normal iid error.
#'
#' @aliases report.RW
#'
#' @param formula Model specification (see "Specials" section).
#' @param ... Not used.
#'
#' @section Specials:
#'
#' \subsection{lag}{
#' The `lag` special is used to specify the lag order for the random walk process.
#' If left out, this special will automatically be included.
#'
#' \preformatted{
#' lag(lag = NULL)
#' }
#'
#' \tabular{ll}{
#'   `lag`        \tab The lag order for the random walk process. If `lag = m`, forecasts will return the observation from `m` time periods ago. This can also be provided as text indicating the duration of the lag window (for example, annual seasonal lags would be "1 year").
#' }
#' }
#'
#' \subsection{drift}{
#' The `drift` special can be used to include a drift/trend component into the model. By default, drift is not included unless `drift()` is included in the formula.
#'
#' \preformatted{
#' drift(drift = TRUE)
#' }
#'
#' \tabular{ll}{
#'   `drift`      \tab If `drift = TRUE`, a drift term will be included in the model.
#' }
#' }
#'
#' @return A model specification.
#'
#' @seealso
#' [Forecasting: Principles and Practices, Some simple forecasting methods (section 3.2)](https://otexts.com/fpp3/simple-methods.html)
#'
#' @examples
#' library(tsibbledata)
#' aus_production %>%
#'   model(rw = RW(Beer ~ drift()))
#' @export
RW <- function(formula, ...) {
  rw_model <- new_model_class("RW",
    train = train_lagwalk,
    specials = new_specials(
      lag = function(lag = NULL) {
        if (is.null(lag)) {
          lag <- 1
        }
        if (!rlang::is_integerish(lag)) {
          warn("Non-integer lag orders for random walk models are not supported. Rounding to the nearest integer.")
          lag <- round(lag)
        }
        get_frequencies(lag, self$data, .auto = "smallest")
      },
      drift = function(drift = TRUE, fixed = NULL) {
        list(drift = drift, fixed = fixed)
      },
      xreg = no_xreg,
      .required_specials = c("lag")
    ),
    check = all_tsbl_checks
  )
  new_model_definition(rw_model, !!enquo(formula), ...)
}

#' @rdname RW
#'
#' @examples
#'
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value))
#' @export
NAIVE <- RW

#' @rdname RW
#'
#' @examples
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year")))
#' @export
SNAIVE <- function(formula, ...) {
  snaive_model <- new_model_class("RW",
    train = train_lagwalk,
    specials = new_specials(
      lag = function(lag = NULL) {
        lag <- get_frequencies(lag, self$data, .auto = "smallest")
        if (lag == 1) {
          abort("Non-seasonal model specification provided, use RW() or provide a different lag specification.")
        }
        if (!rlang::is_integerish(lag)) {
          warn("Non-integer lag orders for random walk models are not supported. Rounding to the nearest integer.")
          lag <- round(lag)
        }
        lag
      },
      drift = function(drift = TRUE, fixed = NULL) {
        list(drift = drift, fixed = fixed)
      },
      xreg = no_xreg,
      .required_specials = c("lag")
    ),
    check = all_tsbl_checks
  )
  new_model_definition(snaive_model, !!enquo(formula), ...)
}

#' @inherit forecast.ARIMA
#' @inheritParams forecast.ETS
#' @importFrom stats qnorm time
#' @importFrom utils tail
#'
#' @examples
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value)) %>%
#'   forecast()
#'
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year"))) %>%
#'   forecast()
#' @export
forecast.RW <- function(object, new_data, specials = NULL, simulate = FALSE, bootstrap = FALSE, times = 5000, ...) {
  h <- NROW(new_data)
  lag <- object$lag
  fullperiods <- (h - 1) / lag + 1
  steps <- rep(1:fullperiods, rep(lag, fullperiods))[1:h]

  b <- object$b
  b.se <- object$b.se
  if (is_empty(b)) {
    b <- b.se <- 0
  }

  # Produce forecasts
  if (simulate || bootstrap) { # Compute prediction intervals using simulations
    sim <- map(seq_len(times), function(x) {
      generate(object, new_data, bootstrap = bootstrap)[[".sim"]]
    }) %>%
      transpose() %>%
      map(as.numeric)
    distributional::dist_sample(sim)
  } else {
    fc <- rep(object$future, fullperiods)[1:h] + steps * b
    res <- residuals(object)
    mse <- sum(res^2, na.rm = TRUE)/(sum(!is.na(res)) - (b != 0))
    if (is.nan(mse)) mse <- NA
    # Adjust prediction intervals to allow for drift coefficient standard error
    se <- sqrt(mse * steps + (steps * b.se)^2)
    distributional::dist_normal(fc, se)
  }
}

#' @inherit generate.ETS
#'
#' @examples
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value)) %>%
#'   generate()
#'
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year"))) %>%
#'   generate()
#' @export
generate.RW <- function(x, new_data, bootstrap = FALSE, ...) {
  if (!is_regular(new_data)) {
    abort("Simulation new_data must be regularly spaced")
  }

  lag <- x$lag
  if (!is_empty(x$b)) {
    b <- stats::rnorm(1, mean = x$b, sd = x$b.se)
  } else {
    b <- 0
  }
  fits <- c(x$.fitted, x$future)

  start_idx <- min(new_data[[index_var(new_data)]])
  start_pos <- match(start_idx, seq(x$time$start, by = default_time_units(x$time$interval), length.out = length(fits)))

  future <- fits[start_pos + seq_len(lag) - 1]

  if (any(is.na(future))) {
    abort("The first lag window for simulation must be within the model's training set.")
  }

  if (!(".innov" %in% names(new_data))) {
    if (bootstrap) {
      new_data$.innov <- sample(stats::na.omit(residuals(x) - mean(residuals(x), na.rm = TRUE)),
        NROW(new_data),
        replace = TRUE
      )
    }
    else {
      new_data$.innov <- stats::rnorm(NROW(new_data), sd = sqrt(x$sigma2))
    }
  }

  sim_rw <- function(e) {
    # Cumulate errors
    dx <- e + b
    lag_grp <- rep_len(seq_len(lag), length(dx))
    dx <- split(dx, lag_grp)
    cumulative_e <- unsplit(lapply(dx, cumsum), lag_grp)
    rep_len(future, length(dx)) + cumulative_e
  }

  transmute(group_by_key(new_data), ".sim" := sim_rw(!!sym(".innov")))
}

#' @inherit fitted.ARIMA
#'
#' @examples
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value)) %>%
#'   fitted()
#'
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year"))) %>%
#'   fitted()
#' @export
fitted.RW <- function(object, ...) {
  object[[".fitted"]]
}

#' @inherit residuals.ARIMA
#'
#' @examples
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value)) %>%
#'   residuals()
#'
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year"))) %>%
#'   residuals()
#' @export
residuals.RW <- function(object, ...) {
  object[[".resid"]]
}

#' Glance a lag walk model
#'
#' Construct a single row summary of the lag walk model.
#' Contains the variance of residuals (`sigma2`).
#'
#' @inheritParams generics::glance
#'
#' @return A one row tibble summarising the model's fit.
#'
#' @examples
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value)) %>%
#'   glance()
#'
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year"))) %>%
#'   glance()
#' @export
glance.RW <- function(x, ...) {
  tibble(sigma2 = x[["sigma2"]])
}

#' @inherit tidy.ARIMA
#'
#' @examples
#' as_tsibble(Nile) %>%
#'   model(NAIVE(value)) %>%
#'   tidy()
#'
#' library(tsibbledata)
#' aus_production %>%
#'   model(snaive = SNAIVE(Beer ~ lag("year"))) %>%
#'   tidy()
#' @export
tidy.RW <- function(x, ...) {
  drift <- !is_empty(x$b)
  tibble(
    term = if (drift) "b" else chr(),
    estimate = x$b, std.error = x$b.se,
    statistic = x$b / x$b.se,
    p.value = 2 * stats::pt(abs(x$b / x$b.se), length(x$.resid) - x$lag - drift, lower.tail = FALSE)
  )
}

#' @export
report.RW <- function(object, ...) {
  cat("\n")
  if (!is_empty(object[["b"]])) {
    cat(paste("Drift: ", round(object[["b"]], 4),
      " (se: ", round(object[["b.se"]], 4), ")\n",
      sep = ""
    ))
  }
  cat(paste("sigma^2:", round(object[["sigma2"]], 4), "\n"))
}

#' @importFrom stats coef
#' @export
model_sum.RW <- function(x) {
  drift <- !is_empty(x[["b"]])
  if (x[["lag"]] == 1 && !drift) {
    method <- "NAIVE"
  }
  else if (x[["lag"]] != 1) {
    method <- "SNAIVE"
  }
  else {
    method <- "RW"
  }
  if (drift) {
    method <- paste(method, "w/ drift")
  }
  method
}

#' Refit a lag walk model
#'
#' Applies a fitted random walk model to a new dataset.
#' 
#' The models `NAIVE` and `SNAIVE` have no specific model parameters. Using `refit`
#' for one of these models will provide the same estimation results as one would 
#' use `fabletools::model(NAIVE(...))` (or `fabletools::model(SNAIVE(...))`. 
#'
#' @inheritParams refit.ARIMA
#' @param reestimate If `TRUE`, the lag walk model will be re-estimated 
#' to suit the new data. 
#' 
#' @examples
#' lung_deaths_male <- as_tsibble(mdeaths)
#' lung_deaths_female <- as_tsibble(fdeaths)
#'
#' fit <- lung_deaths_male %>%
#'   model(RW(value ~ drift()))
#'
#' report(fit)
#'
#' fit %>%
#'   refit(lung_deaths_female) %>%
#'   report()
#' @export
refit.RW <- function(object, new_data, specials = NULL, reestimate = FALSE, ...) {
  
  # Update specials 'lag'.
  specials$lag <- object$lag

  # Case if reestimate = TRUE. 
  if (reestimate) {
    return(train_lagwalk(new_data, specials, ...))
  }
  
  # Case if reestimate = FALSE. 
  # Update fixed. 
  if (!rlang::is_empty(object$b)) {
    specials$drift[[1]][[2]] <- object$b
  }
  
  refit <- train_lagwalk(new_data, specials, ...)
  
  # b.se could be either a numeric value or an empty numeric (dbl()). 
  refit$b.se <- object$b.se
  return(refit)
}

Try the fable package in your browser

Any scripts or data that you put into this service are public.

fable documentation built on March 31, 2023, 8:13 p.m.