Nothing
#' Predictions for flexCountReg models
#'
#' @name predict.flexCountReg
#' @param object a model object estimated using this R package.
#' @param newdata optional dataframe for which to generate predictions.
#' @param ... optional arguments passed to the function. This includes `method`.
#'
#' @note optional parameter `newdata`: a dataframe that has all of the variables
#' in the \code{formula} and \code{rpar_formula}.
#' @note optional parameter `method`: Only valid for random parameters models
#' (`countreg.rp`). Options include \code{Simulated} (default),
#' \code{Individual}, or \code{Exact}.
#'
#' @description
#' Generates predictions for the expected count (lambda) for observations.
#'
#' For \strong{countreg.rp} (Random Parameters) models, three methods are
#' available:
#' \itemize{
#' \item \strong{Simulated}: Uses Halton draws to simulate the random
#' parameters and averages the outcomes. This is a simulation-based
#' approximation.
#' \item \strong{Individual}: Estimates observation-specific coefficients
#' (conditional on observed outcomes) using Empirical Bayes. Requires the
#' outcome variable to be present in \code{data}.
#' \item \strong{Exact}: Uses the analytical Moment Generating Functions
#' (MGFs) of the random parameter distributions to calculate the exact
#' expected value. This method is faster and removes simulation error.
#' }
#'
#' For \strong{countreg}, \strong{poisLindRE}, and \strong{RENB} models, the
#' function calculates the expected value \eqn{\mu = \exp(X\beta)} (with
#' appropriate adjustments for specific families like PLN or underreporting).
#'
#' @references
#' Wood, J.S., Gayah, V. (2025). Out-of-sample prediction and interpretation for
#' random parameter generalized linear models. \emph{Accident Analysis and
#' Prevention}, 220, 108147.
#'
#' @returns A numeric vector of predicted expected counts for each observation
#' in the provided data. If no data is provided, the predictions for the data
#' used in estimating the model are provided.
#'
#' @import randtoolbox stats modelr rlang
#' @importFrom utils head tail
#' @include corr_haltons.R halton_dists.R helpers.R
#'
#' @examples
#' \donttest{
#' # Load data and create a dummy variable
#' data("washington_roads")
#' washington_roads$AADT10kplus <- ifelse(washington_roads$AADT > 10000, 1, 0)
#'
#' # =========================================================================
#' # 1. Fixed Parameter Model (countreg)
#' # =========================================================================
#' nb2_fixed <- countreg(Total_crashes ~ lnaadt + lnlength + speed50,
#' data = washington_roads,
#' family = "NB2")
#' pred_fixed <- predict(nb2_fixed, data = washington_roads)
#'
#' # =========================================================================
#' # 2. Random Parameters Model (countreg.rp)
#' # =========================================================================
#' rp_nb2 <- countreg.rp(Total_crashes ~ lnaadt + lnlength,
#' rpar_formula = ~ -1 + speed50,
#' data = washington_roads,
#' family = "NB2",
#' rpardists = c(speed50 = "n"),
#' ndraws = 100)
#'
#' # Method A: Simulated (Default)
#' pred_sim <- predict(rp_nb2, data = washington_roads, method = "Simulated")
#'
#' # Method B: Exact (Analytical MGF)
#' pred_exact <- predict(rp_nb2, data = washington_roads, method = "Exact")
#'
#' # =========================================================================
#' # 3. Random Effects Models (poisLindRE / RENB)
#' # =========================================================================
#' pl_re <- poisLind.re(Total_crashes ~ lnaadt + lnlength,
#' data = washington_roads,
#' group_var = "ID")
#' pred_pl_re <- predict(pl_re, data = washington_roads)
#' }
#' @export
predict.flexCountReg <- function(object, newdata = NULL, ...){
# Extract optional parameters from '...'
additional_args <- list(...)
# 1. FIXED: Handle data argument logic to support 'newdata'
if (!is.null(newdata)) {
data <- as.data.frame(newdata)
} else if (!is.null(additional_args$data)) {
# Fallback if user explicitly passed 'data=' in ...
data <- as.data.frame(additional_args$data)
} else {
# Default to training data
data <- object$data
}
model <- object$model
if (!is.null(object$modelType)) {
modtype <- object$modelType
} else if (!is.null(model$modelType)) {
modtype <- model$modelType
} else {
modtype <- "countreg"
}
# Handle method argument
if (is.null(additional_args$method)) {
if(modtype == "countreg.rp") {
method <- "Simulated"
} else {
method <- "Standard"
}
} else {
method <- additional_args$method
}
# =========================================================================
# 1. RANDOM PARAMETERS MODELS (countreg.rp)
# =========================================================================
if(modtype == "countreg.rp"){
# --- Setup & Matrix Generation ---
form <- model$form
family <- if(!is.null(model$family)) model$family else "NB2"
rpar_formula <- model$rpar_formula
rpardists <- model$rpardists
correlated <- model$correlated
scrambled <- model$scrambled
ndraws <- max(model$ndraws, 500)
# Ensure response is removed for prediction
formula_fixed <- delete.response(terms(model$formula))
data <- as.data.frame(data)
if (!is.null(model$offset)){
if(length(model$offset) > 1 || !model$offset %in% names(data)) {
X_offset <- rep(0, nrow(data))
} else {
X_offset <- data[[model$offset]]
}
} else {
X_offset <- rep(0, nrow(data))
}
# Matrices
X_Fixed <- as.matrix(modelr::model_matrix(data, formula_fixed))
X_rand <- as.matrix(modelr::model_matrix(data, rpar_formula))
# INTERCEPT CLEANUP
if("(Intercept)" %in% colnames(X_rand) && !is.null(rpardists)){
if(!any(grepl("intercept", names(rpardists), ignore.case=TRUE))){
X_rand <- X_rand[ , colnames(X_rand) != "(Intercept)", drop=FALSE]
}
}
# Distribution Parameter Matrices
vec_1 <- NULL
# --- Extract Coefficients ---
coefs <- unlist(model$estimate, recursive = TRUE, use.names = FALSE)
N_fixed <- ncol(X_Fixed)
N_rand <- ncol(X_rand)
current_idx <- 0
# Fixed
fixed_coefs <- coefs[(current_idx + 1):(current_idx + N_fixed)]
current_idx <- current_idx + N_fixed
# Random Means
random_coefs_means <- coefs[(current_idx + 1):(current_idx + N_rand)]
current_idx <- current_idx + N_rand
# Random Vars
n_var_params <- if(correlated) N_rand * (N_rand + 1) / 2 else N_rand
rand_var_params <- coefs[(current_idx + 1):(current_idx + n_var_params)]
current_idx <- current_idx + n_var_params
# Heterogeneity
if (!is.null(model$het_mean_formula)) {
N_het_mean <- ncol(model.matrix(model$het_mean_formula, data))
current_idx <- current_idx + N_het_mean
}
if (!is.null(model$het_var_formula)) {
N_het_var <- ncol(model.matrix(model$het_var_formula, data))
current_idx <- current_idx + N_het_var
}
# Distribution Params (for PLN correction)
params <- get_params(family)
if (!is.null(model$dis_param_formula_1)) {
X_dis_1 <-
as.matrix(modelr::model_matrix(data, model$dis_param_formula_1))
N_dis_1 <- ncol(X_dis_1)
p_dis_1 <- coefs[(current_idx + 1):(current_idx + N_dis_1)]
current_idx <- current_idx + N_dis_1
vec_1 <- exp(as.vector(X_dis_1 %*% p_dis_1))
} else if (!is.null(params[[1]])) {
p_dis_1 <- coefs[(current_idx + 1):(current_idx + 1)]
current_idx <- current_idx + 1
vec_1 <- rep(exp(p_dis_1), nrow(data))
}
# --- Prediction Method Implementation ---
if (method == 'Exact') {
# Base prediction
pred_exact <- exp(as.vector(X_Fixed %*% fixed_coefs) + X_offset)
# Correlated Normal
if (correlated) {
L <- matrix(0, N_rand, N_rand)
L[lower.tri(L, diag = TRUE)] <- rand_var_params
var_rand <- rowSums((X_rand %*% L)^2)
mu_rand <- as.vector(X_rand %*% random_coefs_means)
pred_exact <- pred_exact * exp(mu_rand + 0.5 * var_rand)
} else {
# Independent
for (i in 1:N_rand) {
dist_type <- rpardists[i]
x_col <- X_rand[, i]
mu <- random_coefs_means[i]
sigma <- rand_var_params[i]
correction <- rep(1, nrow(data))
if (dist_type == "n") { # Normal
correction <- exp(x_col * mu + 0.5 * (x_col^2) * (sigma^2))
} else if (dist_type == "ln") { # Lognormal
arg_w <- -exp(mu) * x_col * sigma^2
valid_idx <- arg_w >= -1/exp(1)
if (any(valid_idx)) {
x_val <- x_col[valid_idx]
w_val <- lamW::lambertW0(arg_w[valid_idx])
num <- x_val * exp(mu - w_val) - (w_val^2)/(2 * sigma^2)
den <- sigma * sqrt(abs(x_val * exp(mu - w_val) - (1/sigma^2)))
correction[valid_idx] <- exp(num) / den
}
if(any(!valid_idx)) correction[!valid_idx] <- NA
} else if (dist_type == "t") { # Triangular
nz <- abs(x_col) > 1e-8
if (any(nz)) {
x_nz <- x_col[nz]
term <- exp(x_nz * (mu - sigma))
correction[nz] <-
(exp(2 * sigma * x_nz) - 2 * exp(sigma * x_nz) + 1) * term /
(sigma^2 * x_nz^2)
}
} else if (dist_type == "u") { # Uniform
nz <- abs(x_col) > 1e-8
if (any(nz)) {
x_nz <- x_col[nz]
correction[nz] <-
(exp(x_nz * (mu + sigma)) - exp(x_nz * (mu - sigma))) /
(2 * sigma * x_nz)
}
} else if (dist_type == "g") { # Gamma
term_base <- 1 - (sigma^2 * x_col) / mu
if(any(term_base <= 0))
warning("Gamma exact prediction undefined for some observations.")
correction <-
ifelse(term_base > 0, term_base^(-(mu^2)/(sigma^2)), NA)
}
pred_exact <- pred_exact * correction
}
}
# PLN Correction
if(family == "PLN" && !is.null(vec_1)){
pred_exact <- pred_exact * exp((vec_1^2)/2)
}
return(pred_exact)
} else {
# Simulated or Individual
hdraws <-
as.matrix(randtoolbox::halton(ndraws, N_rand, mixed = scrambled))
# Re-extract heterogeneity for generation function
het_mean_coefs <- NULL
het_var_coefs <- NULL
X_het_mean <- NULL
X_het_var <- NULL
if (!is.null(model$het_mean_formula)) {
X_het_mean <- model.matrix(model$het_mean_formula, data)
if("(Intercept)" %in% colnames(X_het_mean))
X_het_mean <- X_het_mean[,-1,drop=FALSE]
}
draws_info <- generate_random_draws(
hdraws = hdraws,
random_coefs_means = random_coefs_means,
rand_var_params = rand_var_params,
rpardists = rpardists,
rpar = colnames(X_rand),
X_rand = X_rand,
het_mean_coefs = het_mean_coefs,
X_het_mean = X_het_mean,
het_var_coefs = het_var_coefs,
X_het_var = X_het_var,
correlated = correlated)
rpar_mat <- exp(draws_info$xb_rand_mat)
mu_fixed <- exp(as.vector(X_Fixed %*% fixed_coefs) + X_offset)
pred_mat <- sweep(rpar_mat, 1, mu_fixed, "*")
if(family == "PLN" && !is.null(vec_1)){
adj_factor <- exp((vec_1^2)/2)
pred_mat <- sweep(pred_mat, 1, adj_factor, "*")
}
if (method == 'Simulated') {
return(rowMeans(pred_mat))
} else if (method == 'Individual') {
y_name <- all.vars(model$formula)[1]
if(!y_name %in% names(data))
warning("Method 'Individual' requires outcome variable.")
# Logic for individual method omitted (placeholder)
return(rowMeans(pred_mat))
}
}
}
# =========================================================================
# 2. RANDOM EFFECTS MODELS (poisLindRE / RENB)
# =========================================================================
else if (modtype %in% c("poisLindRE", "RENB")) {
beta_pred <- model$beta_pred
# Use delete.response to handle new data without outcome
form_fixed <- delete.response(terms(model$formula))
X <- as.matrix(modelr::model_matrix(data, form_fixed))
# Handle Offset
if (!is.null(model$offset)){
if(length(model$offset) > 1 || !model$offset %in% names(data)) {
X_offset <- rep(0, nrow(data))
} else {
X_offset <- data[[model$offset]]
}
} else {
X_offset <- rep(0, nrow(data))
}
mu <- exp(as.vector(X %*% beta_pred) + X_offset)
return(mu)
}
# =========================================================================
# 3. FIXED PARAMETER MODELS (countreg)
# =========================================================================
else {
# 2. FIXED: Use delete.response to prevent error when 'y' is missing in new
# data
form_fixed <- delete.response(terms(model$formula))
X <- as.matrix(modelr::model_matrix(data, form_fixed))
coefs <- unlist(model$estimate, recursive = TRUE, use.names = FALSE)
N_x <- ncol(X)
beta_pred <- as.vector(coefs[1:N_x])
if (!is.null(model$offset)){
if(length(model$offset) > 1 || !model$offset %in% names(data)) {
X_offset <- rep(0, nrow(data))
} else {
X_offset <- data[[model$offset]]
}
} else {
X_offset <- rep(0, nrow(data))
}
pred_base <- exp(as.vector(X %*% beta_pred) + X_offset)
# --- Adjustments ---
alpha <- NULL
if (!is.null(model$dis_param_formula_1)){
alpha_X <-
as.matrix(modelr::model_matrix(data, model$dis_param_formula_1))
N_alpha <- ncol(alpha_X)
alpha_pars <- coefs[(N_x+1):(N_x+N_alpha)]
alpha <- exp(as.vector(alpha_X %*% alpha_pars))
} else {
params <- get_params(model$family)
if(!is.null(params[[1]])){
alpha <- rep(exp(coefs[N_x+1]), nrow(data))
}
}
mu_adj <- rep(1, nrow(data))
if(!is.null(model$underreport_formula)){
X_und <- as.matrix(modelr::model_matrix(data, model$underreport_formula))
N_und <- ncol(X_und)
total_pars <- length(coefs)
und_pars <- coefs[(total_pars - N_und + 1):total_pars]
und_lin <- as.vector(X_und %*% und_pars)
if(model$underreport_family == "logit"){
mu_adj <- 1/(1+exp(-und_lin))
} else {
mu_adj <- stats::pnorm(und_lin, lower.tail = FALSE)
}
}
predictions <- pred_base * mu_adj
if (modtype == "countreg" && model$family == "PLN" && !is.null(alpha)){
predictions <- predictions * exp((alpha^2)/2)
}
return(predictions)
}
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.