R/mean.R

Defines functions refit.model_mean model_sum.model_mean report.model_mean tidy.model_mean glance.model_mean residuals.model_mean fitted.model_mean interpolate.model_mean generate.model_mean forecast.model_mean MEAN train_mean

Documented in fitted.model_mean forecast.model_mean generate.model_mean glance.model_mean interpolate.model_mean MEAN refit.model_mean report.model_mean residuals.model_mean tidy.model_mean

#' @importFrom stats sd
train_mean <- function(.data, specials, ...) {
  if (length(measured_vars(.data)) > 1) {
    abort("Only univariate responses are supported by MEAN.")
  }

  y <- unclass(.data)[[measured_vars(.data)]]

  if (all(is.na(y))) {
    abort("All observations are missing, a model cannot be estimated without data.")
  }

  n <- length(y)
  window_size <- specials$window[[1]]
  if (is.null(window_size)) {
    y_mean <- mean(y, na.rm = TRUE)
    fits <- rep(y_mean, n)
  }
  else {
    fits <- slide_dbl(y, mean,
      na.rm = TRUE,
      .size = window_size, .partial = TRUE
    )
    y_mean <- fits[length(fits)]
    fits <- dplyr::lag(fits)
  }
  res <- y - fits
  sigma <- sd(res, na.rm = TRUE)

  structure(
    list(
      fitted = fits,
      resid = res,
      mean = y_mean,
      sigma = sigma,
      nobs = sum(!is.na(y)),
      window = window_size %||% NA
    ),
    class = "model_mean"
  )
      
      # est = tibble(.fitted = fits, .resid = res),
      # fit = tibble(sigma2 = sigma^2),
      # spec = tibble(window_size = window_size %||% NA)
}

specials_mean <- new_specials(
  window = function(size = NULL) {
    size
  },
  .required_specials = "window"
)

#' Mean models
#'
#' \code{MEAN()} returns an iid model applied to the formula's response variable.
#'
#' @aliases report.model_mean
#'
#' @param formula Model specification.
#' @param ... Not used.
#'
#' @section Specials:
#'
#' \subsection{window}{
#' The `window` special is used to specify a rolling window for the mean.
#' \preformatted{
#' window(size = NULL)
#' }
#'
#' \tabular{ll}{
#'   `size`     \tab The size (number of observations) for the rolling window. If NULL (default), a rolling window will not be used.
#' }
#' }
#'
#' @return A model specification.
#'
#' @seealso
#' [Forecasting: Principles and Practices, Some simple forecasting methods (section 3.2)](https://otexts.com/fpp3/simple-methods.html)
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand))
#' @export
MEAN <- function(formula, ...) {
  mean_model <- new_model_class("mean",
    train = train_mean,
    specials = specials_mean
  )
  new_model_definition(mean_model, !!enquo(formula), ...)
}

#' @importFrom fabletools forecast
#' @importFrom stats qnorm time
#' @importFrom utils tail
#'
#' @inherit forecast.ARIMA
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand)) %>%
#'   forecast()
#' @export
forecast.model_mean <- function(object, new_data, specials = NULL, bootstrap = FALSE, times = 5000, ...) {
  h <- NROW(new_data)

  y_mean <- object$mean
  n <- length(object$resid)
  sigma <- object$sigma

  # Produce forecasts
  if (bootstrap) { # Compute prediction intervals using simulations
    sim <- map(seq_len(times), function(x) {
      generate(object, new_data, bootstrap = TRUE)[[".sim"]]
    }) %>%
      transpose() %>%
      map(as.numeric)
    distributional::dist_sample(sim)
  } else {
    fc <- rep(y_mean, h)
    se <- sigma * sqrt(1 + 1 / n)
    distributional::dist_normal(fc, se)
  }
}

#' @inherit generate.ETS
#' @importFrom stats na.omit
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand)) %>%
#'   generate()
#' @export
generate.model_mean <- function(x, new_data, bootstrap = FALSE, ...) {
  f <- x$mean

  if (!(".innov" %in% names(new_data))) {
    if (bootstrap) {
      res <- residuals(x)
      new_data$.innov <- sample(na.omit(res) - mean(res, na.rm = TRUE),
        NROW(new_data),
        replace = TRUE
      )
    }
    else {
      new_data$.innov <- stats::rnorm(NROW(new_data), sd = x$sigma)
    }
  }

  transmute(group_by_key(new_data), ".sim" := f + !!sym(".innov"))
}

#' @inherit interpolate.ARIMA
#'
#' @examples
#' library(tsibbledata)
#'
#' olympic_running %>%
#'   model(mean = MEAN(Time)) %>%
#'   interpolate(olympic_running)
#' @export
interpolate.model_mean <- function(object, new_data, specials, ...) {
  # Get inputs
  y <- new_data[[measured_vars(new_data)]]
  window_size <- object$window
  miss_val <- is.na(y)

  if (!is.na(window_size)) {
    fits <- dplyr::lag(
      slide_dbl(y, mean, na.rm = TRUE, .size = window_size, .partial = TRUE)
    )[miss_val]
  }
  else {
    fits <- object$mean
  }

  new_data[[measured_vars(new_data)]][miss_val] <- fits
  new_data
}

#' @inherit fitted.ARIMA
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand)) %>%
#'   fitted()
#' @export
fitted.model_mean <- function(object, ...) {
  object$fitted
}

#' @inherit residuals.ARIMA
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand)) %>%
#'   residuals()
#' @export
residuals.model_mean <- function(object, ...) {
  object$resid
}

#' Glance a average method model
#'
#' Construct a single row summary of the average method model.
#'
#' Contains the variance of residuals (`sigma2`).
#'
#' @inheritParams generics::glance
#'
#' @return A one row tibble summarising the model's fit.
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand)) %>%
#'   glance()
#' @export
glance.model_mean <- function(x, ...) {
  tibble(sigma2 = x$sigma^2)
}

#' @inherit tidy.ARIMA
#'
#' @examples
#' library(tsibbledata)
#' vic_elec %>%
#'   model(avg = MEAN(Demand)) %>%
#'   tidy()
#' @export
tidy.model_mean <- function(x, ...) {
  mu <- x$mean
  se <- x$sigma / sqrt(x$nobs)
  stat <- mu/se
  tibble(term = "mean", estimate = mu, std.error = se,
         statistic = stat,
         p.value = 2 * stats::pt(abs(stat), x$nobs - 1, lower.tail = FALSE))
}

#' @export
report.model_mean <- function(object, ...) {
  cat("\n")
  cat(paste("Mean:", round(object$mean, 4), "\n"))
  cat(paste("sigma^2:", round(object$sigma^2, 4), "\n"))
}

#' @export
model_sum.model_mean <- function(x) {
  paste0("MEAN") # , ", intToUtf8(0x3BC), "=", format(x$par$estimate))
}

#' Refit a MEAN model
#'
#' Applies a fitted average method model to a new dataset.
#'
#' @inheritParams refit.ARIMA
#' @param reestimate If `TRUE`, the mean for the fitted model will be re-estimated 
#' to suit the new data. 
#' 
#' @examples
#' lung_deaths_male <- as_tsibble(mdeaths)
#' lung_deaths_female <- as_tsibble(fdeaths)
#'
#' fit <- lung_deaths_male %>%
#'   model(MEAN(value))
#'
#' report(fit)
#'
#' fit %>%
#'   refit(lung_deaths_female) %>%
#'   report()
#' @export
refit.model_mean <- function(object, new_data, specials = NULL, reestimate = FALSE, ...) {
  # Update data for re-evaluation
  # update specials
  specials$window <- if(is.na(object$window)) NULL else object$window 

  if (reestimate) {
    return(train_mean(new_data, specials, ...))
  }
  
  y <- unclass(new_data)[[measured_vars(new_data)]]
  
  if (all(is.na(y))) {
    abort("All new observations are missing, model cannot be applied.")
  }

  if (!is_null(specials$window)) warn("A rolling mean model cannot be refitted, the most recent mean from the fitted model will be used as a fixed estimate of the mean.")
  
  n <- length(y)

  fits <- rep(object$mean, n)
  res <- y - fits
  sigma <- sd(res, na.rm = TRUE)
  
  object$fitted <- fits
  object$resid <- res
  object$sigma <- sigma
  object$nobs <- sum(!is.na(y))
  object
}

Try the fable package in your browser

Any scripts or data that you put into this service are public.

fable documentation built on March 31, 2023, 8:13 p.m.