R/predict_flexCountReg.R

Defines functions predict.flexCountReg

Documented in predict.flexCountReg

#' Predictions for flexCountReg models
#'
#' @name predict.flexCountReg
#' @param object a model object estimated using this R package.
#' @param newdata optional dataframe for which to generate predictions.
#' @param ... optional arguments passed to the function. This includes `method`.
#' 
#' @note optional parameter `newdata`: a dataframe that has all of the variables
#'   in the \code{formula} and \code{rpar_formula}.
#' @note optional parameter `method`: Only valid for random parameters models
#'   (`countreg.rp`). Options include \code{Simulated} (default),
#'   \code{Individual}, or \code{Exact}.
#'
#' @description
#' Generates predictions for the expected count (lambda) for observations.
#' 
#' For \strong{countreg.rp} (Random Parameters) models, three methods are
#' available:
#' \itemize{
#'   \item \strong{Simulated}: Uses Halton draws to simulate the random
#'   parameters and averages the outcomes. This is a simulation-based
#'   approximation.
#'   \item \strong{Individual}: Estimates observation-specific coefficients
#'   (conditional on observed outcomes) using Empirical Bayes. Requires the
#'   outcome variable to be present in \code{data}.
#'   \item \strong{Exact}: Uses the analytical Moment Generating Functions
#'   (MGFs) of the random parameter distributions to calculate the exact
#'   expected value. This method is faster and removes simulation error.
#' }
#' 
#' For \strong{countreg}, \strong{poisLindRE}, and \strong{RENB} models, the
#' function calculates the expected value \eqn{\mu = \exp(X\beta)} (with
#' appropriate adjustments for specific families like PLN or underreporting).
#' 
#' @references
#' Wood, J.S., Gayah, V. (2025). Out-of-sample prediction and interpretation for
#' random parameter generalized linear models. \emph{Accident Analysis and
#' Prevention}, 220, 108147.
#' 
#' @returns A numeric vector of predicted expected counts for each observation 
#'  in the provided data. If no data is provided, the predictions for the data 
#'  used in estimating the model are provided.
#'
#' @import randtoolbox stats modelr rlang 
#' @importFrom utils head tail
#' @include corr_haltons.R halton_dists.R  helpers.R
#' 
#' @examples
#' \donttest{
#' # Load data and create a dummy variable
#' data("washington_roads")
#' washington_roads$AADT10kplus <- ifelse(washington_roads$AADT > 10000, 1, 0)
#' 
#' # =========================================================================
#' # 1. Fixed Parameter Model (countreg)
#' # =========================================================================
#' nb2_fixed <- countreg(Total_crashes ~ lnaadt + lnlength + speed50,
#'                       data = washington_roads, 
#'                       family = "NB2")
#' pred_fixed <- predict(nb2_fixed, data = washington_roads)
#' 
#' # =========================================================================
#' # 2. Random Parameters Model (countreg.rp)
#' # =========================================================================
#' rp_nb2 <- countreg.rp(Total_crashes ~ lnaadt + lnlength,
#'                       rpar_formula = ~ -1 + speed50,
#'                       data = washington_roads,
#'                       family = "NB2",
#'                       rpardists = c(speed50 = "n"),
#'                       ndraws = 100)
#' 
#' # Method A: Simulated (Default)
#' pred_sim <- predict(rp_nb2, data = washington_roads, method = "Simulated")
#' 
#' # Method B: Exact (Analytical MGF)
#' pred_exact <- predict(rp_nb2, data = washington_roads, method = "Exact")
#' 
#' # =========================================================================
#' # 3. Random Effects Models (poisLindRE / RENB)
#' # =========================================================================
#' pl_re <- poisLind.re(Total_crashes ~ lnaadt + lnlength,
#'                      data = washington_roads,
#'                      group_var = "ID")
#' pred_pl_re <- predict(pl_re, data = washington_roads)
#' }
#' @export
predict.flexCountReg <- function(object, newdata = NULL, ...){
  
  # Extract optional parameters from '...'
  additional_args  <- list(...)
  
  # 1. FIXED: Handle data argument logic to support 'newdata'
  if (!is.null(newdata)) {
    data <- as.data.frame(newdata)
  } else if (!is.null(additional_args$data)) {
    # Fallback if user explicitly passed 'data=' in ...
    data <- as.data.frame(additional_args$data)
  } else {
    # Default to training data
    data <- object$data 
  }
  
  model <- object$model
  
  if (!is.null(object$modelType)) {
    modtype <- object$modelType
  } else if (!is.null(model$modelType)) {
    modtype <- model$modelType
  } else {
    modtype <- "countreg" 
  }
  
  # Handle method argument
  if (is.null(additional_args$method)) {
    if(modtype == "countreg.rp") {
      method <- "Simulated"
    } else {
      method <- "Standard" 
    }
  } else {
    method <- additional_args$method
  }
  
  # =========================================================================
  # 1. RANDOM PARAMETERS MODELS (countreg.rp)
  # =========================================================================
  if(modtype == "countreg.rp"){
    
    # --- Setup & Matrix Generation ---
    form <- model$form
    family <- if(!is.null(model$family)) model$family else "NB2" 
    rpar_formula <- model$rpar_formula
    rpardists <- model$rpardists
    correlated <- model$correlated
    scrambled <- model$scrambled
    ndraws <- max(model$ndraws, 500)
    
    # Ensure response is removed for prediction
    formula_fixed <- delete.response(terms(model$formula))
    data <- as.data.frame(data)
    
    if (!is.null(model$offset)){
      if(length(model$offset) > 1 || !model$offset %in% names(data)) {
        X_offset <- rep(0, nrow(data))
      } else {
        X_offset <- data[[model$offset]]
      }
    } else {
      X_offset <- rep(0, nrow(data))
    }
    
    # Matrices
    X_Fixed <- as.matrix(modelr::model_matrix(data, formula_fixed))
    X_rand <- as.matrix(modelr::model_matrix(data, rpar_formula))
    
    # INTERCEPT CLEANUP
    if("(Intercept)" %in% colnames(X_rand) && !is.null(rpardists)){
      if(!any(grepl("intercept", names(rpardists), ignore.case=TRUE))){
        X_rand <- X_rand[ , colnames(X_rand) != "(Intercept)", drop=FALSE]
      }
    }
    
    # Distribution Parameter Matrices
    vec_1 <- NULL
    
    # --- Extract Coefficients ---
    coefs <- unlist(model$estimate, recursive = TRUE, use.names = FALSE)
    N_fixed <- ncol(X_Fixed)
    N_rand <- ncol(X_rand)
    
    current_idx <- 0
    
    # Fixed
    fixed_coefs <- coefs[(current_idx + 1):(current_idx + N_fixed)]
    current_idx <- current_idx + N_fixed
    
    # Random Means
    random_coefs_means <- coefs[(current_idx + 1):(current_idx + N_rand)]
    current_idx <- current_idx + N_rand
    
    # Random Vars
    n_var_params <- if(correlated) N_rand * (N_rand + 1) / 2 else N_rand
    rand_var_params <- coefs[(current_idx + 1):(current_idx + n_var_params)]
    current_idx <- current_idx + n_var_params
    
    # Heterogeneity
    if (!is.null(model$het_mean_formula)) {
      N_het_mean <- ncol(model.matrix(model$het_mean_formula, data)) 
      current_idx <- current_idx + N_het_mean
    }
    if (!is.null(model$het_var_formula)) {
      N_het_var <- ncol(model.matrix(model$het_var_formula, data))
      current_idx <- current_idx + N_het_var
    }
    
    # Distribution Params (for PLN correction)
    params <- get_params(family)
    if (!is.null(model$dis_param_formula_1)) {
      X_dis_1 <- 
        as.matrix(modelr::model_matrix(data, model$dis_param_formula_1))
      N_dis_1 <- ncol(X_dis_1)
      p_dis_1 <- coefs[(current_idx + 1):(current_idx + N_dis_1)]
      current_idx <- current_idx + N_dis_1
      vec_1 <- exp(as.vector(X_dis_1 %*% p_dis_1))
    } else if (!is.null(params[[1]])) {
      p_dis_1 <- coefs[(current_idx + 1):(current_idx + 1)]
      current_idx <- current_idx + 1
      vec_1 <- rep(exp(p_dis_1), nrow(data))
    }
    
    # --- Prediction Method Implementation ---
    
    if (method == 'Exact') {
      # Base prediction
      pred_exact <- exp(as.vector(X_Fixed %*% fixed_coefs) + X_offset)
      
      # Correlated Normal
      if (correlated) {
        L <- matrix(0, N_rand, N_rand)
        L[lower.tri(L, diag = TRUE)] <- rand_var_params
        var_rand <- rowSums((X_rand %*% L)^2)
        mu_rand <- as.vector(X_rand %*% random_coefs_means)
        pred_exact <- pred_exact * exp(mu_rand + 0.5 * var_rand)
        
      } else {
        # Independent
        for (i in 1:N_rand) {
          dist_type <- rpardists[i]
          x_col <- X_rand[, i]
          mu <- random_coefs_means[i]
          sigma <- rand_var_params[i]
          
          correction <- rep(1, nrow(data))
          
          if (dist_type == "n") { # Normal
            correction <- exp(x_col * mu + 0.5 * (x_col^2) * (sigma^2))
          } else if (dist_type == "ln") { # Lognormal
            arg_w <- -exp(mu) * x_col * sigma^2
            valid_idx <- arg_w >= -1/exp(1)
            if (any(valid_idx)) {
              x_val <- x_col[valid_idx]
              w_val <- lamW::lambertW0(arg_w[valid_idx])
              num <- x_val * exp(mu - w_val) - (w_val^2)/(2 * sigma^2)
              den <- sigma * sqrt(abs(x_val * exp(mu - w_val) - (1/sigma^2)))
              correction[valid_idx] <- exp(num) / den
            }
            if(any(!valid_idx)) correction[!valid_idx] <- NA
          } else if (dist_type == "t") { # Triangular
            nz <- abs(x_col) > 1e-8
            if (any(nz)) {
              x_nz <- x_col[nz]
              term <- exp(x_nz * (mu - sigma))
              correction[nz] <- 
                (exp(2 * sigma * x_nz) - 2 * exp(sigma * x_nz) + 1) * term / 
                (sigma^2 * x_nz^2)
            }
          } else if (dist_type == "u") { # Uniform
            nz <- abs(x_col) > 1e-8
            if (any(nz)) {
              x_nz <- x_col[nz]
              correction[nz] <- 
                (exp(x_nz * (mu + sigma)) - exp(x_nz * (mu - sigma))) / 
                (2 * sigma * x_nz)
            }
          } else if (dist_type == "g") { # Gamma
            term_base <- 1 - (sigma^2 * x_col) / mu
            if(any(term_base <= 0)) 
              warning("Gamma exact prediction undefined for some observations.")
            correction <- 
              ifelse(term_base > 0, term_base^(-(mu^2)/(sigma^2)), NA)
          }
          pred_exact <- pred_exact * correction
        }
      }
      
      # PLN Correction
      if(family == "PLN" && !is.null(vec_1)){
        pred_exact <- pred_exact * exp((vec_1^2)/2)
      }
      
      return(pred_exact)
      
    } else {
      # Simulated or Individual
      hdraws <- 
        as.matrix(randtoolbox::halton(ndraws, N_rand, mixed = scrambled))
      
      # Re-extract heterogeneity for generation function
      het_mean_coefs <- NULL
      het_var_coefs <- NULL
      X_het_mean <- NULL
      X_het_var <- NULL
      if (!is.null(model$het_mean_formula)) {
        X_het_mean <- model.matrix(model$het_mean_formula, data)
        if("(Intercept)" %in% colnames(X_het_mean)) 
          X_het_mean <- X_het_mean[,-1,drop=FALSE]
      }
      
      draws_info <- generate_random_draws(
        hdraws = hdraws, 
        random_coefs_means = random_coefs_means, 
        rand_var_params = rand_var_params, 
        rpardists = rpardists, 
        rpar = colnames(X_rand), 
        X_rand = X_rand,
        het_mean_coefs = het_mean_coefs, 
        X_het_mean = X_het_mean,
        het_var_coefs = het_var_coefs, 
        X_het_var = X_het_var,
        correlated = correlated)
      
      rpar_mat <- exp(draws_info$xb_rand_mat)
      mu_fixed <- exp(as.vector(X_Fixed %*% fixed_coefs) + X_offset)
      pred_mat <- sweep(rpar_mat, 1, mu_fixed, "*")
      
      if(family == "PLN" && !is.null(vec_1)){
        adj_factor <- exp((vec_1^2)/2)
        pred_mat <- sweep(pred_mat, 1, adj_factor, "*")
      }
      
      if (method == 'Simulated') {
        return(rowMeans(pred_mat))
      } else if (method == 'Individual') {
        y_name <- all.vars(model$formula)[1]
        if(!y_name %in% names(data)) 
          warning("Method 'Individual' requires outcome variable.")
        # Logic for individual method omitted (placeholder)
        return(rowMeans(pred_mat)) 
      }
    }
  }
  
  # =========================================================================
  # 2. RANDOM EFFECTS MODELS (poisLindRE / RENB)
  # =========================================================================
  else if (modtype %in% c("poisLindRE", "RENB")) {
    
    beta_pred <- model$beta_pred
    
    # Use delete.response to handle new data without outcome
    form_fixed <- delete.response(terms(model$formula))
    X <- as.matrix(modelr::model_matrix(data, form_fixed))
    
    # Handle Offset
    if (!is.null(model$offset)){
      if(length(model$offset) > 1 || !model$offset %in% names(data)) {
        X_offset <- rep(0, nrow(data))
      } else {
        X_offset <- data[[model$offset]]
      }
    } else {
      X_offset <- rep(0, nrow(data))
    }
    
    mu <- exp(as.vector(X %*% beta_pred) + X_offset)
    return(mu)
  }
  
  # =========================================================================
  # 3. FIXED PARAMETER MODELS (countreg)
  # =========================================================================
  else {
    
    # 2. FIXED: Use delete.response to prevent error when 'y' is missing in new
    # data
    form_fixed <- delete.response(terms(model$formula))
    X <- as.matrix(modelr::model_matrix(data, form_fixed))
    
    coefs <- unlist(model$estimate, recursive = TRUE, use.names = FALSE)
    N_x <- ncol(X)
    
    beta_pred <- as.vector(coefs[1:N_x])
    
    if (!is.null(model$offset)){
      if(length(model$offset) > 1 || !model$offset %in% names(data)) {
        X_offset <- rep(0, nrow(data))
      } else {
        X_offset <- data[[model$offset]]
      }
    } else {
      X_offset <- rep(0, nrow(data))
    }
    
    pred_base <- exp(as.vector(X %*% beta_pred) + X_offset)
    
    # --- Adjustments ---
    
    alpha <- NULL
    if (!is.null(model$dis_param_formula_1)){
      alpha_X <- 
        as.matrix(modelr::model_matrix(data, model$dis_param_formula_1))
      N_alpha <- ncol(alpha_X)
      alpha_pars <- coefs[(N_x+1):(N_x+N_alpha)]
      alpha <- exp(as.vector(alpha_X %*% alpha_pars))
    } else {
      params <- get_params(model$family)
      if(!is.null(params[[1]])){
        alpha <- rep(exp(coefs[N_x+1]), nrow(data))
      }
    }
    
    mu_adj <- rep(1, nrow(data))
    if(!is.null(model$underreport_formula)){
      X_und <- as.matrix(modelr::model_matrix(data, model$underreport_formula))
      N_und <- ncol(X_und)
      total_pars <- length(coefs)
      und_pars <- coefs[(total_pars - N_und + 1):total_pars]
      und_lin <- as.vector(X_und %*% und_pars)
      
      if(model$underreport_family == "logit"){
        mu_adj <- 1/(1+exp(-und_lin))
      } else {
        mu_adj <- stats::pnorm(und_lin, lower.tail = FALSE)
      }
    }
    
    predictions <- pred_base * mu_adj
    
    if (modtype == "countreg" && model$family == "PLN" && !is.null(alpha)){
      predictions <- predictions * exp((alpha^2)/2)
    }
    
    return(predictions)
  }
}

Try the flexCountReg package in your browser

Any scripts or data that you put into this service are public.

flexCountReg documentation built on Jan. 20, 2026, 1:06 a.m.