R/plot.cv.savvyPR.R

Defines functions plot.cv.savvyPR

Documented in plot.cv.savvyPR

#' Plot for a Cross-Validated Parity Regression Model
#'
#' @title Plot for a Cross-Validated Parity Regression Model
#' @description Generates various visualizations for a fitted cross-validated parity regression
#' model object. It supports plotting estimated coefficients, risk contributions, coefficient
#' paths, and cross-validation error curves based on the specified \code{plot_type}.
#'
#' @param x A fitted model object of class \code{"cv.savvyPR"} returned by \code{\link{cv.savvyPR}}.
#' @param plot_type Character string specifying the type of plot to generate. Can be \code{"estimated_coefficients"}, \code{"risk_contributions"}, \code{"cv_coefficients"}, or \code{"cv_errors"}.
#'                  Defaults to \code{"estimated_coefficients"}.
#' @param label Logical; if \code{TRUE}, labels are added based on the plot type:
#' \describe{
#'   \item{cv_coefficients}{Variable names are added to the coefficient paths.}
#'   \item{risk_contributions}{Numeric labels are added above the bars.}
#'   \item{estimated_coefficients}{Numeric values are added to the coefficient plot.}
#'   \item{cv_errors}{Numeric labels are added to the optimal error point.}
#' }
#' Default is \code{TRUE}.
#' @param xvar Character string specifying the x-axis variable for plotting coefficient paths. Options are \code{"norm"}, \code{"lambda"}, \code{"dev"}, and \code{"val"}.
#'              This argument is only used when \code{plot_type = "cv_coefficients"}.
#' @param max_vars_per_plot Integer specifying the maximum number of variables to plot per panel. Cannot exceed \code{10}. Default is \code{10}.
#'              This argument is only used when \code{plot_type = "cv_coefficients"}.
#' @param ... Additional arguments passed to the underlying \code{ggplot} function.
#'
#' @details
#' This function offers four types of plots, depending on the value of \code{plot_type}:
#'
#' \describe{
#'   \item{\strong{Estimated Coefficients}}{Generates a line plot with points for the estimated
#'       coefficients of the optimally tuned cross-validated regression model.
#'       If an intercept term is included, it will be labeled as \code{beta_0};
#'       otherwise, coefficients are labeled sequentially based on the covariates.
#'       If \code{label = TRUE}, numeric values are displayed on the plot. This plot
#'       helps visualize the contribution of each predictor variable to the model.}
#'
#'   \item{\strong{Risk Contributions}}{Generates two bar plots, one for the optimization
#'       variables (weights or target values) and one for risk contributions, from
#'       the risk parity model. If \code{label = TRUE}, numeric labels are added
#'       above the bars for clarity.}
#'
#'   \item{\strong{Coefficient Paths}}{Generates a plot showing the coefficient paths against
#'       the selected \code{x-axis} variable (\code{val}, \code{lambda}, \code{norm},
#'       or \code{dev}) depending on the model type:
#'       \itemize{
#'           \item \strong{PR1/PR2: }Plots coefficient paths against \code{log(val)} values.
#'           \item \strong{PR3: }Plots coefficient paths against \code{log(lambda)} values.
#'       }
#'       Invalid combinations of \code{model_type} and \code{xvar} will result in an error:
#'       \itemize{
#'           \item \strong{PR1/PR2: }Cannot use \code{"lambda"} as \code{xvar}.
#'           \item \strong{PR3: }Cannot use \code{"val"} as \code{xvar}.
#'       }
#'       If \code{max_vars_per_plot} exceeds \code{10}, it is reset to \code{10}.
#'       The plot provides insight into how coefficients evolve across different
#'       regularization parameters.}
#'
#'   \item{\strong{Cross-Validation Errors}}{Generates a plot that shows the cross-validation
#'       error metric against the logarithm of the tuning parameter (\code{val} or
#'       \code{lambda}), depending on the model type. It adds a vertical dashed line
#'       to indicate the optimal parameter value.
#'       \itemize{
#'           \item \strong{PR1/PR2: }Plots cross-validation errors against \code{log(val)} values.
#'           \item \strong{PR3: }Plots cross-validation errors against \code{log(lambda)} values.
#'       }
#'    }
#' }
#'
#' @return A \code{ggplot} object representing the requested plot.
#'
#' @examples
#' \donttest{
#' # Example usage for `cv.savvyPR` with Correlated Data:
#' set.seed(123)
#' n <- 100
#' p <- 10
#' # Create highly correlated predictors to demonstrate parity regression
#' base_var <- rnorm(n)
#' x <- matrix(rnorm(n * p, sd = 0.1), n, p) + base_var
#' beta <- matrix(rnorm(p), p, 1)
#' y <- x %*% beta + rnorm(n, sd = 0.5)
#'
#' # Fit CV model using Budget method
#' cv_result1 <- cv.savvyPR(x, y, method = "budget", model_type = "PR1",
#'                          measure_type = "mse", intercept = FALSE)
#' plot(cv_result1, plot_type = "estimated_coefficients")
#' plot(cv_result1, plot_type = "risk_contributions", label = FALSE)
#' plot(cv_result1, plot_type = "cv_coefficients", xvar = "val", max_vars_per_plot = 10)
#' plot(cv_result1, plot_type = "cv_errors")
#'
#' # Fit CV model using Target method
#' cv_result2 <- cv.savvyPR(x, y, method = "target", model_type = "PR2")
#' cv_result3 <- cv.savvyPR(x, y, method = "budget", model_type = "PR3")
#'
#' plot(cv_result2, plot_type = "cv_coefficients", xvar = "val",
#'      max_vars_per_plot = 5, label = FALSE)
#' plot(cv_result3, plot_type = "cv_coefficients", xvar = "lambda",
#'      max_vars_per_plot = 10, label = TRUE)
#' plot(cv_result2, plot_type = "cv_errors", label = FALSE)
#' plot(cv_result3, plot_type = "cv_errors", label = TRUE)
#' }
#'
#' @author Ziwei Chen, Vali Asimit and Pietro Millossovich\cr
#' Maintainer: Ziwei Chen <ziwei.chen.3@citystgeorges.ac.uk>
#'
#' @importFrom ggplot2 ggplot aes geom_line geom_point geom_bar geom_text scale_color_brewer coord_cartesian
#'             theme_minimal labs theme element_text element_line element_blank geom_errorbar geom_vline annotate ylim
#' @importFrom gridExtra grid.arrange
#'
#' @seealso \code{\link{cv.savvyPR}}
#' @method plot cv.savvyPR
#' @export
plot.cv.savvyPR <- function(x, plot_type = c("estimated_coefficients", "risk_contributions",
                                             "cv_coefficients", "cv_errors"), label = TRUE,
                            xvar = c("norm", "lambda", "dev", "val"), max_vars_per_plot = 10, ...) {

  plot_type <- match.arg(plot_type)

  if (plot_type == "estimated_coefficients") {
    plotCoef(x$coefficients, intercept = x$PR_fit$intercept, label = label, ...)

  } else if (plot_type == "risk_contributions") {
    if (!is.null(x$PR_fit$orp_fit)) {
      method_used <- if(is.null(x$method)) "budget" else x$method
      plotRiskContr(x$PR_fit$orp_fit, label = label, method = method_used, ...)
    } else {
      warning("No 'orp_fit' found in the cross-validated model. This usually occurs when the optimal tuning parameter is 0. Cannot plot risk contributions.")
    }

  } else if (plot_type == "cv_coefficients") {
    plotCVCoef(result_list = x, label = label, xvar = xvar, max_vars_per_plot = max_vars_per_plot, ...)

  } else if (plot_type == "cv_errors") {
    plotCVErr(cv_results = x, label = label, ...)
  }
}

Try the savvyPR package in your browser

Any scripts or data that you put into this service are public.

savvyPR documentation built on April 7, 2026, 5:08 p.m.