R/cp-methods.R

Defines functions summary.cp_model_nls

Documented in summary.cp_model_nls

utils::globalVariables(c(
  "target",
  "y_mean",
  "y_sd",
  "unsys",
  "fit",
  "y_pred"
))

#-------------------------------------------------------------------------------
# Summarize Nonlinear Model (cp_model_nls)
#' Summarize a Cross-Price Demand Model (Nonlinear)
#'
#' @param object A cross-price model object from fit_cp_nls with return_all=TRUE.
#' @param inv_fun Optional function to inverse-transform predictions (e.g., ll4_inv).
#'   Default is `identity`.
#' @param inverse_fun `r lifecycle::badge("deprecated")` Use `inv_fun` instead.
#' @param ... Additional arguments (unused).
#' @return A list containing model summary information.
#' @importFrom nlstools confint2
#' @importFrom stats residuals AIC BIC
#' @export
summary.cp_model_nls <- function(object, inv_fun = identity, inverse_fun = deprecated(), ...) {
  # Handle deprecated inverse_fun argument
  if (lifecycle::is_present(inverse_fun)) {
    lifecycle::deprecate_warn(
      "0.2.0",
      "summary.cp_model_nls(inverse_fun)",
      "summary.cp_model_nls(inv_fun)"
    )
    inv_fun <- inverse_fun
  }
  model <- object$model
  equation <- object$equation
  method <- object$method

  # Check if model is NULL and return an informative message
  if (is.null(model)) {
    result <- list(
      call = NA,
      coefficients = data.frame(
        Estimate = numeric(0),
        "Std. Error" = numeric(0),
        "t value" = numeric(0),
        "Pr(>|t|)" = numeric(0)
      ),
      conf_int = NULL,
      equation = equation,
      equation_text = switch(
        equation,
        exponential = "log10(y) ~ log10_qalone + I * exp(-(10^log10_beta) * x)",
        exponentiated = "y ~ (10^log10_qalone) * 10^(I * exp(-(10^log10_beta) * x))",
        additive = "y ~ (10^log10_qalone) + I * exp(-(10^log10_beta) * x)",
        "Unknown equation type"
      ),
      method = method,
      method_description = switch(
        method,
        nls_multstart = "Multiple starting values optimization with nls.multstart",
        nlsLM = "Levenberg-Marquardt nonlinear least-squares algorithm",
        wrapnlsr = "Nonlinear least squares with 'nlsr' package",
        "Unknown model type"
      ),
      r_squared = NA,
      aic = NA,
      bic = NA,
      transform = "none",
      residuals = numeric(0),
      derived_metrics = NULL,
      data = object$data,
      error_message = "Model fitting failed - no valid model available"
    )

    class(result) <- "summary.cp_model_nls"
    return(result)
  }

  # Use provided data or attempt to extract from the model environment.
  data <- if (!is.null(object$data)) {
    object$data
  } else {
    model_env <- if (inherits(model, "nls")) {
      model$m$getEnv()
    } else {
      environment(model)
    }
    data.frame(x = model_env$x, y = model_env$y)
  }

  # Get coefficient summary
  if (inherits(model, c("nls", "nlsLM"))) {
    coef_summary <- summary(model)$coefficients
  } else if (inherits(model, "nlsr")) {
    coef_summary <- summary(model)$coefficients
    colnames(coef_summary) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
  } else {
    coef_summary <- data.frame(
      Estimate = coef(model),
      `Std. Error` = NA,
      `t value` = NA,
      `Pr(>|t|)` = NA
    )
  }

  # Calculate R2 on the fitted response scale:
  # - For exponential: the model fits log10(y) as the response.
  # - For exponentiated/additive: the model fits y on the natural scale.
  r_squared <- tryCatch(
    {
      if (equation == "exponential") {
        if (any(data$y <= 0, na.rm = TRUE)) {
          return(NA_real_)
        }
        y_lhs <- log10(data$y)
        y_hat <- fitted(model) # predicted log10(y)
        1 - (sum((y_lhs - y_hat)^2) / sum((y_lhs - mean(y_lhs))^2))
      } else {
        y_lhs <- data$y
        y_hat <- fitted(model)
        1 - (sum((y_lhs - y_hat)^2) / sum((y_lhs - mean(y_lhs))^2))
      }
    },
    error = function(e) NA_real_
  )
  transform_info <- if (equation == "exponential") "log10(y)" else "none"

  # Fit statistics
  aic <- tryCatch(AIC(model), error = function(e) NA)
  bic <- tryCatch(BIC(model), error = function(e) NA)

  model_type_info <- switch(
    method,
    nls_multstart = "Multiple starting values optimization with nls.multstart",
    nlsLM = "Levenberg-Marquardt nonlinear least-squares algorithm",
    wrapnlsr = "Nonlinear least squares with 'nlsr' package",
    "Unknown model type"
  )

  equation_text <- switch(
    equation,
    exponential = "log10(y) ~ log10_qalone + I * exp(-(10^log10_beta) * x)",
    exponentiated = "y ~ (10^log10_qalone) * 10^(I * exp(-(10^log10_beta) * x))",
    additive = "y ~ (10^log10_qalone) + I * exp(-(10^log10_beta) * x)",
    "Unknown equation type"
  )

  conf_int <- tryCatch(
    {
      if (requireNamespace("nlstools", quietly = TRUE)) {
        ci <- nlstools::confint2(model)
        as.data.frame(ci)
      } else {
        NULL
      }
    },
    error = function(e) NULL
  )

  if (inherits(object, "cp_model_nls") && !is.null(data)) {
    coefs <- coef(model)
    # Extract log10-parameterized coefficients and compute natural-scale values
    log10_qalone <- coefs["log10_qalone"]
    I_param <- coefs["I"]
    log10_beta <- coefs["log10_beta"]

    se_log10_qalone <- NA_real_
    se_I <- NA_real_
    se_log10_beta <- NA_real_
    if (is.matrix(coef_summary) || is.data.frame(coef_summary)) {
      if ("Std. Error" %in% colnames(coef_summary)) {
        se_log10_qalone <- suppressWarnings(as.numeric(coef_summary["log10_qalone", "Std. Error"]))
        se_I <- suppressWarnings(as.numeric(coef_summary["I", "Std. Error"]))
        se_log10_beta <- suppressWarnings(as.numeric(coef_summary["log10_beta", "Std. Error"]))
      }
    }

    ln10 <- log(10)
    qalone_nat <- 10^log10_qalone
    beta_nat <- 10^log10_beta
    qalone_se <- if (is.finite(se_log10_qalone)) ln10 * qalone_nat * se_log10_qalone else NA_real_
    beta_se <- if (is.finite(se_log10_beta)) ln10 * beta_nat * se_log10_beta else NA_real_

    derived_metrics <- list(
      log10_qalone = log10_qalone,
      I = I_param,
      log10_beta = log10_beta,
      log10_qalone_se = se_log10_qalone,
      I_se = se_I,
      log10_beta_se = se_log10_beta,
      # Natural-scale values (back-transformed)
      qalone = qalone_nat,
      beta = beta_nat,
      qalone_se = qalone_se,
      beta_se = beta_se
    )
  } else {
    derived_metrics <- NULL
  }

  result <- list(
    call = if (is.null(model$call)) NA else model$call,
    coefficients = coef_summary,
    conf_int = conf_int,
    equation = equation,
    equation_text = equation_text,
    method = method,
    method_description = model_type_info,
    r_squared = r_squared,
    aic = aic,
    bic = bic,
    transform = transform_info,
    residuals = residuals(model),
    derived_metrics = derived_metrics,
    data = data
  )

  class(result) <- "summary.cp_model_nls"
  return(result)
}

#' Print method for summary.cp_model_nls objects
#' @param x A `summary.*` object
#' @param ... Unused
#' @return Invisibly returns the input object \code{x}.
#' @export
print.summary.cp_model_nls <- function(x, ...) {
  cat("Cross-Price Demand Model Summary\n")
  cat("================================\n\n")

  # Check for error message and display it prominently if present
  if (!is.null(x$error_message)) {
    cat("ERROR: ", x$error_message, "\n\n")
  }

  cat("Model Specification:\n")
  cat("Equation type:", x$equation, "\n")
  cat("Functional form:", x$equation_text, "\n")
  cat("Fitting method:", x$method, "\n")
  cat("Method details:", x$method_description, "\n")

  if (!is.null(x$error_message)) {
    cat("\nNote: No valid model was fit. Parameter estimates unavailable.\n")
    invisible(x)
    return()
  }

  if (x$transform != "none") {
    cat("Transformation applied:", x$transform, "\n")
  }
  cat("\nCoefficients:\n")
  printCoefmat(x$coefficients)

  if (!is.null(x$conf_int)) {
    cat("\nConfidence Intervals:\n")
    print(x$conf_int, digits = 4)
  }
  cat("\nFit Statistics:\n")
  cat("R-squared:", format(x$r_squared, digits = 4), "\n")
  if (!is.na(x$aic)) {
    cat("AIC:", format(x$aic, digits = 4), "\n")
  }
  if (!is.na(x$bic)) {
    cat("BIC:", format(x$bic, digits = 4), "\n")
  }

  if (!is.null(x$derived_metrics)) {
    cat("\nParameter Interpretation (natural scale):\n")
    cat(
      "qalone (Q_alone):",
      format(x$derived_metrics$qalone, digits = 4),
      " - consumption at zero alternative price\n"
    )
    cat(
      "I:",
      format(x$derived_metrics$I, digits = 4),
      " - interaction parameter (substitution direction)\n"
    )
    cat(
      "beta:",
      format(x$derived_metrics$beta, digits = 4),
      " - sensitivity parameter (sensitivity of relation to price)\n"
    )
    cat("\nOptimizer parameters (log10 scale):\n")
    cat(
      "log10_qalone:",
      format(x$derived_metrics$log10_qalone, digits = 4),
      "\n"
    )
    cat(
      "log10_beta:",
      format(x$derived_metrics$log10_beta, digits = 4),
      "\n"
    )
  }
  invisible(x)
}


#' Plot a Cross-Price Demand Model (Nonlinear)
#'
#' @param x A cross-price model object from fit_cp_nls with return_all=TRUE.
#' @param data Optional data frame with x and y; if NULL, uses object$data.
#' @param inv_fun Optional function to inverse-transform predictions. Default is `identity`.
#' @param n_points Number of points used for prediction curve.
#' @param title Optional plot title.
#' @param xlab X-axis label.
#' @param ylab Y-axis label.
#' @param x_trans Transformation for x-axis: "identity", "log10", or "pseudo_log".
#' @param y_trans Transformation for y-axis: "identity", "log10", or "pseudo_log".
#' @param point_size Size of data points.
#' @param inverse_fun `r lifecycle::badge("deprecated")` Use `inv_fun` instead.
#' @param ... Additional arguments (passed to predict).
#' @return A ggplot2 object.
#' @importFrom scales log10_trans pseudo_log_trans identity_trans
#' @export
plot.cp_model_nls <- function(
  x,
  data = NULL,
  inv_fun = identity,
  n_points = 100,
  title = NULL,
  xlab = "Price",
  ylab = "Consumption",
  x_trans = "identity",
  y_trans = "identity",
  point_size = 3,
  inverse_fun = deprecated(),
  ...
) {
  # Handle deprecated inverse_fun argument
  if (lifecycle::is_present(inverse_fun)) {
    lifecycle::deprecate_warn(
      "0.2.0",
      "plot.cp_model_nls(inverse_fun)",
      "plot.cp_model_nls(inv_fun)"
    )
    inv_fun <- inverse_fun
  }
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required.")
  }
  if (!requireNamespace("scales", quietly = TRUE)) {
    stop("Package 'scales' is required.")
  }

  # Use provided data or fallback
  if (is.null(data)) {
    if (!is.null(x$data)) {
      data <- x$data
    } else {
      stop(
        "No data provided and no data found in model object. Please provide data."
      )
    }
  }

  # Defensive: ensure data has x and y
  if (!all(c("x", "y") %in% names(data))) {
    stop("Data must contain columns 'x' and 'y'")
  }

  # If model is NULL, plot only the data points and warn
  if (is.null(x$model)) {
    warning("Model fitting failed; plotting data points only.")
    p <- ggplot2::ggplot(data, ggplot2::aes(x = x, y = y)) +
      ggplot2::geom_point(
        shape = 21,
        size = point_size,
        fill = "white"
      ) +
      ggplot2::labs(
        x = xlab,
        y = ylab,
        title = if (is.null(title)) "No model fit: data only" else title
      ) +
      ggplot2::theme_bw()
    return(p)
  }

  # --- existing code for valid model below ---
  if ("target" %in% names(data)) {
    data <- data[data$target == "alt", ]
    if (nrow(data) == 0) {
      stop("No data with target = 'alt' found in provided data")
    }
  }

  allowed_trans <- c("identity", "log10", "pseudo_log")
  x_trans <- match.arg(x_trans, allowed_trans)
  y_trans <- match.arg(y_trans, allowed_trans)

  get_trans <- function(trans_name) {
    switch(
      trans_name,
      log10 = scales::log10_trans(),
      pseudo_log = scales::pseudo_log_trans(),
      identity = scales::identity_trans()
    )
  }
  x_trans_obj <- get_trans(x_trans)
  y_trans_obj <- get_trans(y_trans)

  if (x_trans == "log10" && any(data$x <= 0, na.rm = TRUE)) {
    data <- data[data$x > 0, ]
    warning("Filtered out non-positive x values for log10 transformation")
    if (nrow(data) == 0) {
      stop("No positive x values left after filtering for log10 transformation")
    }
  }

  x_range <- range(data$x, na.rm = TRUE)
  if (!all(is.finite(x_range))) {
    stop("Cannot determine a valid x range from the provided data")
  }

  if (x_trans == "log10") {
    min_x <- max(0.001, x_range[1])
    pred_x <- exp(seq(log(min_x), log(x_range[2]), length.out = n_points))
  } else {
    pred_x <- seq(x_range[1], x_range[2], length.out = n_points)
  }

  new_x <- data.frame(x = pred_x)
  preds <- predict(x, newdata = new_x, inv_fun = inv_fun, ...)

  y_col <- if (
    !identical(inv_fun, identity) && "y_pred_untransformed" %in% names(preds)
  ) {
    "y_pred_untransformed"
  } else {
    "y_pred"
  }

  p <- ggplot2::ggplot() +
    ggplot2::geom_line(
      data = preds,
      ggplot2::aes(x = x, y = .data[[y_col]]),
      color = "blue",
      linewidth = 1
    ) +
    ggplot2::geom_point(
      data = data,
      ggplot2::aes(x = x, y = y),
      shape = 21,
      size = point_size,
      fill = "white"
    ) +
    ggplot2::scale_x_continuous(trans = x_trans_obj) +
    ggplot2::scale_y_continuous(trans = y_trans_obj)

  if (x_trans == "log10") {
    p <- p + ggplot2::annotation_logticks(sides = "b")
  }
  if (y_trans == "log10") {
    p <- p + ggplot2::annotation_logticks(sides = "l")
  }

  p <- p +
    ggplot2::labs(x = xlab, y = ylab, title = title) +
    ggplot2::theme_bw()
  return(p)
}


#' Predict from a Cross-Price Demand Model (Nonlinear)
#'
#' @param object A cross-price model object from fit_cp_nls with return_all=TRUE.
#' @param newdata A data frame containing an 'x' column.
#' @param inv_fun Optional inverse transformation function. Default is `identity`.
#' @param inverse_fun `r lifecycle::badge("deprecated")` Use `inv_fun` instead.
#' @param ... Additional arguments.
#' @return A data frame with x values and predicted y values.
#' @export
predict.cp_model_nls <- function(
  object,
  newdata = NULL,
  inv_fun = identity,
  inverse_fun = deprecated(),
  ...
) {
  # Handle deprecated inverse_fun argument
  if (lifecycle::is_present(inverse_fun)) {
    lifecycle::deprecate_warn(
      "0.2.0",
      "predict.cp_model_nls(inverse_fun)",
      "predict.cp_model_nls(inv_fun)"
    )
    inv_fun <- inverse_fun
  }
  if (!inherits(object, "cp_model_nls")) {
    stop("Object must be of class 'cp_model_nls'")
  }
  if (is.null(newdata)) {
    stop("'newdata' must be provided as a data frame with an 'x' column")
  }
  if (!("x" %in% names(newdata))) {
    stop("'newdata' must contain a column named 'x'")
  }

  equation <- object$equation
  model <- object$model
  coefs <- tryCatch(coef(model), error = function(e) {
    stop("Could not extract coefficients from model: ", e$message)
  })

  if (!all(c("log10_qalone", "I", "log10_beta") %in% names(coefs))) {
    stop("Missing required coefficients: log10_qalone, I, or log10_beta")
  }

  # Extract log10-parameterized coefficients
  log10_qalone <- coefs["log10_qalone"]
  I_param <- coefs["I"]
  log10_beta <- coefs["log10_beta"]

  # Compute natural-scale parameters
  qalone <- 10^log10_qalone
  beta <- 10^log10_beta

  x_vals <- newdata$x

  # Predictions per EQUATIONS_CONTRACT.md (log10 parameterization):
  # - Exponentiated: y = (10^log10_qalone) * 10^(I * exp(-(10^log10_beta) * x))
  # - Exponential (on log10 scale): log10(y) = log10_qalone + I * exp(-(10^log10_beta) * x)
  # - Additive: y = (10^log10_qalone) + I * exp(-(10^log10_beta) * x)
  y_pred_internal <- switch(
    equation,
    exponentiated = qalone * 10^(I_param * exp(-beta * x_vals)),
    exponential = log10_qalone + I_param * exp(-beta * x_vals),  # Returns log10(y)
    additive = qalone + I_param * exp(-beta * x_vals),
    stop("Unsupported equation type: ", equation)
  )

  if (equation == "exponential") {
    log10_y_pred <- y_pred_internal
    y_pred <- 10^log10_y_pred
    result <- tibble::tibble(x = x_vals, y_pred = y_pred, y_pred_log10 = log10_y_pred)
  } else {
    result <- tibble::tibble(x = x_vals, y_pred = y_pred_internal)
  }

  if (!identical(inv_fun, identity)) {
    tryCatch(
      {
        result$y_pred_untransformed <- inv_fun(result$y_pred)
      },
      error = function(e) {
        warning("Failed to apply inverse transformation: ", e$message)
      }
    )
  }
  return(result)
}

#-------------------------------------------------------------------------------
# Predict Methods for Linear Models
#' Predict method for cp_model_lm objects.
#'
#' @param object A cp_model_lm object.
#' @param newdata Data frame containing new x values.
#' @param ... Additional arguments.
#' @return Data frame with predictions.
#' @export
predict.cp_model_lm <- function(object, newdata = NULL, ...) {
  if (is.null(newdata)) {
    stop("newdata must be provided")
  }
  predictions <- predict(object$model, newdata = newdata, ...)
  tibble::tibble(x = newdata$x, y_pred = predictions)
}

#' Predict from a Mixed-Effects Cross-Price Demand Model
#'
#' Generates predictions from a mixed-effects cross-price demand model (of class
#' \code{cp_model_lmer}). The function supports two modes:
#'
#' \describe{
#'   \item{\code{"fixed"}}{Returns predictions based solely on the fixed-effects component
#'       (using \code{re.form = NA}).}
#'   \item{\code{"random"}}{Returns subject-specific predictions (fixed plus random effects)
#'       (using \code{re.form = NULL}).}
#' }
#'
#' @param object A \code{cp_model_lmer} object (as returned by \code{fit_cp_linear(type = "mixed", ...)}).
#' @param newdata A data frame containing at least an \code{x} column. For \code{pred_type = "random"},
#'   an \code{id} column is required. If absent, the function extracts unique ids from \code{object$data}
#'   and expands the grid accordingly. If no ids are available, a default id of 1 is used (with a warning).
#' @param pred_type Character string specifying the type of prediction: either \code{"fixed"} (population-level)
#'   or \code{"random"} (subject-specific). The default is \code{"fixed"}.
#' @param ... Additional arguments passed to the underlying \code{predict} function.
#'
#' @return A data frame containing all columns of \code{newdata} plus a column \code{y_pred}
#'   with the corresponding predictions.
#'
#' @examples
#' \donttest{
#' data(etm)
#' fit <- fit_cp_linear(etm, type = "mixed")
#' new_prices <- data.frame(x = c(2, 4, 8, 16, 32, 64))
#' predict(fit, newdata = new_prices, pred_type = "fixed")
#' }
#'
#' @export
predict.cp_model_lmer <- function(
  object,
  newdata = NULL,
  pred_type = c("fixed", "random"),
  ...
) {
  if (is.null(newdata)) {
    stop("newdata must be provided")
  }
  if (!("x" %in% names(newdata))) {
    stop("newdata must contain a column named 'x'")
  }

  pred_type <- match.arg(pred_type)

  # If subject-specific predictions are requested, ensure newdata has an 'id' column.
  if (pred_type == "random" && !("id" %in% names(newdata))) {
    if (!is.null(object$data) && "id" %in% names(object$data)) {
      ids <- unique(object$data$id)
      # Expand grid: for every x value, create a row for each id.
      newdata <- expand.grid(x = newdata$x, id = ids)
    } else {
      warning(
        "newdata does not contain 'id' and object$data does not provide ids; assigning default id = 1"
      )
      newdata$id <- 1
    }
  }

  if (pred_type == "fixed") {
    preds <- predict(object$model, newdata = newdata, re.form = NA, ...)
  } else if (pred_type == "random") {
    preds <- predict(object$model, newdata = newdata, re.form = NULL, ...)
  }

  out <- as.data.frame(newdata)
  out$y_pred <- preds
  return(out)
}


#-------------------------------------------------------------------------------
# Summary Methods for Linear Models
#' Summary method for cp_model_lm objects.
#'
#' @param object A cp_model_lm object.
#' @param ... Additional arguments.
#' @return A list summarizing the linear model.
#' @export
summary.cp_model_lm <- function(object, ...) {
  model_summary <- summary(object$model)
  result <- list(
    call = object$model$call,
    coefficients = model_summary$coefficients,
    formula = object$formula,
    method = object$method,
    r_squared = model_summary$r.squared,
    adj_r_squared = model_summary$adj.r.squared,
    residuals = residuals(object$model)
  )
  class(result) <- "summary.cp_model_lm"
  return(result)
}

#' Summary method for cp_model_lmer objects.
#'
#' @param object A cp_model_lmer object.
#' @param ... Additional arguments.
#' @return A list summarizing the mixed-effects model.
#' @importFrom lme4 VarCorr
#' @importFrom performance r2_nakagawa
#' @importFrom stats AIC BIC
#' @export
summary.cp_model_lmer <- function(object, ...) {
  if (!requireNamespace("lme4", quietly = TRUE)) {
    stop("Package 'lme4' is required")
  }
  model_summary <- summary(object$model)
  fixed_effects <- model_summary$coefficients

  # Reformat the random effects table for clarity.
  var_corr <- lme4::VarCorr(object$model)
  rand_effects <- as.data.frame(var_corr)
  # Optionally, rename columns to make them more intuitive.
  names(rand_effects) <- c("Group", "Term", "Variance", "Std.Dev", "NA")

  r2 <- tryCatch(
    {
      if (requireNamespace("performance", quietly = TRUE)) {
        performance::r2_nakagawa(object$model)
      } else {
        list(R2_conditional = NA, R2_marginal = NA)
      }
    },
    error = function(e) list(R2_conditional = NA, R2_marginal = NA)
  )

  aic <- AIC(object$model)
  bic <- BIC(object$model)

  result <- list(
    call = object$model@call,
    coefficients = fixed_effects,
    random_effects = rand_effects,
    formula = object$formula,
    method = object$method,
    r2_marginal = r2$R2_marginal,
    r2_conditional = r2$R2_conditional,
    AIC = aic,
    BIC = bic,
    residuals = residuals(object$model)
  )
  class(result) <- "summary.cp_model_lmer"
  return(result)
}

#-------------------------------------------------------------------------------
# Broom Methods
#' Convert a cross-price model to a tidy data frame of coefficients
#'
#' This function extracts model coefficients from a cross-price demand model
#' into a tidy data frame format, following the conventions of the broom package.
#' It handles cases where model fitting failed gracefully, returning an empty
#' data frame with the expected structure.
#'
#' @param x A model object from fit_cp_nls or fit_cp_linear
#' @param ... Additional arguments (unused)
#'
#' @return A data frame with one row per coefficient, containing columns:
#'   \item{term}{The name of the model parameter}
#'   \item{estimate}{The estimated coefficient value}
#'   \item{std.error}{The standard error of the coefficient}
#'   \item{statistic}{The t-statistic for the coefficient}
#'   \item{p.value}{The p-value for the coefficient}
#'
#' @examples
#' \donttest{
#' data(etm)
#' fit <- fit_cp_nls(etm, equation = "exponentiated")
#' tidy(fit)
#' }
#'
#' @importFrom tibble rownames_to_column
#' @export
tidy.cp_model_nls <- function(x, ...) {
  if (is.null(x$model)) {
    return(tibble::tibble(
      term = character(0),
      estimate = numeric(0),
      std.error = numeric(0),
      statistic = numeric(0),
      p.value = numeric(0)
    ))
  }

  summ <- summary(x)
  if (!is.null(summ$coefficients)) {
    stats::setNames(
      as.data.frame(summ$coefficients),
      c("estimate", "std.error", "statistic", "p.value")
    ) |>
      tibble::rownames_to_column("term") |>
      tibble::as_tibble()
  } else {
    # Fallback if summary doesn't contain coefficient matrix
    coeffs <- coef(x)
    tibble::tibble(
      term = names(coeffs),
      estimate = unname(coeffs),
      std.error = NA_real_,
      statistic = NA_real_,
      p.value = NA_real_
    )
  }
}

#' Get model summaries from a cross-price model
#'
#' This function extracts model summary statistics from a cross-price demand model
#' into a single-row data frame, following the conventions of the broom package.
#' It returns goodness-of-fit measures and other model information.
#'
#' @param x A model object from fit_cp_nls or fit_cp_linear
#' @param ... Additional arguments (unused)
#'
#' @return A one-row data frame with model summary statistics:
#'   \item{r.squared}{R-squared value indicating model fit}
#'   \item{aic}{Akaike Information Criterion}
#'   \item{bic}{Bayesian Information Criterion}
#'   \item{equation}{The equation type used in the model}
#'   \item{method}{The method used to fit the model}
#'   \item{transform}{The transformation applied to the data, if any}
#'
#' @examples
#' \donttest{
#' data(etm)
#' fit <- fit_cp_nls(etm, equation = "exponentiated")
#' glance(fit)
#' }
#'
#' @export
glance.cp_model_nls <- function(x, ...) {
  summ <- summary(x)
  tibble::tibble(
    r.squared = summ$r_squared,
    aic = summ$aic,
    bic = summ$bic,
    equation = summ$equation,
    method = summ$method,
    transform = summ$transform
  )
}

#' Confidence Intervals for Cross-Price NLS Model Parameters
#'
#' Computes confidence intervals for parameters from a nonlinear cross-price
#' demand model using `nlstools::confint2()`.
#'
#' @param object A `cp_model_nls` object from [fit_cp_nls()].
#' @param parm Character vector of parameter names to compute CIs for.
#'   Default includes all parameters.
#' @param level Confidence level (default 0.95).
#' @param method Character. Method for computing intervals passed to
#'   `nlstools::confint2()`:
#'   - `"asymptotic"` (default): Wald-type asymptotic intervals
#'   - `"profile"`: Profile-t confidence intervals
#' @param ... Additional arguments passed to `nlstools::confint2()`.
#'
#' @return A tibble with columns: `term`, `estimate`, `conf.low`, `conf.high`,
#'   `level`, `method`.
#'
#' @details
#' This method wraps `nlstools::confint2()` to provide confidence intervals
#' for the log10-parameterized coefficients (`log10_qalone`, `I`, `log10_beta`).
#'
#' For back-transformed natural-scale confidence intervals, apply the
#' transformation: `10^conf.low` and `10^conf.high` for log10-scale parameters.
#'
#' @examples
#' \donttest{
#' data(etm)
#' fit <- fit_cp_nls(etm, equation = "exponentiated")
#' confint(fit)
#' }
#'
#' @importFrom nlstools confint2
#' @export
confint.cp_model_nls <- function(
  object,
  parm = NULL,
  level = 0.95,
  method = c("asymptotic", "profile"),
  ...
) {
  method <- match.arg(method)

  if (!is.numeric(level) || length(level) != 1 || level <= 0 || level >= 1) {
    stop("`level` must be a single number between 0 and 1.", call. = FALSE)
  }

  if (is.null(object$model)) {
    warning("No model found in object. Model fitting may have failed.", call. = FALSE)
    return(tibble::tibble(
      term = character(),
      estimate = numeric(),
      conf.low = numeric(),
      conf.high = numeric(),
      level = numeric(),
      method = character()
    ))
  }

  # Get estimates
  coefs <- stats::coef(object$model)
  terms <- names(coefs)

  # Filter parameters if parm is specified
  if (!is.null(parm)) {
    keep <- terms %in% parm
    if (!any(keep)) {
      warning("No requested parameters found in model.", call. = FALSE)
      return(tibble::tibble(
        term = character(),
        estimate = numeric(),
        conf.low = numeric(),
        conf.high = numeric(),
        level = numeric(),
        method = character()
      ))
    }
  }

  # Compute confidence intervals using nlstools
  ci_result <- tryCatch(
    nlstools::confint2(object$model, level = level, method = method, ...),
    error = function(e) {
      warning(
        "Confidence interval computation failed: ", conditionMessage(e),
        call. = FALSE
      )
      return(NULL)
    }
  )

  if (is.null(ci_result)) {
    return(tibble::tibble(
      term = terms,
      estimate = unname(coefs),
      conf.low = NA_real_,
      conf.high = NA_real_,
      level = level,
      method = method
    ))
  }

  ci_df <- as.data.frame(ci_result)
  ci_terms <- rownames(ci_df)

  # Filter parameters if parm is specified
  if (!is.null(parm)) {
    keep <- ci_terms %in% parm
    ci_df <- ci_df[keep, , drop = FALSE]
    ci_terms <- ci_terms[keep]
    coefs <- coefs[keep]
  }

  tibble::tibble(
    term = ci_terms,
    estimate = unname(coefs[ci_terms]),
    conf.low = ci_df[, 1],
    conf.high = ci_df[, 2],
    level = level,
    method = method
  )
}

#' Extract coefficients from a linear cross-price model in tidy format
#'
#' @param x A cp_model_lm object.
#' @param ... Additional arguments (unused).
#' @return A tibble with columns: term, estimate, std.error, statistic, p.value.
#' @export
tidy.cp_model_lm <- function(x, ...) {
  if (is.null(x$model)) {
    return(tibble::tibble(
      term = character(0),
      estimate = numeric(0),
      std.error = numeric(0),
      statistic = numeric(0),
      p.value = numeric(0)
    ))
  }
  broom::tidy(x$model)
}

#' Get model summaries from a linear cross-price model
#'
#' @param x A cp_model_lm object.
#' @param ... Additional arguments (unused).
#' @return A tibble with model summary statistics.
#' @export
glance.cp_model_lm <- function(x, ...) {
  if (is.null(x$model)) {
    return(tibble::tibble(
      r.squared = NA_real_,
      adj.r.squared = NA_real_,
      sigma = NA_real_,
      statistic = NA_real_,
      p.value = NA_real_,
      df = NA_integer_,
      nobs = NA_integer_
    ))
  }
  broom::glance(x$model)
}

#' Extract coefficients from a mixed-effects cross-price model in tidy format
#'
#' @param x A cp_model_lmer object.
#' @param effects Which effects to return: "fixed" (default), "random", or "ran_pars".
#' @param ... Additional arguments passed to broom.mixed::tidy.
#' @return A tibble with tidy coefficient information.
#' @export
tidy.cp_model_lmer <- function(x, effects = c("fixed", "random", "ran_pars"), ...) {
  effects <- match.arg(effects)
  if (is.null(x$model)) {
    return(tibble::tibble(
      effect = character(0),
      term = character(0),
      estimate = numeric(0),
      std.error = numeric(0)
    ))
  }
  if (!requireNamespace("broom.mixed", quietly = TRUE)) {
    missing_package_error("broom.mixed", reason = "to tidy mixed-effects models")
  }
  broom.mixed::tidy(x$model, effects = effects, ...)
}

#' Get model summaries from a mixed-effects cross-price model
#'
#' @param x A cp_model_lmer object.
#' @param ... Additional arguments passed to broom.mixed::glance.
#' @return A tibble with model summary statistics.
#' @export
glance.cp_model_lmer <- function(x, ...) {
  if (is.null(x$model)) {
    return(tibble::tibble(
      nobs = NA_integer_,
      sigma = NA_real_,
      logLik = NA_real_,
      AIC = NA_real_,
      BIC = NA_real_,
      deviance = NA_real_,
      df.residual = NA_integer_
    ))
  }
  if (!requireNamespace("broom.mixed", quietly = TRUE)) {
    missing_package_error("broom.mixed", reason = "to glance mixed-effects models")
  }
  broom.mixed::glance(x$model, ...)
}

#-------------------------------------------------------------------------------
# Print Methods
#' Print method for summary.cp_model_lm objects.
#' @param x A `summary.*` object
#' @param ... Unused
#' @return Invisibly returns the input object \code{x}.
#' @export
print.summary.cp_model_lm <- function(x, ...) {
  cat("Linear Cross-Price Demand Model Summary\n")
  cat("=======================================\n\n")
  cat("Formula:", deparse(x$formula), "\n")
  cat("Method:", x$method, "\n\n")
  cat("Coefficients:\n")
  printCoefmat(x$coefficients)
  cat(
    "\nR-squared:",
    format(x$r_squared, digits = 4),
    "  Adjusted R-squared:",
    format(x$adj_r_squared, digits = 4),
    "\n"
  )
  invisible(x)
}

#' Print method for summary.cp_model_lmer objects.
#' @param x A `summary.*` object
#' @param ... Unused
#' @return Invisibly returns the input object \code{x}.
#' @export
print.summary.cp_model_lmer <- function(x, ...) {
  cat("Mixed-Effects Linear Cross-Price Demand Model Summary\n")
  cat("====================================================\n\n")
  cat("Formula:", deparse(x$formula), "\n")
  cat("Method:", x$method, "\n\n")
  cat("Fixed Effects:\n")
  printCoefmat(x$coefficients)
  cat("\nRandom Effects:\n")
  print(x$random_effects, row.names = FALSE)
  cat("\nModel Fit:\n")
  if (!is.na(x$r2_marginal)) {
    cat(
      "R2 (marginal):",
      format(x$r2_marginal, digits = 4),
      "  [Fixed effects only]\n"
    )
  }
  if (!is.na(x$r2_conditional)) {
    cat(
      "R2 (conditional):",
      format(x$r2_conditional, digits = 4),
      "  [Fixed + random effects]\n"
    )
  }
  cat("AIC:", format(x$AIC, digits = 4), "\n")
  cat("BIC:", format(x$BIC, digits = 4), "\n")
  cat("\nNote: R2 values for mixed models are approximate.\n")
  invisible(x)
}

#-------------------------------------------------------------------------------
# Plot Methods for Linear Models

#' Plot Method for Linear Cross-Price Demand Models
#'
#' Creates a ggplot2 visualization of a fitted linear cross-price demand model
#' (of class \code{cp_model_lm}). The plot overlays a prediction line over the
#' observed data points. Axis transformations (e.g., \code{"log10"}) can be applied.
#' If the model includes group effects, separate lines will be drawn for each group.
#'
#' @param x A \code{cp_model_lm} object (as returned by \code{fit_cp_linear(type = "fixed", ...)}).
#' @param data Optional data frame containing columns \code{x} and \code{y} to plot.
#'   If not provided, the function uses \code{object$data} if available.
#' @param inv_fun Optional function to inverse-transform predictions. Default is `identity`.
#'   Not typically used for linear models but included for API consistency.
#' @param n_points Number of points to create in the prediction grid. Default is \code{100}.
#' @param title Optional title for the plot; default is \code{NULL}.
#' @param xlab Label for the x-axis. Default is \code{"Price"}.
#' @param ylab Label for the y-axis. Default is \code{"Consumption"}.
#' @param x_trans Transformation for the x-axis; one of \code{"identity"}, \code{"log10"}, or \code{"pseudo_log"}.
#'   Default is \code{"identity"}.
#' @param y_trans Transformation for the y-axis; one of \code{"identity"}, \code{"log10"}, or \code{"pseudo_log"}.
#'   Default is \code{"identity"}.
#' @param point_size Size of the data points in the plot. Default is \code{3}.
#' @param ... Additional arguments passed to the generic \code{predict} method.
#'
#' @return A ggplot2 object displaying the fitted model predictions and observed data.
#'
#' @export
plot.cp_model_lm <- function(
  x,
  data = NULL,
  inv_fun = identity,
  n_points = 100,
  title = NULL,
  xlab = "Price",
  ylab = "Consumption",
  x_trans = "identity",
  y_trans = "identity",
  point_size = 3,
  ...
) {
  object <- x
  # Use provided data, or fallback on object$data if available.
  if (is.null(data)) {
    if (!is.null(object$data)) {
      data <- object$data
    } else {
      stop(
        "No data provided and no data found in model object. Please supply a data frame."
      )
    }
  }

  # Filter for target "alt" if that column exists.
  if ("target" %in% names(data)) {
    data <- data[data$target == "alt", ]
  }
  if (!all(c("x", "y") %in% names(data))) {
    stop("Data must contain columns 'x' and 'y'.")
  }

  # Determine if the model has group effects
  has_group_effects <- !is.null(object$group_effects) &&
    (isTRUE(object$group_effects) ||
      object$group_effects %in% c("intercept", "interaction"))

  # Get the group variable if present in the model formula
  if (has_group_effects && !("group" %in% names(data))) {
    stop("Model includes group effects but 'group' variable not found in data.")
  }

  # Determine the x-range and create a prediction grid.
  x_range <- range(data$x, na.rm = TRUE)
  if (x_trans == "log10") {
    if (any(data$x <= 0, na.rm = TRUE)) {
      data <- data[data$x > 0, ]
      warning("Filtered out non-positive x values for log10 transformation")
      x_range <- range(data$x, na.rm = TRUE)
    }
    min_x <- max(0.001, x_range[1])
    pred_x <- exp(seq(log(min_x), log(x_range[2]), length.out = n_points))
  } else {
    pred_x <- seq(x_range[1], x_range[2], length.out = n_points)
  }

  # Create prediction grid - now handling group if present
  if (has_group_effects) {
    # Get unique groups
    groups <- unique(data$group)
    # Create a grid with all combinations of x and group
    newdata <- expand.grid(x = pred_x, group = groups)
    # Ensure group is a factor if original was a factor
    if (is.factor(data$group)) {
      newdata$group <- factor(newdata$group, levels = levels(data$group))
    }
  } else {
    newdata <- data.frame(x = pred_x)
  }

  # Generate predictions (your S3 predict method will return a data frame with a y_pred column)
  preds <- predict(object, newdata = newdata, ...)
  if (has_group_effects) {
    preds <- data.frame(
      x = newdata$x,
      group = newdata$group,
      y_pred = preds$y_pred
    )
  } else {
    preds <- data.frame(x = newdata$x, y_pred = preds$y_pred)
  }

  # Build the ggplot
  p <- ggplot2::ggplot()

  # Add line(s) - different handling for group vs no group
  if (has_group_effects) {
    p <- p +
      ggplot2::geom_line(
        data = preds,
        ggplot2::aes(x = x, y = y_pred, color = group, group = group),
        linewidth = 1
      )

    # Add points with color by group
    p <- p +
      ggplot2::geom_point(
        data = data,
        ggplot2::aes(x = x, y = y, fill = group),
        shape = 21,
        size = point_size,
        color = "black",
        stroke = 0.5
      )
  } else {
    p <- p +
      ggplot2::geom_line(
        data = preds,
        ggplot2::aes(x = x, y = y_pred),
        color = "blue",
        linewidth = 1
      )

    # Add points
    p <- p +
      ggplot2::geom_point(
        data = data,
        ggplot2::aes(x = x, y = y),
        shape = 21,
        size = point_size,
        fill = "white"
      )
  }

  # Add labels and theme
  p <- p +
    ggplot2::labs(x = xlab, y = ylab, title = title) +
    ggplot2::theme_bw()

  # Apply axis transformations
  if (x_trans == "log10") {
    p <- p +
      ggplot2::scale_x_continuous(trans = scales::log10_trans()) +
      ggplot2::annotation_logticks(sides = "b")
  }
  if (y_trans == "log10") {
    p <- p +
      ggplot2::scale_y_continuous(trans = scales::log10_trans()) +
      ggplot2::annotation_logticks(sides = "l")
  }

  return(p)
}


#' Plot Method for Mixed-Effects Cross-Price Demand Models
#'
#' Creates a ggplot2 visualization of a fitted mixed-effects cross-price demand model
#' (of class \code{cp_model_lmer}). This function allows you to plot:
#'
#' \describe{
#'   \item{\code{"fixed"}}{Only the population-level (fixed-effects) prediction.}
#'   \item{\code{"random"}}{Only the subject-specific predictions.}
#'   \item{\code{"all"}}{Both: the fixed-effects and the subject-specific predictions.}
#' }
#'
#' If the model includes group effects, separate lines will be drawn for each group.
#'
#' @param x A \code{cp_model_lmer} object (as returned by
#'   \code{fit_cp_linear(type = "mixed", ...)}).
#' @param data Optional data frame containing columns \code{x} and \code{y} to be plotted.
#'   If not provided, \code{object$data} is used.
#' @param inv_fun Optional function to inverse-transform predictions. Default is `identity`.
#'   Not typically used for linear models but included for API consistency.
#' @param n_points Number of points to use in creating the prediction grid. Default is \code{100}.
#' @param title Optional title for the plot; default is \code{NULL}.
#' @param xlab Label for the x-axis. Default is \code{"Price"}.
#' @param ylab Label for the y-axis. Default is \code{"Consumption"}.
#' @param x_trans Transformation for the x-axis; one of \code{"identity"}, \code{"log10"}, or
#'   \code{"pseudo_log"}. Default is \code{"identity"}.
#' @param y_trans Transformation for the y-axis; one of \code{"identity"}, \code{"log10"}, or
#'   \code{"pseudo_log"}. Default is \code{"identity"}.
#' @param point_size Size of the observed data points. Default is \code{3}.
#' @param pred_type Character string specifying which prediction components to plot:
#'   \describe{
#'     \item{\code{"fixed"}}{Plot only the fixed-effects (population) prediction.}
#'     \item{\code{"random"}}{Plot only the subject-specific predictions.}
#'     \item{\code{"all"}}{Plot both the fixed-effects and the subject-specific predictions.}
#'   }
#'   The default is \code{"fixed"}.
#' @param ... Additional arguments passed to \code{\link{predict.cp_model_lmer}}.
#'
#' @return A ggplot2 object displaying the observed data points along with the prediction curves.
#'
#' @export
plot.cp_model_lmer <- function(
  x,
  data = NULL,
  inv_fun = identity,
  n_points = 100,
  title = NULL,
  xlab = "Price",
  ylab = "Consumption",
  x_trans = "identity",
  y_trans = "identity",
  point_size = 3,
  pred_type = c("fixed", "random", "all"),
  ...
) {
  object <- x # Rename for clarity (x is the formal parameter name for plot methods)
  pred_type <- match.arg(pred_type)

  # Use provided data or fall back on object$data if available.
  if (is.null(data)) {
    if (!is.null(object$data)) {
      data <- object$data
    } else {
      stop(
        "No data provided and no data found in model object. Please supply a data frame."
      )
    }
  }

  # Filter data for target == 'alt' if present.
  if ("target" %in% names(data)) {
    data <- data[data$target == "alt", ]
  }
  if (!all(c("x", "y") %in% names(data))) {
    stop("Data must contain columns 'x' and 'y'.")
  }

  # Determine if the model has group effects
  has_group_effects <- !is.null(object$group_effects) &&
    (isTRUE(object$group_effects) ||
      object$group_effects %in% c("intercept", "interaction"))

  # Check for group column when needed
  if (has_group_effects && !("group" %in% names(data))) {
    stop("Model includes group effects but 'group' variable not found in data.")
  }

  # If x_trans is "log10", filter out non-positive x values.
  x_range <- range(data$x, na.rm = TRUE)
  if (x_trans == "log10") {
    if (any(data$x <= 0, na.rm = TRUE)) {
      data <- data[data$x > 0, ]
      warning("Filtered out non-positive x values for log10 transformation")
      x_range <- range(data$x, na.rm = TRUE)
    }
  }

  # Create a prediction grid for x values.
  if (x_trans == "log10") {
    min_x <- max(1e-3, x_range[1])
    pred_x <- exp(seq(log(min_x), log(x_range[2]), length.out = n_points))
  } else {
    pred_x <- seq(x_range[1], x_range[2], length.out = n_points)
  }

  # Initialize plot
  p <- ggplot2::ggplot() +
    # Plot observed data - with group coloring if applicable
    ggplot2::geom_point(
      data = data,
      ggplot2::aes(x = x, y = y, fill = if (has_group_effects) group else NULL),
      shape = 21,
      size = point_size,
      color = "black",
      stroke = 0.5
    ) +
    ggplot2::labs(x = xlab, y = ylab, title = title) +
    ggplot2::theme_bw()

  # Population (fixed) predictions.
  if (pred_type %in% c("fixed", "all")) {
    # Create prediction grid for fixed effects
    if (has_group_effects) {
      # Get unique groups
      groups <- unique(data$group)
      # Create grid with all x and group combinations
      pop_newdata <- expand.grid(x = pred_x, group = groups)
      # Ensure group is a factor if original was
      if (is.factor(data$group)) {
        pop_newdata$group <- factor(
          pop_newdata$group,
          levels = levels(data$group)
        )
      }
    } else {
      pop_newdata <- data.frame(x = pred_x)
    }

    # Get fixed-effects predictions
    preds_fixed <- predict(
      object,
      newdata = pop_newdata,
      pred_type = "fixed",
      ...
    )

    # Add to plot - different handling for groups
    if (has_group_effects) {
      p <- p +
        ggplot2::geom_line(
          data = preds_fixed,
          ggplot2::aes(x = x, y = y_pred, color = group, group = group),
          linewidth = 1.2
        )
    } else {
      p <- p +
        ggplot2::geom_line(
          data = preds_fixed,
          ggplot2::aes(x = x, y = y_pred, group = 1),
          color = "black",
          linewidth = 1.2
        )
    }
  }

  # Subject-specific (random) predictions.
  if (pred_type %in% c("random", "all")) {
    # Get unique IDs
    if ("id" %in% names(data)) {
      ids <- unique(data$id)

      # Create prediction grid for random effects
      if (has_group_effects) {
        # With groups: need x, id and group combinations
        groups <- unique(data$group)
        # For random effects, we need ID-specific predictions for each group
        rand_newdata <- expand.grid(x = pred_x, id = ids, group = groups)
        # Filter to match id-group combinations that exist in the data
        id_group_pairs <- unique(data[, c("id", "group")])
        rand_newdata <- merge(rand_newdata, id_group_pairs)
        # Ensure group is a factor if original was
        if (is.factor(data$group)) {
          rand_newdata$group <- factor(
            rand_newdata$group,
            levels = levels(data$group)
          )
        }
      } else {
        # Without groups: just need x and id combinations
        rand_newdata <- expand.grid(x = pred_x, id = ids)
      }

      # Get subject-specific predictions
      preds_random <- predict(
        object,
        newdata = rand_newdata,
        pred_type = "random",
        ...
      )

      # Add to plot - different handling for groups
      if (has_group_effects) {
        p <- p +
          ggplot2::geom_line(
            data = preds_random,
            ggplot2::aes(
              x = x,
              y = y_pred,
              group = interaction(id, group),
              color = group
            ),
            linewidth = 0.6,
            alpha = 0.5
          )
      } else {
        p <- p +
          ggplot2::geom_line(
            data = preds_random,
            ggplot2::aes(x = x, y = y_pred, group = id),
            color = "blue",
            linewidth = 0.6,
            alpha = 0.5
          )
      }
    } else {
      warning(
        "No 'id' column found in data. Cannot plot subject-specific predictions."
      )
    }
  }

  # Apply axis transformations.
  if (x_trans == "log10") {
    p <- p +
      ggplot2::scale_x_continuous(trans = scales::log10_trans()) +
      ggplot2::annotation_logticks(sides = "b")
  }
  if (y_trans == "log10") {
    p <- p +
      ggplot2::scale_y_continuous(trans = scales::log10_trans()) +
      ggplot2::annotation_logticks(sides = "l")
  }

  # Add appropriate legend titles if needed
  if (has_group_effects) {
    p <- p +
      ggplot2::guides(
        color = ggplot2::guide_legend(title = "Group"),
        fill = ggplot2::guide_legend(title = "Group")
      )
  }

  return(p)
}


#' Extract Coefficients from Cross-Price Demand Models
#'
#' @description Methods to extract coefficients from various cross-price demand model objects.
#'
#' @name coef-methods
#' @rdname coef-methods
NULL

#' @describeIn coef-methods Extract coefficients from a nonlinear cross-price model
#' @param object A cp_model_nls object
#' @param ... Additional arguments (not used).
#' @return Named vector of coefficients
#' @export
coef.cp_model_nls <- function(object, ...) {
  if (!inherits(object, "cp_model_nls")) {
    stop("Object must be of class 'cp_model_nls'")
  }

  # Simply extract the coefficients from the underlying model
  coef(object$model)
}

#' @describeIn coef-methods Extract coefficients from a linear cross-price model
#' @param object A cp_model_lm object
#' @param ... Additional arguments (not used).
#' @return A named numeric vector of model coefficients.
#' @export
coef.cp_model_lm <- function(object, ...) {
  if (!inherits(object, "cp_model_lm")) {
    stop("Object must be of class 'cp_model_lm'")
  }

  # Simply extract the coefficients from the underlying model
  coef(object$model)
}

#' @describeIn coef-methods Extract coefficients from a mixed-effects cross-price model
#' @param fixed_only Logical; if TRUE, returns only fixed effects. Default is FALSE.
#' @param combine Logical; if TRUE and fixed_only=FALSE, returns fixed + random effects combined. Default is TRUE.
#' @importFrom lme4 fixef
#' @export
coef.cp_model_lmer <- function(
  object,
  fixed_only = FALSE,
  combine = TRUE,
  ...
) {
  if (!inherits(object, "cp_model_lmer")) {
    stop("Object must be of class 'cp_model_lmer'")
  }

  if (!requireNamespace("lme4", quietly = TRUE)) {
    stop(
      "Package 'lme4' is required for extracting coefficients from mixed-effects models"
    )
  }

  # If only fixed effects are needed
  if (fixed_only) {
    return(lme4::fixef(object$model))
  }

  # Get both fixed and random effects
  if (combine) {
    # Don't use lme4::coef() - just use coef() directly
    return(coef(object$model)) # This uses S3 dispatch to find the right method
  } else {
    # Return a list with separate fixed and random effects
    list(
      fixed = lme4::fixef(object$model),
      random = lme4::ranef(object$model)
    )
  }
}


#' Extract Random Effects from Mixed-Effects Cross-Price Model
#'
#' @param object A cp_model_lmer object
#' @param ... Additional arguments passed to ranef
#' @return List of random effects
#' @importFrom lme4 ranef
#' @export
ranef.cp_model_lmer <- function(object, ...) {
  if (!inherits(object, "cp_model_lmer")) {
    stop("Object must be of class 'cp_model_lmer'")
  }

  if (!requireNamespace("lme4", quietly = TRUE)) {
    stop("Package 'lme4' is required for extracting random effects")
  }

  lme4::ranef(object$model, ...)
}

#' Extract Fixed Effects from Mixed-Effects Cross-Price Model
#'
#' @param object A cp_model_lmer object
#' @param ... Additional arguments passed to fixef
#' @return Named vector of fixed effects
#' @importFrom lme4 fixef
#' @export
fixef.cp_model_lmer <- function(object, ...) {
  if (!inherits(object, "cp_model_lmer")) {
    stop("Object must be of class 'cp_model_lmer'")
  }

  if (!requireNamespace("lme4", quietly = TRUE)) {
    stop("Package 'lme4' is required for extracting fixed effects")
  }

  lme4::fixef(object$model, ...)
}

#' Extract All Coefficient Types from Cross-Price Demand Models
#'
#' @description
#' A convenience function to extract coefficients from any type of cross-price
#' demand model in a unified format. For mixed effects models, returns a list
#' with different coefficient types.
#'
#' @param object A cross-price demand model object (cp_model_nls, cp_model_lm, or cp_model_lmer)
#' @param ... Additional arguments passed to the appropriate coef method
#' @return For cp_model_nls and cp_model_lm, returns the model coefficients.
#'   For cp_model_lmer, returns a list with fixed, random, and combined coefficients.
#' @examples
#' \donttest{
#' data(etm, package = "beezdemand")
#' fit <- fit_cp_nls(etm, equation = "exponentiated")
#' extract_coefficients(fit)
#' }
#' @importFrom lme4 fixef ranef
#' @export
extract_coefficients <- function(object, ...) {
  if (inherits(object, "cp_model_nls") || inherits(object, "cp_model_lm")) {
    return(coef(object, ...))
  } else if (inherits(object, "cp_model_lmer")) {
    return(list(
      fixed = fixef(object, ...),
      random = ranef(object, ...),
      combined = coef(object, fixed_only = FALSE, combine = TRUE, ...)
    ))
  } else {
    stop(
      "Unsupported model class. Must be one of: cp_model_nls, cp_model_lm, cp_model_lmer"
    )
  }
}

#-------------------------------------------------------------------------------
# Posthoc Methods for Linear Models
#' Test for significant interaction in a cross-price demand model
#'
#' @param object A cp_model_lmer object from fit_cp_linear
#' @param alpha Significance level for testing (default: 0.05)
#' @return Logical indicating whether interaction is significant
#' @keywords internal
has_significant_interaction <- function(object, alpha = 0.05) {
  if (!inherits(object, "cp_model_lmer")) {
    stop("Object must be a cp_model_lmer object")
  }

  # Get coefficient summary
  coefs <- coef(summary(object$model))

  # Look for interaction terms
  interaction_rows <- grep(":", rownames(coefs))

  if (length(interaction_rows) == 0) {
    return(FALSE) # No interaction terms found
  }

  # For mixed models, calculate p-values from t-values
  # Column indices may differ between model types
  if ("t value" %in% colnames(coefs)) {
    t_col <- "t value"
  } else if ("t.value" %in% colnames(coefs)) {
    t_col <- "t.value"
  } else {
    # If t-values are in column 3 (typical for lmer models)
    t_col <- 3
  }

  # Extract t-values for interaction terms
  t_values <- coefs[interaction_rows, t_col]

  # Calculate approximate p-values (two-tailed)
  # For mixed models, degrees of freedom are not straightforward
  # We'll use a conservative approach with normal approximation
  p_values <- 2 * (1 - pnorm(abs(t_values)))

  # Check if any interaction is significant
  any(p_values < alpha, na.rm = TRUE)
}

#' Run pairwise slope comparisons for cross-price demand model
#'
#' This function performs pairwise comparisons of slopes between groups in a
#' cross-price demand model, but only when a significant interaction is present.
#' The emmeans table showing estimated marginal means for slopes is always returned.
#'
#' @param object A cp_model_lmer object from fit_cp_linear
#' @param alpha Significance level for testing (default: 0.05)
#' @param adjust Method for p-value adjustment; see emmeans::contrast (default: "tukey")
#' @param ... Additional arguments passed to emmeans
#' @return List containing the emmeans table and optionally pairwise comparisons if interaction is significant
#' @examples
#' \donttest{
#' data(etm)
#' fit <- fit_cp_linear(etm, type = "mixed", group_effects = TRUE)
#' cp_posthoc_slopes(fit)
#' }
#' @importFrom emmeans emmeans emtrends contrast
#' @export
cp_posthoc_slopes <- function(object, alpha = 0.05, adjust = "tukey", ...) {
  if (!requireNamespace("emmeans", quietly = TRUE)) {
    stop("Package 'emmeans' is required for pairwise comparisons")
  }

  if (!inherits(object, "cp_model_lmer")) {
    stop("Object must be a cp_model_lmer object")
  }

  # Check if there's a significant interaction
  has_interaction <- has_significant_interaction(object, alpha)

  # Calculate the emmeans for slopes regardless of significance
  if (isTRUE(object$log10x)) {
    # The formula has log10(x), so identify x as the variable
    trends <- emmeans::emtrends(
      object$model,
      ~group,
      var = "x",
      tran = "log10",
      ...
    )
  } else {
    # For regular x, just use emtrends directly
    trends <- emmeans::emtrends(object$model, ~group, var = "x", ...)
  }

  # Always include the emmeans table in the result
  result <- list(
    emmeans = as.data.frame(summary(trends)),
    significant_interaction = has_interaction
  )

  # Only compute contrasts if there's a significant interaction
  if (has_interaction) {
    # Compute pairwise differences of slopes
    contrasts <- emmeans::contrast(
      trends,
      method = "pairwise",
      adjust = adjust,
      ...
    )

    # Convert to a clean data frame with standardized column names
    contrast_df <- as.data.frame(summary(contrasts))

    # Add significance indicators
    if ("p.value" %in% names(contrast_df)) {
      contrast_df$significance <- ""
      contrast_df$significance[contrast_df$p.value < alpha] <- "*"
      contrast_df$significance[contrast_df$p.value < alpha / 5] <- "**"
      contrast_df$significance[contrast_df$p.value < alpha / 20] <- "***"
    }

    result$contrasts <- contrast_df
  } else {
    result$message <- paste(
      "No significant interaction detected (alpha =",
      alpha,
      "). Pairwise slope comparisons not performed."
    )
  }

  # Set class and attributes
  class(result) <- c("cp_posthoc", class(result))
  attr(result, "adjust") <- adjust
  attr(result, "type") <- "slopes"

  return(result)
}

#' Run pairwise intercept comparisons for cross-price demand model
#'
#' This function performs pairwise comparisons of intercepts between groups in a
#' cross-price demand model, but only when a significant interaction is present.
#' The emmeans table showing estimated marginal means for intercepts is always returned.
#'
#' @param object A cp_model_lmer object from fit_cp_linear
#' @param alpha Significance level for testing (default: 0.05)
#' @param adjust Method for p-value adjustment; see emmeans::contrast (default: "tukey")
#' @param ... Additional arguments passed to emmeans
#' @return List containing the emmeans table and optionally pairwise comparisons if interaction is significant
#' @examples
#' \donttest{
#' data(etm)
#' fit <- fit_cp_linear(etm, type = "mixed", group_effects = TRUE)
#' cp_posthoc_intercepts(fit)
#' }
#' @importFrom emmeans emmeans emtrends contrast
#' @export
cp_posthoc_intercepts <- function(object, alpha = 0.05, adjust = "tukey", ...) {
  if (!requireNamespace("emmeans", quietly = TRUE)) {
    stop("Package 'emmeans' is required for pairwise comparisons")
  }

  if (!inherits(object, "cp_model_lmer")) {
    stop("Object must be a cp_model_lmer object")
  }

  # Check if there's a significant interaction
  has_interaction <- has_significant_interaction(object, alpha)

  # Create emmeans specifications, handling log-transformed models
  if (isTRUE(object$log10x)) {
    # For models with log10(x), set x=1 so that log10(x)=0
    emm <- emmeans::emmeans(object$model, specs = ~group, at = list(x = 1), ...)
  } else {
    # For linear models without transformation, set x=0 for intercept
    emm <- emmeans::emmeans(object$model, specs = ~group, at = list(x = 0), ...)
  }

  # Always include the emmeans table in the result
  result <- list(
    emmeans = as.data.frame(summary(emm)),
    significant_interaction = has_interaction
  )

  # Only compute contrasts if there's a significant interaction
  if (has_interaction) {
    # Compute pairwise differences
    contrasts <- emmeans::contrast(
      emm,
      method = "pairwise",
      adjust = adjust,
      ...
    )

    # Convert to a clean data frame with standardized column names
    contrast_df <- as.data.frame(summary(contrasts))

    # Add significance indicators
    if ("p.value" %in% names(contrast_df)) {
      contrast_df$significance <- ""
      contrast_df$significance[contrast_df$p.value < alpha] <- "*"
      contrast_df$significance[contrast_df$p.value < alpha / 5] <- "**"
      contrast_df$significance[contrast_df$p.value < alpha / 20] <- "***"
    }

    result$contrasts <- contrast_df
  } else {
    result$message <- paste(
      "No significant interaction detected (alpha =",
      alpha,
      "). Pairwise intercept comparisons not performed."
    )
  }

  # Set class and attributes
  class(result) <- c("cp_posthoc", class(result))
  attr(result, "adjust") <- adjust
  attr(result, "type") <- "intercepts"

  return(result)
}

#' Print method for cp_posthoc objects
#'
#' @param x A cp_posthoc object
#' @param ... Additional arguments passed to print
#' @return Invisibly returns the input object \code{x}.
#' @export
print.cp_posthoc <- function(x, ...) {
  # Get type attribute or default
  type <- attr(x, "type")
  if (is.null(type)) {
    type <- "Post-hoc"
  }

  # Create title based on type
  title <- switch(
    type,
    "slopes" = "Slope Estimates and Comparisons",
    "intercepts" = "Intercept Estimates and Comparisons",
    "Estimates and Post-hoc Comparisons"
  )

  cat(title, "\n")
  cat(paste(rep("=", nchar(title)), collapse = ""), "\n\n")

  # Print the emmeans table
  cat("Estimated Marginal Means:\n")
  print(x$emmeans, row.names = FALSE)
  cat("\n")

  # Print interaction status
  cat(
    "Significant interaction:",
    ifelse(x$significant_interaction, "Yes", "No"),
    "\n\n"
  )

  # If contrasts are available, print them
  if (!is.null(x$contrasts)) {
    cat("Pairwise Comparisons:\n")
    # Convert to a simple data frame and remove S3 class to avoid recursive calls
    df <- as.data.frame(unclass(x$contrasts))
    # Print the data frame in a simple way without row names
    print(df, row.names = FALSE)

    # Print significance legend if needed
    if ("significance" %in% names(df)) {
      cat("\nSignificance codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '' 1\n")
    }
  } else if (!is.null(x$message)) {
    cat(x$message, "\n")
  }

  # Print adjustment method if available
  adjust <- attr(x, "adjust")
  if (!is.null(adjust)) {
    cat("P-value adjustment method:", adjust, "\n")
  }

  # Return invisibly
  invisible(x)
}

Try the beezdemand package in your browser

Any scripts or data that you put into this service are public.

beezdemand documentation built on March 3, 2026, 9:07 a.m.