R/plot.dm_growth_evaluation.R

Defines functions plot.dm_growth_evaluation

Documented in plot.dm_growth_evaluation

#' Plot growth-fitting evaluation statistics
#'
#' @description
#' Creates ggplot2-based comparison plots for objects returned by
#' [dm.growth.evaluate()].
#'
#' @param x An object of class \code{"dm_growth_evaluation"}.
#' @param metric Evaluation metric to plot. One of \code{"rmse"},
#'   \code{"mae"}, \code{"bias"}, \code{"abs_bias"}, \code{"r2"},
#'   \code{"correlation"}, \code{"nrmse"}, \code{"rss"},
#'   \code{"aic_approx"}, or \code{"bic_approx"}.
#' @param type Plot type. One of \code{"boxplot"}, \code{"mean"}, or
#'   \code{"heatmap"}.
#' @param order_methods Logical. If \code{TRUE}, methods are ordered by their
#'   average value for the chosen metric.
#' @param decreasing Logical or \code{NULL}. Controls method ordering. If
#'   \code{NULL}, metrics where smaller values are better are ordered ascending,
#'   and metrics where larger values are better are ordered descending.
#' @param show_points Logical. If \code{TRUE}, add individual fit points to
#'   \code{"boxplot"} and \code{"mean"} displays.
#' @param show_errorbar Logical. If \code{TRUE}, show standard-error bars when
#'   \code{type = "mean"}.
#' @param heatmap_label Character string used to label rows in the heatmap. One
#'   of \code{"series_fit"}, \code{"series"}, or \code{"fit_id"}.
#' @param na.rm Logical. If \code{TRUE}, remove missing metric values before
#'   plotting.
#' @param ... Further arguments passed to plotting methods.
#'
#' @return A \code{ggplot} object.
#'
#' @seealso [dm.growth.evaluate()]
#'
#' @method plot dm_growth_evaluation
#' @export
plot.dm_growth_evaluation <- function(x,
                                      metric = c(
                                        "rmse", "mae", "bias", "abs_bias",
                                        "r2", "correlation", "nrmse",
                                        "rss", "aic_approx", "bic_approx"
                                      ),
                                      type = c("boxplot", "mean", "heatmap"),
                                      order_methods = TRUE,
                                      decreasing = NULL,
                                      show_points = TRUE,
                                      show_errorbar = TRUE,
                                      heatmap_label = c("series_fit", "series", "fit_id"),
                                      na.rm = TRUE,
                                      ...) {
  avg_metric <- mean_metric <- method <- metric_value <- n <- row_label <- NULL
  sd_metric <- se_metric <- NULL
  metric <- match.arg(metric)
  type <- match.arg(type)
  heatmap_label <- match.arg(heatmap_label)

  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("Package 'ggplot2' is required for plot.dm_growth_evaluation().")
  }

  dat <- tibble::as_tibble(x)

  required_cols <- c("method", "series", "fit_id")
  missing_required <- setdiff(required_cols, names(dat))
  if (length(missing_required) > 0) {
    stop(
      "The evaluation table is missing required columns: ",
      paste(missing_required, collapse = ", "), "."
    )
  }

  if (metric == "abs_bias") {
    if (!"bias" %in% names(dat)) {
      stop("Column 'bias' is required to compute 'abs_bias'.")
    }
    dat$metric_value <- abs(dat$bias)
    metric_label <- "Absolute bias"
  } else {
    if (!metric %in% names(dat)) {
      stop("Column '", metric, "' not found in evaluation table.")
    }
    dat$metric_value <- dat[[metric]]
    metric_label <- switch(
      metric,
      rmse = "RMSE",
      mae = "MAE",
      bias = "Bias",
      r2 = "R2",
      correlation = "Correlation",
      nrmse = "NRMSE",
      rss = "RSS",
      aic_approx = "Approx. AIC",
      bic_approx = "Approx. BIC",
      metric
    )
  }

  if (isTRUE(na.rm)) {
    dat <- dat[is.finite(dat$metric_value), , drop = FALSE]
  }

  if (nrow(dat) == 0) {
    stop("No finite values available for metric '", metric, "'.")
  }

  smaller_is_better <- metric %in% c("rmse", "mae", "abs_bias", "nrmse", "rss", "aic_approx", "bic_approx")

  if (is.null(decreasing)) {
    decreasing <- !smaller_is_better
  }

  if (isTRUE(order_methods)) {
    method_order <- dat %>%
      dplyr::group_by(method) %>%
      dplyr::summarise(avg_metric = mean(metric_value, na.rm = TRUE), .groups = "drop") %>%
      dplyr::arrange(if (isTRUE(decreasing)) dplyr::desc(avg_metric) else avg_metric) %>%
      dplyr::pull(method)

    dat$method <- factor(dat$method, levels = method_order)
  }

  if (type == "boxplot") {
    p <- ggplot2::ggplot(dat, ggplot2::aes(x = method, y = metric_value)) +
      ggplot2::geom_boxplot(outlier.shape = NA) +
      ggplot2::labs(
        x = "Method",
        y = metric_label,
        title = paste("Comparison of", metric, "across growth-fitting methods")
      ) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
      )

    if (isTRUE(show_points)) {
      p <- p + ggplot2::geom_jitter(width = 0.15, height = 0, alpha = 0.7, size = 1.8)
    }

    return(p)
  }

  if (type == "mean") {
    sum_dat <- dat %>%
      dplyr::group_by(method) %>%
      dplyr::summarise(
        n = sum(is.finite(metric_value)),
        mean_metric = mean(metric_value, na.rm = TRUE),
        sd_metric = stats::sd(metric_value, na.rm = TRUE),
        se_metric = sd_metric / sqrt(n),
        .groups = "drop"
      )

    if (isTRUE(order_methods)) {
      sum_dat$method <- factor(sum_dat$method, levels = levels(dat$method))
    }

    p <- ggplot2::ggplot(sum_dat, ggplot2::aes(x = method, y = mean_metric)) +
      ggplot2::geom_col() +
      ggplot2::labs(
        x = "Method",
        y = paste("Mean", metric_label),
        title = paste("Mean", metric, "by growth-fitting method")
      ) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
      )

    if (isTRUE(show_errorbar)) {
      p <- p + ggplot2::geom_errorbar(
        ggplot2::aes(
          ymin = mean_metric - se_metric,
          ymax = mean_metric + se_metric
        ),
        width = 0.2
      )
    }

    if (isTRUE(show_points)) {
      p <- p + ggplot2::geom_point(
        data = dat,
        mapping = ggplot2::aes(x = method, y = metric_value),
        inherit.aes = FALSE,
        alpha = 0.5,
        position = ggplot2::position_jitter(width = 0.12, height = 0)
      )
    }

    return(p)
  }

  if (type == "heatmap") {
    dat$row_label <- switch(
      heatmap_label,
      series_fit = paste(dat$series, dat$fit_id, sep = " | "),
      series = dat$series,
      fit_id = dat$fit_id
    )

    row_levels <- dat %>%
      dplyr::group_by(row_label) %>%
      dplyr::summarise(avg_metric = mean(metric_value, na.rm = TRUE), .groups = "drop") %>%
      dplyr::arrange(if (isTRUE(decreasing)) dplyr::desc(avg_metric) else avg_metric) %>%
      dplyr::pull(row_label)

    dat$row_label <- factor(dat$row_label, levels = row_levels)

    p <- ggplot2::ggplot(
      dat,
      ggplot2::aes(x = method, y = row_label, fill = metric_value)
    ) +
      ggplot2::geom_tile() +
      ggplot2::labs(
        x = "Method",
        y = switch(
          heatmap_label,
          series_fit = "Series | Fit",
          series = "Series",
          fit_id = "Fit ID"
        ),
        fill = metric_label,
        title = paste("Heatmap of", metric, "across growth-fitting methods")
      ) +
      ggplot2::theme_bw() +
      ggplot2::theme(
        axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
      )

    return(p)
  }

  stop("Unknown plot type.")
}

Try the dendRoAnalyst package in your browser

Any scripts or data that you put into this service are public.

dendRoAnalyst documentation built on May 20, 2026, 5:07 p.m.