#' Use a generic R model to infill and project data
#'
#' `predict_general_mdl()` uses a general model object from R to fit a model and
#' use that model to infill and project the dependent variable. It is flexible
#' to allow for many general models to be used through the function. However,
#' they need to fit certain criteria:
#' * The model accepts `formula` and `data` arguments. All other arguments can be
#' passed anonymously through `...`.
#' * The returned object passed through [stats::family()] returns an inverse
#' link function in its list as `linkinv`.
#' * Must have a `predict.model` generic that accepts the `se.fit = TRUE` argument
#' and returns confidence intervals.
#'
#' As example, [stats::lm()] and [stats::glm()] fit the
#' above criteria and convenient wrappers for those models are
#' provided in augury, but additional model functions can be used in
#' `predict_general_mdl()` if they fit the criteria.
#'
#' The function also allows for inputting of data type and source information
#' directly into the data frame if the `type_col` and `source_col` are specified
#' respectively.
#'
#' @param df Data frame of model data.
#' @param model An R function that outputs a model object with a `predict.model` generic,
#' where [stats::family()] contains an inverse link function `linkinv` and
#' `predict.model()` accepts the `se.fit = TRUE` argument and returns confidence
#' intervals. This includes [stats::lm], [stats::glm], and [lme4::lmer].
#' @param formula A formula that will be supplied to the model, such as `y~x`.
#' @param ... Other arguments passed to the model function.
#' @param ret Character vector specifying what values the function returns. Defaults
#' to returning a data frame, but can return a vector of model error, the
#' model itself or a list with all 3 as components.
#' @param scale Either `NULL` or a numeric value. If a numeric value is provided,
#' the response variable is scaled by the value passed to scale prior to model
#' fitting and prior to any probit transformation, so can be used to put the
#' response onto a 0 to 1 scale. Scaling is done by dividing the response by
#' the scale and using the [scale_transform()] function. The response, as well
#' as the fitted values and confidence bounds are unscaled prior to error
#' calculation and returning to the user.
#' @param probit Logical value on whether or not to probit transform the response
#' prior to model fitting. Probit transformation is performed after any scaling
#' determined by `scale` but prior to model fitting. The response, as well as
#' the fitted values and confidence bounds are untransformed prior to error
#' calculation and returning to the user.
#' @param test_col Name of logical column specifying which response values to remove
#' for testing the model's predictive accuracy. If `NULL`, ignored. See [model_error()]
#' for details on the methods and metrics returned.
#' @param test_period Length of period to test for RMChE. If `NULL`, beginning and end
#' points of each group in `group_col` are compared. Otherwise, `test_period` must
#' be set to an integer `n` and for each group, comparisons are made between
#' the end point and `n` periods prior.
#' @param test_period_flex Logical value indicating if `test_period` is less than
#' the full length of the series, should change error still be calculated for that
#' point. Defaults to `FALSE`.
#' @param group_col Column name(s) of group(s) to use in [dplyr::group_by()] when
#' supplying type, calculating mean absolute scaled error on data involving
#' time series, and if `group_models`, then fitting and predicting models too.
#' If `NULL`, not used. Defaults to `"iso3"`.
#' @param group_models Logical, if `TRUE`, fits and predicts models individually onto
#' each `group_col`. If `FALSE`, a general model is fit across the entire data
#' frame.
#' @param obs_filter String value of the form "`logical operator` `integer`"
#' that specifies the number of observations required to fit the model and
#' replace observations with predicted values. This is done in
#' conjunction with `group_col`. So, if `group_col = "iso3"` and
#' `obs_filter = ">= 5"`, then for this model, predictions will only be used
#' for `iso3` vales that have 5 or more observations. Possible logical operators
#' to use are `>`, `>=`, `<`, `<=`, `==`, and `!=`.
#'
#' If `group_models = FALSE`, then `obs_filter` is only used to determine when
#' predicted values replace observed values but **is not** used to restrict values
#' from being used in model fitting. If `group_models = TRUE`, then a model
#' is only fit for a group if they meet the `obs_filter` requirements. This provides
#' speed benefits, particularly when running INLA time series using `predict_inla()`.
#' @param sort_col Column name(s) to use to [dplyr::arrange()] the data prior to
#' supplying type and calculating mean absolute scaled error on data involving
#' time series. If `NULL`, not used. Defaults to `"year"`.
#' @param sort_descending Logical value on whether the sorted values from `sort_col`
#' should be sorted in descending order. Defaults to `FALSE`.
#' @param pred_col Column name to store predicted value.
#' @param pred_upper_col Column name to store upper bound of confidence interval
#' generated by the `predict_...` function. This stores the full set of generated
#' values for the upper bound.
#' @param pred_lower_col Column name to store lower bound of confidence interval
#' generated by the `predict_...` function. This stores the full set of generated
#' values for the lower bound.
#' @param upper_col Column name that contains upper bound information, including
#' upper bound of the input data to the model. Values from `pred_upper_col`
#' are put into this column in the exact same way the response is filled by `pred`
#' based on `replace_na` (only when there is a missing value in the response).
#' @param lower_col Column name that contains lower bound information, including
#' lower bound of the input data to the model. Values from `pred_lower_col`
#' are put into this column in the exact same way the response is filled by `pred`
#' based on `replace_na` (only when there is a missing value in the response).
#' @param filter_na Character value specifying how, if at all, to filter `NA`
#' values from the dataset prior to applying the model. By default, all
#' observations with missing values are removed, although it can also remove
#' rows only if they have missing dependent or independent variables, or no
#' filtering at all.
#' @param type_col Column name specifying data type.
#' @param types Vector of length 3 that provides the type to provide to data
#' produced in the model. These values are only used to fill in type values
#' where the dependent variable is missing. The first value is given to missing
#' observations that precede the first observation, the second to those after
#' the last observation, and the third for those following the final observation.
#' @param source_col Column name containing source information for the data frame.
#' If provided, the argument in `source` is used to fill in where predictions
#' have filled in missing data.
#' @param source Source to add to missing values.
#' @param scenario_detail_col Column name containing scenario_detail information
#' for the data frame. If provided, the argument in `scenario_detail` is used
#' to fill in where prediction shave filled in missing data.
#' @param scenario_detail Scenario details to add to missing values (usually the
#' name of the model being used to generate the projection, optionally with
#' relevant parameters).
#' @param replace_obs Character value specifying how, if at all, observations should
#' be replaced by fitted values. Defaults to replacing only missing values,
#' but can be used to replace all values or none.
#' @param error_correct Logical value indicating whether or not whether mean error
#' should be used to adjust predicted values. If `TRUE`, the mean error between
#' observed and predicted data points will be used to adjust predictions. If
#' `error_correct_cols` is not `NULL`, mean error will be used within those
#' groups instead of overall mean error.
#' @param error_correct_cols Column names of data frame to group by when applying
#' error correction to the predicted values.
#' @param shift_trend Logical value specifying whether or not to shift predictions
#' so that the trend matches up to the last observation. If `error_correct` and
#' `shift_trend` are both `TRUE`, `shift_trend` takes precedence.
#'
#' @return Depending on the value passed to `ret`, either a data frame with
#' predicted data, a vector of errors from [model_error()], a fitted model, or a list with all 3.
#'
#' @export
predict_general_mdl <- function(df,
model,
formula,
...,
ret = c("df", "all", "error", "model"),
scale = NULL,
probit = FALSE,
test_col = NULL,
test_period = NULL,
test_period_flex = NULL,
group_col = "iso3",
group_models = FALSE,
obs_filter = NULL,
sort_col = "year",
sort_descending = FALSE,
pred_col = "pred",
pred_upper_col = "pred_upper",
pred_lower_col = "pred_lower",
upper_col = "upper",
lower_col = "lower",
filter_na = c("all", "response", "predictors", "none"),
type_col = NULL,
types = c("imputed", "imputed", "projected"),
source_col = NULL,
source = NULL,
scenario_detail_col = NULL,
scenario_detail = NULL,
replace_obs = c("missing", "all", "none"),
error_correct = FALSE,
error_correct_cols = NULL,
shift_trend = FALSE) {
# Assertions and error checking
df <- assert_df(df)
assert_model(model)
formula_vars <- parse_formula(formula)
assert_columns(df, formula_vars, test_col, group_col, sort_col, type_col, source_col)
assert_group_models(group_col, group_models)
response <- formula_vars[1]
assert_columns_unique(response, pred_col, pred_upper_col, pred_lower_col, upper_col, lower_col, test_col, group_col, sort_col, type_col, source_col)
ret <- rlang::arg_match(ret)
assert_test_col(df, test_col)
assert_string(pred_col, 1)
assert_string(pred_upper_col, 1)
assert_string(pred_lower_col, 1)
assert_string(upper_col, 1)
assert_string(lower_col, 1)
filter_na <- rlang::arg_match(filter_na)
assert_string(types, 3)
assert_string(source, 1)
replace_obs <- rlang::arg_match(replace_obs)
obs_filter <- parse_obs_filter(obs_filter, response)
# Scale response variable
if (!is.null(scale)) {
df <- scale_transform(df, formula_vars[1], scale = scale)
}
# Transform response variable to probit space
if (probit) {
df <- probit_transform(df, formula_vars[1])
}
mdl_df <- fit_general_model(df = df,
model = model,
formula = formula,
...,
formula_vars = formula_vars,
test_col = test_col,
group_col = group_col,
group_models = group_models,
obs_filter = obs_filter,
sort_col = sort_col,
sort_descending = sort_descending,
pred_col = pred_col,
pred_upper_col = pred_upper_col,
pred_lower_col = pred_lower_col,
filter_na = filter_na,
ret = ret,
error_correct = error_correct,
error_correct_cols = error_correct_cols,
shift_trend = shift_trend)
mdl <- mdl_df[["mdl"]]
df <- mdl_df[["df"]]
# Return model now
if (ret == "mdl") {
return(mdl)
}
# Untransform variables
if (probit) {
df <- probit_transform(df,
c(formula_vars[1],
pred_col,
pred_upper_col,
pred_lower_col),
inverse = TRUE)
}
# Unscale variables
if (!is.null(scale)) {
df <- scale_transform(df,
c(formula_vars[1],
pred_col,
pred_upper_col,
pred_lower_col),
scale = scale,
divide = FALSE)
}
# get error if being returned
if (ret %in% c("all", "error")) {
err <- model_error(df = df,
response = formula_vars[1],
test_col = test_col,
test_period = test_period,
test_period_flex = test_period_flex,
group_col = group_col,
sort_col = sort_col,
sort_descending = sort_descending,
pred_col = pred_col,
pred_upper_col = pred_upper_col,
pred_lower_col = pred_lower_col)
if (ret == "error") {
return(err)
}
}
# Merge predictions into observations
df <- merge_prediction(df = df,
response = formula_vars[1],
group_col = group_col,
obs_filter = obs_filter,
sort_col = sort_col,
sort_descending = sort_descending,
pred_col = pred_col,
pred_upper_col = pred_upper_col,
pred_lower_col = pred_lower_col,
upper_col = upper_col,
lower_col = lower_col,
type_col = type_col,
types = types,
source_col = source_col,
source = source,
scenario_detail_col = scenario_detail_col,
scenario_detail = scenario_detail,
replace_obs = replace_obs)
if (ret == "df") {
return(df)
} else if (ret == "all") {
list(df = df,
error = err,
model = mdl)
}
}
#' Generate prediction from model object
#'
#' `predict_general_data()` generates a prediction vector from a model object and full
#' data frame, putting this prediction back into the data frame.
#'
#' @inheritParams predict_general_mdl
#' @return A data frame.
predict_general_data <- function(df,
model,
pred_col,
pred_upper_col,
pred_lower_col) {
inv_link <- stats::family(model)[["linkinv"]]
pred <- stats::predict(model, newdata = df, se.fit = TRUE)
x <- pred[["fit"]]
se <- pred[["se.fit"]]
df[[pred_col]] <- inv_link(x)
df[[pred_upper_col]] <- inv_link(x + 2 * se)
df[[pred_lower_col]] <- inv_link(x - 2 * se)
df
}
#' Fit general model to data
#'
#' Used within `predict_general_mdl()`, this function fits the model to the data
#' frame, workingw hether the model is being fit across the entire data frame or
#' being fit to each group individually. Data is filtered prior to fitting,
#' model(s) are fit, and then fitted values are generated on the original.
#'
#' If fitting models individually to each group, `mdl` will never be returned, as
#' as these are instead a large group of models. Otherwise, a list of `mdl` and `df`
#' is returned and used within `predict_general_mdl()`.
#'
#' @inheritParams predict_general_mdl
#' @param formula_vars Variables included in the model formula, generated by
#' `all.vars(formula)`.
#'
#' @return List of `mdl` (fitted model) and `df` (data frame with fitted values
#' and confidence bounds generated from the model).
fit_general_model <- function(df,
model,
formula,
...,
formula_vars,
test_col,
group_col,
group_models,
obs_filter,
sort_col,
sort_descending,
pred_col,
pred_upper_col,
pred_lower_col,
filter_na,
ret,
error_correct,
error_correct_cols,
shift_trend) {
# Filter data for modeling
if (!group_models) group_col_mdl <- NULL else group_col_mdl <- group_col
data <- get_model_data(df = df,
formula_vars = formula_vars,
test_col = test_col,
group_col = group_col_mdl,
filter_na = filter_na)
if (group_models) {
# Split data frames
data <- dplyr::group_by(data, .data[[group_col]]) %>%
dplyr::group_split()
df <- dplyr::group_by(df, .data[[group_col]]) %>%
dplyr::group_split()
# build and apply models
df <- purrr::map2_dfr(data, df, function(x, y) {
obs_check <- dplyr::filter(y, eval(parse(text = obs_filter)))
if (nrow(obs_check) == 0) {
mdl <- model(formula = formula,
data = x,
...)
predict_general_data(df = y,
model = mdl,
pred_col = pred_col,
pred_upper_col = pred_upper_col,
pred_lower_col = pred_lower_col)
} else {
y
}
})
# in case no models fit
df <- augury_add_columns(df, c(pred_col, pred_lower_col, pred_upper_col))
mdl <- NULL # not returning all models together for grouped models
} else { # single model fitting
mdl <- model(formula = formula,
data = data,
...)
if (ret == "mdl") {
df <- NULL
} else {
df <- dplyr::group_by(df, dplyr::across(group_col)) %>%
dplyr::mutate("augury_temp_obs_check" := eval(parse(text = obs_filter))) %>%
dplyr::group_by(.data[["augury_temp_obs_check"]]) %>%
dplyr::group_modify(function(x, ...) {
if (!unique(x[["augury_temp_obs_check"]])) {
x <- predict_general_data(df = x,
model = mdl,
pred_col = pred_col,
pred_upper_col = pred_upper_col,
pred_lower_col = pred_lower_col)
}
dplyr::select(x, -"augury_temp_obs_check")
},
.keep = TRUE) %>%
dplyr::ungroup() %>%
dplyr::select(-"augury_temp_obs_check")
}
}
# use error correction if applicable
if (ret != "mdl") {
df <- error_correct_fn(df = df,
response = formula_vars[1],
group_col = group_col,
sort_col = sort_col,
sort_descending = sort_descending,
pred_col = pred_col,
pred_upper_col = pred_upper_col,
pred_lower_col = pred_lower_col,
test_col = test_col,
error_correct = error_correct,
error_correct_cols = error_correct_cols,
shift_trend = shift_trend)
}
list(df = df, mdl = mdl)
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.