R/predict_forecast.R

Defines functions fit_forecast_model forecast_series trim_series get_forecast_data get_forecast_bounds predict_forecast_data predict_forecast

Documented in fit_forecast_model forecast_series get_forecast_data predict_forecast predict_forecast_data trim_series

#' Use a time series model to infill and project data
#'
#' `predict_forecast()` uses the forecast package's [forecast::forecast()] methods
#' to generate predictions on time series data. These use the longest contiguous
#' observed values to forecast out a certain number periods. This function
#' automatically detects the latest observed values and the number of missing
#' values to forecast, and runs the provided forecasting function on the
#' observed data series.
#'
#' @param forecast_function An R function that outputs a forecast object coming from the
#'     forecast package. You can directly pass [forecast::forecast()] to the
#'     function, or you can pass other wrappers to it such as [forecast::holt()] or
#'     [forecast::ses()].
#' @param response Column name of response variable to be used as the input to the
#'     forecast function.
#' @param sort_col Column name of column to arrange data by in `dplyr::arrange()`,
#'     prior to filtering for latest contiguous time series and producing the
#'     forecast. Not used if `NULL`, defaults to `"year"`.
#' @param ... Additional arguments passed to the forecast function.
#'
#' @inherit predict_general_mdl params return
#'
#' @export
predict_forecast <- function(df,
                             forecast_function,
                             response,
                             ...,
                             ret = c("df", "all", "error", "model"),
                             scale = NULL,
                             probit = FALSE,
                             test_col = NULL,
                             test_period = NULL,
                             test_period_flex = NULL,
                             group_col = "iso3",
                             group_models = TRUE,
                             obs_filter = NULL,
                             sort_col = "year",
                             sort_descending = FALSE,
                             pred_col = "pred",
                             pred_upper_col = "pred_upper",
                             pred_lower_col = "pred_lower",
                             upper_col = "upper",
                             lower_col = "lower",
                             filter_na = c("all", "response", "predictors", "none"),
                             type_col = NULL,
                             types = "projected",
                             source_col = NULL,
                             source = NULL,
                             scenario_detail_col = NULL,
                             scenario_detail = NULL,
                             replace_obs = c("missing", "all", "none")) {
  # Assertions and error checking
  df <- assert_df(df)
  assert_function(forecast_function)
  assert_columns(df, response, test_col, group_col, sort_col, type_col, source_col)
  assert_columns_unique(response, pred_col, pred_upper_col, pred_lower_col, upper_col, lower_col, test_col, group_col, sort_col, type_col, source_col)
  assert_group_models(group_col, group_models)
  ret <- rlang::arg_match(ret)
  assert_test_col(df, test_col)
  assert_string(pred_col, 1)
  assert_string(upper_col, 1)
  assert_string(lower_col, 1)
  assert_string(pred_lower_col, 1)
  assert_string(pred_upper_col, 1)
  filter_na <- rlang::arg_match(filter_na)
  assert_string(types, 1)
  assert_string(source, 1)
  replace_obs <- rlang::arg_match(replace_obs)
  obs_filter <- parse_obs_filter(obs_filter, response)

  if (!is.null(scale)) {
    df <- scale_transform(df, response, scale = scale)
  }

  if (probit) {
    df <- probit_transform(df, response)
  }

  mdl_df <- fit_forecast_model(df = df,
                               forecast_function = forecast_function,
                               response = response,
                               ...,
                               test_col = test_col,
                               group_col = group_col,
                               group_models = group_models,
                               obs_filter = obs_filter,
                               sort_col = sort_col,
                               sort_descending = sort_descending,
                               pred_col = pred_col,
                               pred_upper_col = pred_upper_col,
                               pred_lower_col = pred_lower_col,
                               filter_na = filter_na,
                               ret = ret)

  mdl <- mdl_df[["mdl"]]
  df <- mdl_df[["df"]]

  if (ret == "model") {
    return(mdl)
  }

  # Untransform variables
  if (probit) {
    df <- probit_transform(df,
                           c(response,
                             pred_col,
                             pred_upper_col,
                             pred_lower_col),
                           inverse = TRUE)
  }

  # Unscale variables
  if (!is.null(scale)) {
    df <- scale_transform(df,
                          c(response,
                            pred_col,
                            pred_upper_col,
                            pred_lower_col),
                          scale = scale,
                          divide = FALSE)
  }

  # Get error if being returned
  if (ret %in% c("all", "error")) {
    err <- model_error(df = df,
                       response = response,
                       test_col = test_col,
                       test_period = test_period,
                       test_period_flex = test_period_flex,
                       group_col = group_col,
                       sort_col = sort_col,
                       sort_descending,
                       pred_col = pred_col,
                       pred_upper_col = pred_upper_col,
                       pred_lower_col = pred_lower_col)

    if (ret == "error") {
      return(err)
    }
  }

  # Merge predictions into observations
  df <- merge_prediction(df = df,
                         response = response,
                         group_col = group_col,
                         obs_filter = obs_filter,
                         sort_col = sort_col,
                         sort_descending = sort_descending,
                         pred_col = pred_col,
                         pred_upper_col = pred_upper_col,
                         pred_lower_col = pred_lower_col,
                         upper_col = upper_col,
                         lower_col = lower_col,
                         type_col = type_col,
                         types = c(NA_character_, NA_character_, types),
                         source_col = source_col,
                         source = source,
                         scenario_detail_col = scenario_detail_col,
                         scenario_detail = scenario_detail,
                         replace_obs = replace_obs)

  if (ret == "df") {
    return(df)
  } else if (ret == "all") {
    list(df = df,
         error = err,
         model = mdl)
  }
}

#' Generate prediction from model object
#'
#' `predict_forecast_data()` generates a prediction vector from a forecast object
#' and full data frame, putting this prediction back into the data frame.
#'
#' @inheritParams predict_general_mdl
#' @param forecast_obj Object of class `forecast` that is output from the `forecast::`
#'     family of functions.
#'
#' @return A data frame.
predict_forecast_data <- function(df,
                                  forecast_obj,
                                  sort_col,
                                  sort_descending,
                                  pred_col,
                                  pred_upper_col,
                                  pred_lower_col) {
  if (!is.null(sort_col)) {
    if (sort_descending) {
      fn <- dplyr::desc
    } else {
      fn <- I
    }
    df <- dplyr::arrange(df, dplyr::across(dplyr::all_of(sort_col), fn), .by_group = TRUE)
  }

  x <- as.numeric(forecast_obj[["mean"]])
  x_len <- length(x)
  na_len <- nrow(df) - x_len # fill in NA for "pred" prior to the forecast
  df[[pred_col]] <- c(rep(NA_real_, na_len), x)
  df[[pred_upper_col]] <- c(rep(NA_real_, na_len), get_forecast_bounds(forecast_obj, "upper"))
  df[[pred_lower_col]] <- c(rep(NA_real_, na_len), get_forecast_bounds(forecast_obj, "lower"))
  df
}

#' @noRd
get_forecast_bounds <- function(x, bound) {
  df <- as.data.frame(x[[bound]])
  df[["95%"]]
}

#' Get data for forecast models
#'
#' Keep only the latest contiguous time series, dropping all other NA values
#' from the response variable. Removes test column variables first.
#'
#' @inheritParams predict_forecast
#'
#' @return A data series.
get_forecast_data <- function(df,
                              response,
                              sort_col,
                              sort_descending,
                              test_col) {
  if (!is.null(sort_col)) {
    if (sort_descending) {
      fn <- dplyr::desc
    } else {
      fn <- I
    }
    df <- dplyr::arrange(df, dplyr::across(dplyr::all_of(sort_col), fn), .by_group = TRUE)
  }

  if (!is.null(test_col)) {
    df[[response]][df[[test_col]]] <- NA_real_
  }

  trim_series(df[[response]])
}


#' Get latest data for forecasting
#'
#' Gets latest data for forecasting. It also gets the number of missing data
#' points to forecast.
#'
#' @param x Data series to reduce for forecasting
#'
#' @return Series with contiguous observations followed by NA values to forecast.
trim_series <- function(x) {
  na_x <- is.na(x)
  last_obs <- max(which(!na_x)) # latest observation
  missing <- which(na_x)
  start_from <- max(missing[missing < last_obs], -Inf) + 1 # find start of contiguous series
  if (is.infinite(start_from)) start_from <- 1
  x[start_from:length(x)]
}

#' Forecast data series
#'
#' Using series coming from `trim_series()`, it uses latest observed values
#' to forecast missing values.
#'
#' @param x Series to forecast, coming from `trim_series()`
#' @inheritParams predict_forecast
#'
#' @return Forecast model.
forecast_series <- function(x,
                            forecast_function,
                            ...) {
  na_x <- is.na(x)
  h <- sum(na_x)
  assert_h(h)
  x <- x[!na_x]
  forecast_function(x,
                    h = h,
                    ...)
}

#' Fit forecast model to data
#'
#' Used within `predict_forecast()`, this function fits the model to the data
#' frame, working whether the model is being fit across the entire data frame or
#' being fit to each group individually. Data is filtered prior to fitting,
#' model(s) are fit, and then fitted values are generated on the original.
#'
#' If fitting models individually to each group, `mdl` will never be returned, as
#' as these are instead a large list of models. Otherwise, a list of `mdl` and `df`
#' is returned and used within `predict_inla()`.
#'
#' @inheritParams predict_forecast
#' @inheritParams fit_general_model
#'
#' @return List of `mdl` (fitted model) and `df` (data frame with fitted values
#'     and confidence bounds generated from the model).
fit_forecast_model <- function(df,
                               forecast_function,
                               response,
                               ...,
                               test_col,
                               group_col,
                               group_models,
                               obs_filter,
                               sort_col,
                               sort_descending,
                               pred_col,
                               pred_upper_col,
                               pred_lower_col,
                               filter_na,
                               ret) {

  if (!group_models) {
    # Filter data for modeling
    x <- get_forecast_data(df = df,
                           response = response,
                           sort_col = sort_col,
                           sort_descending = sort_descending,
                           test_col = test_col)

    # Build model
    mdl <- forecast_series(x,
                           forecast_function,
                           ...)

    if (ret == "model") {
      df <- NULL
    } else {
      # Get model predictions
      df <- predict_forecast_data(df = df,
                                  forecast_obj = mdl,
                                  sort_col = sort_col,
                                  sort_descending = sort_descending,
                                  pred_col = pred_col,
                                  pred_upper_col = pred_upper_col,
                                  pred_lower_col = pred_lower_col)
    }

  } else {
    # map by group

    df_list <- dplyr::group_by(df, dplyr::across(dplyr::all_of(group_col))) %>%
      dplyr::group_split()

    df <- purrr::map_dfr(df_list, function(df) {
      obs_check <- dplyr::filter(df, eval(parse(text = obs_filter)))
      if (nrow(obs_check) == 0) {
        x <- get_forecast_data(df = df,
                               response = response,
                               sort_col = sort_col,
                               sort_descending = sort_descending,
                               test_col = test_col)
        mdl <- forecast_series(x,
                               forecast_function,
                               ...)
        predict_forecast_data(df = df,
                              forecast_obj = mdl,
                              sort_col = sort_col,
                              sort_descending = sort_descending,
                              pred_col = pred_col,
                              pred_upper_col = pred_upper_col,
                              pred_lower_col = pred_lower_col)
      } else {
        augury_add_columns(df, c(pred_col, pred_upper_col, pred_lower_col))
      }
    })

    mdl <- NULL
  }
  list(df = df, mdl = mdl)
}
caldwellst/augury documentation built on Oct. 10, 2024, 8:20 a.m.