R/fixed-methods.R

Defines functions augment.beezdemand_fixed confint.beezdemand_fixed glance.beezdemand_fixed predict.beezdemand_fixed coef.beezdemand_fixed tidy.beezdemand_fixed print.summary.beezdemand_fixed summary.beezdemand_fixed plot.beezdemand_fixed print.beezdemand_fixed

Documented in augment.beezdemand_fixed coef.beezdemand_fixed confint.beezdemand_fixed glance.beezdemand_fixed plot.beezdemand_fixed predict.beezdemand_fixed print.beezdemand_fixed print.summary.beezdemand_fixed summary.beezdemand_fixed tidy.beezdemand_fixed

#' Print Method for beezdemand_fixed
#'
#' @param x A beezdemand_fixed object
#' @param ... Additional arguments (ignored)
#' @return Invisibly returns the input object \code{x}.
#' @export
print.beezdemand_fixed <- function(x, ...) {
  cat("\nFixed-Effect Demand Model\n")
  cat("==========================\n\n")

  cat("Call:\n")
  print(x$call)
  cat("\n")

  cat("Equation:", x$equation, "\n")
  cat("k:", x$k_spec, "\n")
  if (!is.null(x$agg)) {
    cat("Aggregation:", x$agg, "\n")
  }
  cat(
    "Subjects:",
    x$n_total,
    "(",
    x$n_success,
    "converged,",
    x$n_fail,
    "failed)\n"
  )
  cat("\nUse summary() for parameter summaries, tidy() for tidy output.\n")

  invisible(x)
}

#' Plot Method for beezdemand_fixed
#'
#' @param x A beezdemand_fixed object.
#' @param type Plot type: "demand", "population", "individual", or "both".
#' @param ids Optional vector of subject IDs to plot. Defaults to all subjects.
#' @param style Plot styling, passed to \code{theme_beezdemand()}.
#' @param show_observed Logical; if TRUE, overlay observed data points where possible.
#' @param show_pred Which prediction layers to plot: "population", "individual",
#'   or "both".
#' @param x_trans X-axis transform: "log", "log10", "linear", or "pseudo_log".
#' @param y_trans Y-axis transform: "log", "log10", "linear", or "pseudo_log".
#' @param free_trans Value used to display free (x = 0) on log scales. Use NULL
#'   to drop x <= 0 values instead.
#' @param x_limits Optional numeric vector of length 2 for x-axis limits.
#' @param y_limits Optional numeric vector of length 2 for y-axis limits.
#' @param n_points Number of points to use for prediction curves when thinning.
#' @param x_lab Optional x-axis label.
#' @param y_lab Optional y-axis label.
#' @param xlab Deprecated alias for \code{x_lab}.
#' @param ylab Deprecated alias for \code{y_lab}.
#' @param facet Faceting specification (TRUE for \code{~id} or a formula).
#' @param observed_point_alpha Alpha for observed points.
#' @param observed_point_size Size for observed points.
#' @param pop_line_alpha Alpha for population curve.
#' @param pop_line_size Line size for population curve.
#' @param ind_line_alpha Alpha for individual curves.
#' @param ind_line_size Line size for individual curves.
#' @param subtitle Optional subtitle for the plot.
#' @param ... Additional arguments (currently unused).
#'
#' @return A ggplot2 object.
#' @export
#' @importFrom rlang .data
plot.beezdemand_fixed <- function(
  x,
  type = c("demand", "population", "individual", "both"),
  ids = NULL,
  style = c("modern", "apa"),
  show_observed = TRUE,
  show_pred = NULL,
  x_trans = c("log10", "log", "linear", "pseudo_log"),
  y_trans = NULL,
  free_trans = 0.01,
  x_limits = NULL,
  y_limits = NULL,
  n_points = 200,
  x_lab = NULL,
  y_lab = NULL,
  xlab = NULL,
  ylab = NULL,
  facet = NULL,
  observed_point_alpha = 0.5,
  observed_point_size = 1.8,
  pop_line_alpha = 0.9,
  pop_line_size = 1.0,
  ind_line_alpha = 0.35,
  ind_line_size = 0.7,
  subtitle = NULL,
  ...
) {
  type <- match.arg(type)
  style <- match.arg(style)
  x_trans <- match.arg(x_trans)
  y_trans_missing <- is.null(y_trans)
  if (y_trans_missing) {
    y_trans <- beezdemand_default_y_trans(type = type, equation = x$equation)
  }
  y_trans <- match.arg(y_trans, c("log10", "log", "linear", "pseudo_log"))

  if (is.null(x$predictions) || is.null(x$data_used) || is.null(x$results)) {
    stop("Fit object lacks detailed predictions; refit with detailed = TRUE.")
  }

  labels <- beezdemand_normalize_plot_labels(x_lab, y_lab, xlab, ylab)
  x_lab <- labels$x_lab %||% "Price"
  y_lab <- labels$y_lab %||% "Consumption"

  agg_lower <- if (is.null(x$agg)) NA_character_ else tolower(x$agg)
  has_population <- !is.na(agg_lower) && agg_lower %in% c("mean", "pooled")
  has_individual <- !has_population

  results <- x$results
  ids_all <- if ("id" %in% names(results)) {
    as.character(results$id)
  } else {
    as.character(seq_along(x$predictions))
  }

  if (is.null(show_pred)) {
    show_pred <- if (has_population) "population" else "individual"
  }

  if (type == "population") {
    show_pred <- "population"
  } else if (type == "individual") {
    show_pred <- "individual"
  } else if (type == "both") {
    show_pred <- "both"
  }

  show_pred <- beezdemand_normalize_show_pred(show_pred)

  if (has_population && any(show_pred %in% "individual")) {
    stop("Individual fits are not available when agg = 'Mean' or 'Pooled'.")
  }
  if (has_individual && any(show_pred %in% "population")) {
    stop("Population curve is not available for per-person fits.")
  }

  if (has_population && !is.null(ids)) {
    stop("ids are not available when agg = 'Mean' or 'Pooled'.")
  }

  if (has_individual) {
    if (is.null(ids)) {
      ids <- ids_all
    }
    ids <- as.character(ids)
    idx <- match(ids, ids_all)
    idx <- idx[!is.na(idx)]
    if (length(idx) == 0) {
      stop("No matching ids found in fit results.")
    }
  }

  y_trans_res <- beezdemand_resolve_y_trans(y_trans, y_is_log = FALSE)
  y_trans <- y_trans_res$y_trans

  obs_df <- NULL
  pred_df <- NULL
  pop_df <- NULL
  subtitle_note <- FALSE
  free_trans_used <- FALSE

  if (any(show_pred %in% "population")) {
    pop_pred <- x$predictions[[1]]
    pop_df <- data.frame(
      id = ids_all[1],
      x = pop_pred$x,
      y = pop_pred$y
    )
    if (nrow(pop_df) > n_points) {
      pop_df <- pop_df[
        round(seq(1, nrow(pop_df), length.out = n_points)),
        ,
        drop = FALSE
      ]
    }

    free_pop <- beezdemand_apply_free_trans(pop_df, "x", x_trans, free_trans)
    pop_df <- free_pop$data
    free_trans_used <- free_trans_used || free_pop$replaced

    pop_y <- beezdemand_drop_nonpositive_y(pop_df, "y", y_trans)
    pop_df <- pop_y$data
    subtitle_note <- subtitle_note || pop_y$dropped
  }

  if (show_observed) {
    if (has_population) {
      obs_df <- x$data_used[[1]]
    } else {
      obs_df <- do.call(rbind, lapply(idx, function(i) x$data_used[[i]]))
    }
    obs_df$id <- as.character(obs_df$id)
    free_obs <- beezdemand_apply_free_trans(obs_df, "x", x_trans, free_trans)
    obs_df <- free_obs$data
    free_trans_used <- free_trans_used || free_obs$replaced

    obs_y <- beezdemand_drop_nonpositive_y(obs_df, "y", y_trans)
    obs_df <- obs_y$data
    subtitle_note <- subtitle_note || obs_y$dropped
  }

  if (any(show_pred %in% "individual")) {
    pred_df <- do.call(
      rbind,
      lapply(idx, function(i) {
        pred <- x$predictions[[i]]
        pred$id <- ids_all[i]
        if (nrow(pred) > n_points) {
          pred <- pred[
            round(seq(1, nrow(pred), length.out = n_points)),
            ,
            drop = FALSE
          ]
        }
        pred
      })
    )
    pred_df$id <- as.character(pred_df$id)

    free_pred <- beezdemand_apply_free_trans(pred_df, "x", x_trans, free_trans)
    pred_df <- free_pred$data
    free_trans_used <- free_trans_used || free_pred$replaced

    pred_y <- beezdemand_drop_nonpositive_y(pred_df, "y", y_trans)
    pred_df <- pred_y$data
    subtitle_note <- subtitle_note || pred_y$dropped
  }

  color_by <- if (
    any(show_pred %in% "individual") && length(unique(pred_df$id)) > 1
  ) {
    "id"
  } else {
    NULL
  }

  p <- ggplot2::ggplot()

  if (!is.null(pop_df)) {
    p <- p +
      ggplot2::geom_line(
        data = pop_df,
        ggplot2::aes(x = x, y = y),
        linewidth = pop_line_size,
        alpha = pop_line_alpha,
        color = beezdemand_style_color(style, "primary")
      )
  }

  if (!is.null(pred_df)) {
    aes_pred <- if (is.null(color_by)) {
      ggplot2::aes(x = x, y = y, group = id)
    } else {
      ggplot2::aes(x = x, y = y, color = .data[[color_by]], group = id)
    }
    p <- p +
      ggplot2::geom_line(
        data = pred_df,
        mapping = aes_pred,
        linewidth = ind_line_size,
        alpha = ind_line_alpha
      )
  }

  if (show_observed && !is.null(obs_df) && nrow(obs_df) > 0) {
    aes_obs <- if (is.null(color_by)) {
      ggplot2::aes(x = x, y = y)
    } else {
      ggplot2::aes(x = x, y = y, color = .data[[color_by]])
    }
    p <- p +
      ggplot2::geom_point(
        data = obs_df,
        mapping = aes_obs,
        alpha = observed_point_alpha,
        size = observed_point_size
      )
  }

  if (!is.null(facet)) {
    if (isTRUE(facet)) {
      p <- p + ggplot2::facet_wrap(~id)
    } else if (is.character(facet)) {
      p <- p + ggplot2::facet_wrap(stats::as.formula(facet))
    } else if (inherits(facet, "formula")) {
      p <- p + ggplot2::facet_wrap(facet)
    }
  }

  if (isTRUE(subtitle_note)) {
    if (is.null(subtitle)) {
      subtitle <- "Zeros omitted on log scale."
    } else {
      subtitle <- paste(subtitle, "Zeros omitted on log scale.")
    }
  }
  beezdemand_warn_free_trans(free_trans_used, free_trans)

  x_limits <- beezdemand_resolve_limits(x_limits, x_trans, axis = "x")
  y_limits <- beezdemand_resolve_limits(y_limits, y_trans, axis = "y")

  p <- p +
    ggplot2::scale_x_continuous(
      trans = beezdemand_get_trans(x_trans),
      limits = x_limits,
      labels = beezdemand_axis_labels()
    ) +
    ggplot2::scale_y_continuous(
      trans = beezdemand_get_trans(y_trans),
      limits = y_limits,
      labels = beezdemand_axis_labels()
    ) +
    ggplot2::labs(
      title = if (type == "population") "Population Demand Curve" else NULL,
      subtitle = subtitle,
      x = x_lab,
      y = y_lab
    ) +
    theme_beezdemand(style = style)

  if (!is.null(color_by)) {
    scale_data <- if (!is.null(pred_df)) pred_df else obs_df
    p <- beezdemand_apply_color_scale(p, style, scale_data, color_by)
  }

  if (x_trans == "log10") {
    p <- p + ggplot2::annotation_logticks(sides = "b")
  }
  if (y_trans == "log10") {
    p <- p + ggplot2::annotation_logticks(sides = "l")
  }

  p
}

#' Summary Method for beezdemand_fixed
#'
#' @param object A beezdemand_fixed object
#' @param report_space Character. Reporting space for core parameters. One of:
#'   - `"natural"`: report natural-scale parameters (default)
#'   - `"log10"`: report `log10()`-scale parameters when defined
#' @param ... Additional arguments (ignored)
#' @return A summary.beezdemand_fixed object (inherits from beezdemand_summary)
#' @export
summary.beezdemand_fixed <- function(
  object,
  report_space = c("natural", "log10"),
  ...
) {
  report_space <- match.arg(report_space)
  # Build coefficients tibble from results if available
  if (
    !is.null(object$results) &&
      is.data.frame(object$results) &&
      nrow(object$results) > 0
  ) {
    results <- object$results
    id_values <- beezdemand_fixed_id_values(results)

    param_specs <- beezdemand_fixed_param_specs(results)

    coefficients_list <- lapply(names(param_specs), function(term_name) {
      spec <- param_specs[[term_name]]
      tibble::tibble(
        id = id_values,
        term = term_name,
        estimate = results[[spec$estimate]],
        std.error = if (!is.na(spec$se) && spec$se %in% names(results)) {
          results[[spec$se]]
        } else {
          NA_real_
        },
        statistic = NA_real_,
        p.value = NA_real_,
        component = "fixed",
        estimate_scale = "natural",
        term_display = term_name
      )
    })
    coefficients <- dplyr::bind_rows(coefficients_list)
    coefficients <- beezdemand_transform_coef_table(
      coef_tbl = coefficients,
      report_space = report_space,
      internal_space = "natural"
    )

    param_summary <- lapply(names(param_specs), function(term_name) {
      vals <- results[[param_specs[[term_name]]$estimate]]
      vals <- vals[!is.na(vals)]
      # Transform to log10 scale if requested and parameter supports it
      if (report_space == "log10" && term_name %in% c("Q0", "alpha", "k")) {
        # Only transform positive values; filter out non-positive
        vals <- vals[vals > 0]
        if (length(vals) > 0) {
          vals <- log10(vals)
        }
      }
      if (length(vals) > 0) summary(vals) else NULL
    })
    names(param_summary) <- names(param_specs)

    # Count observations if data_used is available
    nobs <- if (!is.null(object$data_used)) {
      sum(vapply(object$data_used, nrow, integer(1)))
    } else {
      NA_integer_
    }
  } else {
    coefficients <- beezdemand_empty_coefficients()
    param_summary <- list()
    nobs <- NA_integer_
  }

  # Compute derived metrics (pmax/omax) per subject using unified engine
  derived_metrics <- beezdemand_empty_derived_metrics()
  pmax_method_info <- list()

  if (
    !is.null(object$results) &&
      is.data.frame(object$results) &&
      nrow(object$results) > 0 &&
      object$equation %in% c("hs", "koff", "simplified")
  ) {
    # Get parameter columns
    q0_col <- param_specs$Q0$estimate %||% NULL
    alpha_col <- param_specs$alpha$estimate %||% NULL
    k_col <- param_specs$k$estimate %||% NULL

    has_required <- !is.null(q0_col) && !is.null(alpha_col)
    # hs/koff need k; simplified does not
    if (object$equation %in% c("hs", "koff")) {
      has_required <- has_required && !is.null(k_col)
    }

    if (has_required) {
      # Determine parameter scale based on param_space
      param_scale <- object$param_space %||% "natural"

      # For each subject, compute pmax/omax
      pmax_results <- lapply(seq_len(nrow(object$results)), function(i) {
        row <- object$results[i, ]

        # Get price observations for this subject if available
        price_obs <- NULL
        if (!is.null(object$data_used) && length(object$data_used) >= i) {
          price_obs <- object$data_used[[i]]$x
        }

        # Build params list (simplified has no k)
        params_i <- list(
          alpha = row[[alpha_col]],
          q0 = row[[q0_col]]
        )
        param_scales_i <- list(
          alpha = param_scale,
          q0 = param_scale
        )
        if (!is.null(k_col) && object$equation != "simplified") {
          params_i$k <- row[[k_col]]
          param_scales_i$k <- "natural"
        }

        beezdemand_calc_pmax_omax(
          model_type = object$equation,
          params = params_i,
          param_scales = param_scales_i,
          price_obs = price_obs,
          compute_observed = FALSE
        )
      })

      # Aggregate results
      pmax_vals <- vapply(pmax_results, function(x) x$pmax_model, numeric(1))
      omax_vals <- vapply(pmax_results, function(x) x$omax_model, numeric(1))
      methods <- vapply(pmax_results, function(x) x$method_model, character(1))

      # Store per-subject in results
      object$results$pmax_model <- pmax_vals
      object$results$omax_model <- omax_vals
      object$results$pmax_method <- methods

      derived_metrics <- dplyr::bind_rows(
        derived_metrics,
        tibble::tibble(
          metric = rep("pmax_model", length(pmax_vals)),
          estimate = as.numeric(pmax_vals),
          std.error = NA_real_,
          conf.low = NA_real_,
          conf.high = NA_real_,
          method = as.character(methods),
          component = "consumption",
          level = "individual",
          id = as.character(id_values)
        ),
        tibble::tibble(
          metric = rep("omax_model", length(omax_vals)),
          estimate = as.numeric(omax_vals),
          std.error = NA_real_,
          conf.low = NA_real_,
          conf.high = NA_real_,
          method = as.character(methods),
          component = "consumption",
          level = "individual",
          id = as.character(id_values)
        )
      )

      # Summary metrics
      pmax_summary <- if (any(!is.na(pmax_vals))) {
        summary(pmax_vals[!is.na(pmax_vals)])
      } else {
        NULL
      }
      omax_summary <- if (any(!is.na(omax_vals))) {
        summary(omax_vals[!is.na(omax_vals)])
      } else {
        NULL
      }

      param_summary$pmax_model <- pmax_summary
      param_summary$omax_model <- omax_summary

      # Method info (use most common method)
      if (length(methods) > 0) {
        method_table <- table(methods[!is.na(methods)])
        if (length(method_table) > 0) {
          pmax_method_info <- list(
            method_model = names(which.max(method_table)),
            n_analytic = sum(grepl("analytic", methods, ignore.case = TRUE)),
            n_numerical = sum(grepl("numerical", methods, ignore.case = TRUE))
          )
        }
      }
    }
  }

  structure(
    list(
      call = object$call,
      model_class = "beezdemand_fixed",
      backend = "legacy",
      equation = object$equation,
      param_space = object$param_space %||% "natural",
      report_space = report_space,
      k_spec = object$k_spec,
      agg = object$agg,
      nobs = nobs,
      n_subjects = object$n_total,
      n_success = object$n_success,
      n_fail = object$n_fail,
      converged = NA,
      logLik = NA_real_,
      AIC = NA_real_,
      BIC = NA_real_,
      coefficients = coefficients,
      derived_metrics = derived_metrics,
      param_summary = param_summary,
      pmax_method_info = pmax_method_info,
      results = object$results,
      notes = character(0)
    ),
    class = c("summary.beezdemand_fixed", "beezdemand_summary")
  )
}

#' Print Method for summary.beezdemand_fixed
#'
#' @param x A summary.beezdemand_fixed object
#' @param digits Number of significant digits to print
#' @param n Number of subjects (ids) to print (default 20)
#' @param ... Additional arguments (ignored)
#' @return Invisibly returns the input object \code{x}.
#' @export
print.summary.beezdemand_fixed <- function(x, digits = 4, n = 20, ...) {
  cat("\n")
  cat("Fixed-Effect Demand Model Summary\n")
  cat(strrep("=", 50), "\n\n")

  cat("Equation:", x$equation, "\n")
  cat("k:", x$k_spec, "\n")
  if (!is.null(x$agg)) {
    cat("Aggregation:", x$agg, "\n")
  }
  cat("\n")

  cat("Fit Summary:\n")
  cat("  Total subjects:", x$n_subjects, "\n")
  cat("  Converged:", x$n_success, "\n")
  cat("  Failed:", x$n_fail, "\n")
  if (!is.na(x$nobs)) {
    cat("  Total observations:", x$nobs, "\n")
  }
  cat("\n")

  if (length(x$param_summary) > 0) {
    cat("Parameter Summary (across subjects):\n")

    if (!is.null(x$param_summary$Q0)) {
      q0_sum <- x$param_summary$Q0
      cat("  Q0:\n")
      cat("    Median:", round(q0_sum["Median"], digits), "\n")
      cat(
        "    Range: [",
        round(q0_sum["Min."], digits),
        ",",
        round(q0_sum["Max."], digits),
        "]\n"
      )
    }

    if (!is.null(x$param_summary$alpha)) {
      alpha_sum <- x$param_summary$alpha
      cat("  alpha:\n")
      cat("    Median:", round(alpha_sum["Median"], 6), "\n")
      cat(
        "    Range: [",
        round(alpha_sum["Min."], 6),
        ",",
        round(alpha_sum["Max."], 6),
        "]\n"
      )
    }
  }

  if (!is.null(x$coefficients) && nrow(x$coefficients) > 0) {
    n_to_print <- n
    if (is.null(n_to_print)) {
      n_to_print <- 20
    }

    coef_all <- x$coefficients |>
      dplyr::mutate(
        id = as.character(.data$id),
        term = as.character(.data$term)
      ) |>
      dplyr::arrange(.data$id, .data$term)

    ids <- unique(coef_all$id)
    ids_to_print <- ids
    if (is.finite(n_to_print) && length(ids) > n_to_print) {
      ids_to_print <- utils::head(ids, n_to_print)
    }

    coef_print <- coef_all |>
      dplyr::filter(.data$id %in% ids_to_print) |>
      dplyr::arrange(.data$id, .data$term)

    cat("\nPer-subject coefficients:\n")
    cat("-------------------------\n")
    print(coef_print, row.names = FALSE)

    if (is.finite(n_to_print) && length(ids) > n_to_print) {
      cat("\n(Showing first", n_to_print, "ids of", length(ids), ")\n")
    }
  }

  if (length(x$notes) > 0) {
    cat("\nNotes:\n")
    for (note in x$notes) {
      cat("  -", note, "\n")
    }
  }

  invisible(x)
}

#' Tidy Method for beezdemand_fixed
#'
#' @param x A beezdemand_fixed object
#' @param report_space Character. Reporting space for core parameters. One of
#'   `"natural"` (default) or `"log10"`.
#' @param ... Additional arguments (ignored)
#' @return A tibble of model coefficients with columns: id, term, estimate,
#'   std.error, statistic, p.value, component
#' @export
tidy.beezdemand_fixed <- function(
  x,
  report_space = c("natural", "log10"),
  ...
) {
  report_space <- match.arg(report_space)
  if (is.null(x$results) || !is.data.frame(x$results) || nrow(x$results) == 0) {
    return(tibble::tibble(
      id = character(),
      term = character(),
      estimate = numeric(),
      std.error = numeric(),
      statistic = numeric(),
      p.value = numeric(),
      component = character(),
      estimate_scale = character(),
      term_display = character()
    ))
  }
  results <- x$results
  id_values <- beezdemand_fixed_id_values(results)

  param_specs <- beezdemand_fixed_param_specs(results)
  coefficient_rows <- lapply(names(param_specs), function(term_name) {
    spec <- param_specs[[term_name]]
    tibble::tibble(
      id = id_values,
      term = term_name,
      estimate = results[[spec$estimate]],
      std.error = if (!is.na(spec$se) && spec$se %in% names(results)) {
        results[[spec$se]]
      } else {
        NA_real_
      },
      statistic = NA_real_,
      p.value = NA_real_,
      component = "fixed",
      estimate_scale = "natural",
      term_display = term_name
    )
  })

  out <- dplyr::bind_rows(coefficient_rows)
  beezdemand_transform_coef_table(
    coef_tbl = out,
    report_space = report_space,
    internal_space = "natural"
  )
}

#' Extract Coefficients from Fixed-Effect Demand Fit
#'
#' @param object A `beezdemand_fixed` object.
#' @param report_space One of `"internal"`, `"natural"`, or `"log10"`. Default `"internal"`.
#' @param ... Unused.
#' @return A tibble with columns `id`, `term`, `estimate`, `estimate_scale`, `term_display`.
#' @export
coef.beezdemand_fixed <- function(
  object,
  report_space = c("internal", "natural", "log10"),
  ...
) {
  report_space <- match.arg(report_space)

  if (is.null(object$fits) || !length(object$fits)) {
    # Fallback: no per-id model objects available; use results table (natural scale).
    out <- tidy.beezdemand_fixed(object, report_space = "natural")
    if (report_space == "internal") {
      return(out)
    }
    return(tidy.beezdemand_fixed(object, report_space = report_space))
  }

  ids <- names(object$fits)
  rows <- lapply(seq_along(object$fits), function(i) {
    fit <- object$fits[[i]]
    id <- if (!is.null(ids) && length(ids) >= i) ids[[i]] else NA_character_
    if (inherits(fit, "try-error") || is.null(fit)) {
      return(NULL)
    }
    cf <- tryCatch(stats::coef(fit), error = function(e) NULL)
    if (is.null(cf)) {
      return(NULL)
    }
    tibble::tibble(
      id = as.character(id),
      term = names(cf),
      estimate = as.numeric(cf)
    )
  })

  out <- dplyr::bind_rows(rows)
  if (!nrow(out)) {
    return(tibble::tibble(
      id = character(),
      term = character(),
      estimate = numeric(),
      estimate_scale = character(),
      term_display = character()
    ))
  }

  internal_space <- object$param_space %||% "natural"
  internal_space <- if (internal_space == "log10") "log10" else "natural"
  requested <- if (report_space == "internal") internal_space else report_space

  out <- out |>
    dplyr::mutate(
      estimate_scale = internal_space,
      term_display = .data$term
    )

  if (requested != internal_space) {
    out <- out |>
      dplyr::mutate(
        estimate_internal = .data$estimate
      )

    if (internal_space == "natural" && requested == "log10") {
      out <- out |>
        dplyr::mutate(
          estimate = dplyr::case_when(
            .data$term %in% c("q0", "alpha", "k") & .data$estimate > 0 ~ log10(
              .data$estimate
            ),
            TRUE ~ .data$estimate
          ),
          estimate_scale = dplyr::case_when(
            .data$term %in% c("q0", "alpha", "k") ~ "log10",
            TRUE ~ .data$estimate_scale
          ),
          term_display = dplyr::case_when(
            .data$term == "q0" ~ "log10(Q0)",
            .data$term == "alpha" ~ "log10(alpha)",
            .data$term == "k" ~ "log10(k)",
            TRUE ~ .data$term
          )
        )
    } else if (internal_space == "log10" && requested == "natural") {
      out <- out |>
        dplyr::mutate(
          estimate = dplyr::case_when(
            .data$term %in% c("q0", "alpha", "k") ~ 10^.data$estimate,
            TRUE ~ .data$estimate
          ),
          estimate_scale = dplyr::case_when(
            .data$term %in% c("q0", "alpha", "k") ~ "natural",
            TRUE ~ .data$estimate_scale
          ),
          term_display = dplyr::case_when(
            .data$term == "q0" ~ "Q0",
            .data$term == "alpha" ~ "alpha",
            .data$term == "k" ~ "k",
            TRUE ~ .data$term
          )
        )
    }
  }

  out
}

#' Predict Method for beezdemand_fixed
#'
#' @param object A `beezdemand_fixed` object.
#' @param newdata A data frame containing a price column matching the fitted
#'   object's `x_var`. If `NULL`, uses the unique observed prices when available.
#' @param type One of `"response"` (default) or `"link"`.
#' @param se.fit Logical; if `TRUE`, includes a `.se.fit` column (currently `NA`
#'   because vcov is not available from legacy fixed fits).
#' @param interval One of `"none"` (default) or `"confidence"`. When requested,
#'   `.lower`/`.upper` are returned as `NA` because vcov is unavailable.
#' @param level Confidence level when `interval = "confidence"`. Currently used
#'   only for validation.
#' @param ... Unused.
#'
#' @return A tibble containing the original `newdata` columns, plus `.fitted`
#'   and, when requested, `.se.fit` and `.lower`/`.upper`. If `newdata` does not
#'   include an id column, predictions are returned for all subjects (cross
#'   product of `newdata` × subjects) unless `k` is subject-specific (`k = "ind"`).
#' @export
predict.beezdemand_fixed <- function(
  object,
  newdata = NULL,
  type = c("response", "link"),
  se.fit = FALSE,
  interval = c("none", "confidence"),
  level = 0.95,
  ...
) {
  type <- match.arg(type)
  interval <- match.arg(interval)

  if (
    !is.null(level) &&
      (!is.numeric(level) ||
        length(level) != 1 ||
        is.na(level) ||
        level <= 0 ||
        level >= 1)
  ) {
    stop("'level' must be a single number between 0 and 1.", call. = FALSE)
  }

  x_var <- object$x_var %||% "x"
  id_var <- object$id_var %||% "id"

  results <- object$results
  if (is.null(results) || !is.data.frame(results) || nrow(results) == 0) {
    return(tibble::tibble())
  }

  ids <- beezdemand_fixed_id_values(results)
  if (all(is.na(ids))) {
    ids <- as.character(seq_len(nrow(results)))
  }

  if (is.null(newdata)) {
    prices <- NULL
    if (
      !is.null(object$data_used) &&
        length(object$data_used) > 0 &&
        x_var %in% names(object$data_used[[1]])
    ) {
      prices <- sort(unique(unlist(lapply(object$data_used, `[[`, x_var))))
    } else if (
      !is.null(object$predictions) &&
        length(object$predictions) > 0 &&
        x_var %in% names(object$predictions[[1]])
    ) {
      prices <- sort(unique(unlist(lapply(object$predictions, `[[`, x_var))))
    }
    if (is.null(prices) || !length(prices)) {
      stop(
        "`newdata` is required when the fit object does not retain observed prices.",
        call. = FALSE
      )
    }
    newdata <- data.frame(prices, stringsAsFactors = FALSE)
    names(newdata) <- x_var
  }

  if (!is.data.frame(newdata)) {
    newdata <- as.data.frame(newdata)
  }
  if (!(x_var %in% names(newdata))) {
    stop(
      "`newdata` must include the price column `",
      x_var,
      "`.",
      call. = FALSE
    )
  }

  prices <- newdata[[x_var]]
  if (!is.numeric(prices)) {
    stop("`newdata[[", x_var, "]]` must be numeric.", call. = FALSE)
  }

  if (isTRUE(object$k_spec == "ind") && !(id_var %in% names(newdata))) {
    stop(
      "Subject-specific k requires `newdata` to include the id column `",
      id_var,
      "`.",
      call. = FALSE
    )
  }

  eq <- object$equation %||% "hs"
  param_specs <- beezdemand_fixed_param_specs(results)

  get_param <- function(name) {
    spec <- param_specs[[name]] %||% NULL
    col <- spec$estimate %||% NULL
    if (is.null(col) || !(col %in% names(results))) {
      return(rep(NA_real_, nrow(results)))
    }
    as.numeric(results[[col]])
  }

  pars <- tibble::tibble(
    !!id_var := ids,
    Q0 = get_param("Q0"),
    alpha = get_param("alpha"),
    k = get_param("k"),
    L = get_param("L"),
    b = get_param("b"),
    a = get_param("a")
  )

  if (id_var %in% names(newdata)) {
    newdata[[id_var]] <- as.character(newdata[[id_var]])
    idx <- match(newdata[[id_var]], pars[[id_var]])
    if (anyNA(idx)) {
      missing_ids <- unique(newdata[[id_var]][is.na(idx)])
      stop(
        "Unknown id values in `newdata`: ",
        paste(missing_ids, collapse = ", "),
        ".",
        call. = FALSE
      )
    }
    out <- tibble::as_tibble(newdata)
    pars_row <- pars[idx, , drop = FALSE]
  } else {
    n_subjects <- nrow(pars)
    n_rows <- nrow(newdata)
    grid <- expand.grid(
      subject_row = seq_len(n_subjects),
      data_row = seq_len(n_rows)
    )
    out <- tibble::as_tibble(newdata[grid$data_row, , drop = FALSE])
    out[[id_var]] <- pars[[id_var]][grid$subject_row]
    pars_row <- pars[grid$subject_row, , drop = FALSE]
  }

  price_vec <- out[[x_var]]

  fitted_link <- rep(NA_real_, length(price_vec))
  if (eq %in% c("hs", "koff")) {
    log10_q0 <- ifelse(
      is.finite(pars_row$Q0) & pars_row$Q0 > 0,
      log10(pars_row$Q0),
      NA_real_
    )
    fitted_link <- log10_q0 +
      pars_row$k * (exp(-pars_row$alpha * pars_row$Q0 * price_vec) - 1)
  } else if (eq == "simplified") {
    # Simplified: Q = Q0 * exp(-alpha * Q0 * P)
    # link scale = log(Q) = log(Q0) - alpha * Q0 * P
    log_q0 <- ifelse(
      is.finite(pars_row$Q0) & pars_row$Q0 > 0,
      log(pars_row$Q0),
      NA_real_
    )
    fitted_link <- log_q0 - pars_row$alpha * pars_row$Q0 * price_vec
  } else if (eq == "linear") {
    log_l <- ifelse(
      is.finite(pars_row$L) & pars_row$L > 0,
      log(pars_row$L),
      NA_real_
    )
    log_x <- ifelse(
      is.finite(price_vec) & price_vec > 0,
      log(price_vec),
      NA_real_
    )
    fitted_link <- log_l +
      pars_row$b * log_x -
      pars_row$a * price_vec
  } else {
    stop(
      "Unsupported equation `",
      eq,
      "` for `predict.beezdemand_fixed()`.",
      call. = FALSE
    )
  }

  fitted_response <- if (eq %in% c("linear", "simplified")) {
    exp(fitted_link)
  } else {
    10^fitted_link
  }

  out$.fitted <- if (type == "response") fitted_response else fitted_link

  if (isTRUE(se.fit) || interval != "none") {
    warning(
      "Standard errors/intervals are not available for `beezdemand_fixed` predictions ",
      "(vcov is unavailable from legacy fixed fits); returning NA for uncertainty columns.",
      call. = FALSE
    )
    out$.se.fit <- NA_real_
    if (interval != "none") {
      out$.lower <- NA_real_
      out$.upper <- NA_real_
    }
  }

  out
}

#' Glance Method for beezdemand_fixed
#'
#' @param x A beezdemand_fixed object
#' @param ... Additional arguments (ignored)
#' @return A one-row tibble of model statistics
#' @export
glance.beezdemand_fixed <- function(x, ...) {
  # Count observations if data_used is available
  nobs <- if (!is.null(x$data_used)) {
    sum(vapply(x$data_used, nrow, integer(1)))
  } else {
    NA_integer_
  }

  tibble::tibble(
    model_class = "beezdemand_fixed",
    backend = "legacy",
    equation = x$equation,
    k_spec = x$k_spec,
    nobs = nobs,
    n_subjects = x$n_total,
    n_success = x$n_success,
    n_fail = x$n_fail,
    converged = NA,
    logLik = NA_real_,
    AIC = NA_real_,
    BIC = NA_real_
  )
}

#' Confidence Intervals for Fixed-Effect Demand Model Parameters
#'
#' Computes confidence intervals for Q0, alpha, and k parameters from
#' individual demand curve fits. Uses asymptotic normal approximation based
#' on standard errors when available.
#'
#' @param object A `beezdemand_fixed` object from [fit_demand_fixed()].
#' @param parm Character vector of parameter names to compute CIs for.
#'   Default includes all available parameters.
#' @param level Confidence level (default 0.95).
#' @param ... Additional arguments (ignored).
#'
#' @return A tibble with columns: `id`, `term`, `estimate`, `conf.low`,
#'   `conf.high`, `level`.
#'
#' @details
#' For `beezdemand_fixed` objects, confidence intervals are computed using
#' the asymptotic normal approximation: estimate +/- z * SE. If standard errors
#' are not available for a parameter, the confidence bounds will be `NA`.
#'
#' When the underlying NLS fit objects are available (from `detailed = TRUE`),
#' this method attempts to use `nlstools::confint2()` for more accurate
#' profile-based intervals.
#'
#' @examples
#' \donttest{
#' fit <- fit_demand_fixed(apt, equation = "hs", k = 2)
#' confint(fit)
#' confint(fit, level = 0.90)
#' confint(fit, parm = "Q0")
#' }
#'
#' @importFrom stats qnorm
#' @export
confint.beezdemand_fixed <- function(object, parm = NULL, level = 0.95, ...) {
  if (!is.numeric(level) || length(level) != 1 || level <= 0 || level >= 1) {
    stop("`level` must be a single number between 0 and 1.", call. = FALSE)
  }

  results <- object$results
  if (is.null(results) || !is.data.frame(results) || nrow(results) == 0) {
    return(tibble::tibble(
      id = character(),
      term = character(),
      estimate = numeric(),
      conf.low = numeric(),
      conf.high = numeric(),
      level = numeric()
    ))
  }

  # Define parameter mappings
  param_map <- list(
    Q0 = list(est = "Q0d", se = "Q0se"),
    alpha = list(est = "Alpha", se = "Alphase"),
    k = list(est = "K", se = NA_character_),
    alpha_star = list(est = "alpha_star", se = "alpha_star_se")
  )

  # Get IDs
  ids <- if ("id" %in% names(results)) {
    as.character(results$id)
  } else {
    as.character(seq_len(nrow(results)))
  }

  # Determine which parameters to include
  available_params <- names(param_map)[vapply(
    names(param_map),
    function(p) {
      param_map[[p]]$est %in% names(results)
    },
    logical(1)
  )]

  if (is.null(parm)) {
    parm <- available_params
  } else {
    parm <- intersect(parm, available_params)
  }

  if (length(parm) == 0) {
    warning("No requested parameters found in fit results.", call. = FALSE)
    return(tibble::tibble(
      id = character(),
      term = character(),
      estimate = numeric(),
      conf.low = numeric(),
      conf.high = numeric(),
      level = numeric()
    ))
  }

  z <- stats::qnorm((1 + level) / 2)

  ci_rows <- lapply(parm, function(p) {
    spec <- param_map[[p]]
    est_col <- spec$est
    se_col <- spec$se

    est <- results[[est_col]]
    se <- if (!is.na(se_col) && se_col %in% names(results)) {
      results[[se_col]]
    } else {
      rep(NA_real_, length(est))
    }

    tibble::tibble(
      id = ids,
      term = p,
      estimate = est,
      conf.low = est - z * se,
      conf.high = est + z * se,
      level = level
    )
  })

  dplyr::bind_rows(ci_rows)
}


#' Augment a beezdemand_fixed Model with Fitted Values and Residuals
#'
#' @description
#' Returns the original data with fitted values and residuals from individual
#' demand curve fits. This enables easy model diagnostics and visualization
#' with the tidyverse.
#'
#' @param x An object of class \code{beezdemand_fixed}.
#' @param newdata Optional data frame of new data for prediction. If NULL,
#'   uses the original data from the model.
#' @param ... Additional arguments (currently unused).
#'
#' @return A tibble containing the original data plus:
#'   \describe{
#'     \item{.fitted}{Fitted demand values on the response scale}
#'     \item{.resid}{Residuals (observed - fitted)}
#'   }
#'
#' @details
#' For "hs" equation models where fitting is done on the log10 scale,
#' fitted values are back-transformed to the natural scale.
#'
#' @examples
#' \donttest{
#' data(apt)
#' fit <- fit_demand_fixed(apt, y_var = "y", x_var = "x", id_var = "id")
#' augmented <- augment(fit)
#'
#' # Plot residuals by subject
#' library(ggplot2)
#' ggplot(augmented, aes(x = .fitted, y = .resid)) +
#'   geom_point(alpha = 0.5) +
#'   facet_wrap(~id) +
#'   geom_hline(yintercept = 0, linetype = "dashed")
#' }
#'
#' @importFrom tibble as_tibble
#' @export
augment.beezdemand_fixed <- function(x, newdata = NULL, ...) {
  # Get variable names
  y_var <- x$y_var %||% "y"
  x_var <- x$x_var %||% "x"
  id_var <- x$id_var %||% "id"

  if (!is.null(newdata)) {
    # Use predict for new data
    preds <- predict(x, newdata = newdata)
    out <- tibble::as_tibble(newdata)
    out$.fitted <- preds$.fitted
    y_obs <- newdata[[y_var]]
    out$.resid <- if (!is.null(y_obs)) y_obs - out$.fitted else NA_real_
    return(out)
  }

  # Use original data from data_used if available
  if (is.null(x$data_used) || length(x$data_used) == 0) {
    stop(
      "No data available. Provide 'newdata' or ensure model retains data_used.",
      call. = FALSE
    )
  }

  # Combine all subject data
  all_data <- dplyr::bind_rows(x$data_used)

  # Get predictions for each subject
  results <- lapply(names(x$data_used), function(subj_id) {
    subj_data <- x$data_used[[subj_id]]
    subj_data[[id_var]] <- subj_id
    preds <- predict(x, newdata = subj_data)
    subj_data$.fitted <- preds$.fitted
    subj_data$.resid <- subj_data[[y_var]] - subj_data$.fitted
    subj_data
  })

  tibble::as_tibble(dplyr::bind_rows(results))
}

Try the beezdemand package in your browser

Any scripts or data that you put into this service are public.

beezdemand documentation built on March 3, 2026, 9:07 a.m.