R/cv.plot.R

Defines functions cv.plot

Documented in cv.plot

#' Plot Cross-Validation Results vs Eta
#'
#' @description
#' Plots cross-validated performance across \code{eta} for
#' \code{cv.coxkl}, \code{cv.coxkl_ridge}, or \code{cv.coxkl_enet} results.
#' The main CV curve is drawn as a solid purple line; a green dotted horizontal
#' reference line is placed at the value corresponding to \code{eta = 0}
#' (or the closest available \code{eta}), with a solid green point marking that
#' reference level.
#'
#' @param object A fitted cross-validation result of class \code{"cv.coxkl"},
#'   \code{"cv.coxkl_ridge"}, or \code{"cv.coxkl_enet"}.
#' @param line_color Color for the CV performance curve. Default \code{"#7570B3"}.
#' @param baseline_color Color for the horizontal reference line and point.
#'   Default \code{"#1B9E77"}.
#' @param ... Additional arguments (currently ignored).
#' 
#' @details
#' The function reads the performance metric from the object:
#' \itemize{
#'   \item For \code{"cv.coxkl"}: uses \code{object$internal_stat} (one row per \code{eta}).
#'   \item For \code{"cv.coxkl_ridge"} and \code{"cv.coxkl_enet"}:
#'         uses \code{object$integrated_stat.best_per_eta} (best \code{lambda} per \code{eta}).
#' }
#' The y-axis label is set to \dQuote{Loss} if \code{criteria} in the object is
#' \dQuote{V&VH} or \dQuote{LinPred}; otherwise it is \dQuote{C Index}.
#' The horizontal reference (“baseline”) is taken from the plotted series at
#' \code{eta = 0} (or the nearest \code{eta} present in the results).
#' 
#' @return A \code{ggplot} object showing cross-validation performance versus \code{eta}.
#' 
#' @seealso \code{\link{cv.coxkl}}, \code{\link{cv.coxkl_ridge}}, \code{\link{cv.coxkl_enet}}
#' 
#' @examples
#' data(Exampledata_lowdim)
#' 
#' train_dat_lowdim <- ExampleData_lowdim$train
#' beta_external_good_lowdim <- ExampleData_lowdim$beta_external_good
#' 
#' etas <- generate_eta(method = "exponential", n = 100, max_eta = 30)
#' cv_res <- cv.coxkl(z = train_dat_lowdim$z,
#'                    delta = train_dat_lowdim$status,
#'                    time = train_dat_lowdim$time,
#'                    stratrum = train_dat_lowdim$stratum,
#'                    beta = beta_external_good_lowdim,
#'                    etas = etas,
#'                    nfolds = 5,
#'                    criteria = c("V&VH"),
#'                    seed = 1)
#' cv.plot(cv_res)
#' 
#' @importFrom ggplot2 ggplot scale_color_manual scale_linetype_manual guides guide_legend coord_cartesian aes geom_line geom_point geom_segment labs theme_minimal theme element_blank element_line element_text
#' @importFrom grid unit
#' @importFrom cowplot plot_grid get_legend
#' @importFrom rlang .data
#' @export
cv.plot <- function(object,
                    line_color = "#7570B3",
                    baseline_color = "#1B9E77",
                    ...) {
  if (inherits(object, "cv.coxkl")) {
    df <- object$internal_stat
    criteria <- object$criteria
  } else if (inherits(object, "cv.coxkl_ridge") || inherits(object, "cv.coxkl_enet")) {
    df <- object$integrated_stat.best_per_eta
    criteria <- object$criteria
  } else {
    stop("Object must be a cv.coxkl, cv.coxkl_ridge, or cv.coxkl_enet.", call. = FALSE)
  }
  
  is_loss <- criteria %in% c("V&VH", "LinPred")
  ylab <- if (is_loss) "Loss" else "C Index"
  
  loss_candidates   <- c("VVH_Loss", "LinPred_Loss", "Loss", "loss")
  cindex_candidates <- c("CIndex_pooled", "CIndex_foldaverage", "CIndex", "cindex")
  candidates <- if (is_loss) loss_candidates else cindex_candidates
  
  metric_col <- NULL
  for (nm in candidates) if (nm %in% names(df)) { metric_col <- nm; break }
  if (is.null(metric_col)) {
    num_cols <- names(df)[vapply(df, is.numeric, logical(1))]
    num_cols <- setdiff(num_cols, c("eta", "lambda"))
    if (length(num_cols) == 0L) stop("Could not detect metric column in CV results.", call. = FALSE)
    metric_col <- num_cols[length(num_cols)]
  }
  
  df$eta <- as.numeric(df$eta)
  df <- df[order(df$eta), , drop = FALSE]
  df$metric <- as.numeric(df[[metric_col]])
  
  if (!any(df$eta == 0)) idx0 <- which.min(abs(df$eta - 0)) else idx0 <- which(df$eta == 0)[1]
  baseline_val <- df$metric[idx0]
  baseline_eta <- df$eta[idx0]
  
  xmin <- min(df$eta, na.rm = TRUE)
  xmax <- max(df$eta, na.rm = TRUE)
  
  if (is_loss) opt_idx <- which.min(df$metric) else opt_idx <- which.max(df$metric)
  opt_eta <- df$eta[opt_idx]
  
  ylow  <- min(c(df$metric, baseline_val), na.rm = TRUE) * 0.995
  yhigh <- max(c(df$metric, baseline_val), na.rm = TRUE) * 1.005
  
  g_main <- ggplot(df, aes(x = .data$eta, y = .data$metric, group = 1)) +
    geom_line(linewidth = 1, color = line_color) +
    geom_point(size = 1.3, color = line_color) +
    geom_segment(
      data = data.frame(xmin = xmin, xmax = xmax, y = baseline_val),
      aes(x = .data$xmin, xend = .data$xmax, y = .data$y, yend = .data$y),
      inherit.aes = FALSE, color = baseline_color, linetype = "dotted", linewidth = 1
    ) +
    geom_point(
      data = data.frame(eta = baseline_eta, metric = baseline_val),
      aes(x = .data$eta, y = .data$metric),
      inherit.aes = FALSE, color = baseline_color, shape = 16, size = 2.4
    ) +
    ggplot2::geom_segment(
      data = data.frame(x = opt_eta),
      ggplot2::aes(x = .data$x, xend = .data$x, y = ylow, yend = yhigh),
      inherit.aes = FALSE, color = "#D95F02", linewidth = 1, linetype = "dashed"
    ) +
    ggplot2::labs(x = expression(eta), y = ylab) +
    ggplot2::theme_minimal(base_size = 13) +
    ggplot2::theme(panel.grid = ggplot2::element_blank(),
                   panel.border = ggplot2::element_blank(),
                   axis.line = ggplot2::element_line(color = "black"),
                   axis.ticks.length = grid::unit(0.1, "cm"),
                   axis.ticks = ggplot2::element_line(color = "black"),
                   axis.text = ggplot2::element_text(size = 14),
                   legend.position = "none") +
    ggplot2::coord_cartesian(ylim = c(ylow, yhigh))
  
  legend_df <- data.frame(
    x = rep(c(0, 1), 2),
    y = rep(1, 4),
    Method = factor(rep(c("survkl", "Internal"), each = 2),
                    levels = c("survkl", "Internal"))
  )
  
  internal_df <- legend_df[legend_df$Method == "Internal", , drop = FALSE]
  
  g_legend <- ggplot2::ggplot(legend_df, ggplot2::aes(x = .data$x, y = .data$y, 
                                                      color = .data$Method, linetype = .data$Method)) +
    ggplot2::geom_line(linewidth = 1) +
    ggplot2::geom_point(
      data = internal_df,
      ggplot2::aes(x = 0.5, y = 1, color = .data$Method),
      inherit.aes = FALSE, shape = 16, size = 2.4
    ) +
    ggplot2::scale_color_manual(values = c("survkl" = line_color, "Internal" = baseline_color)) +
    ggplot2::scale_linetype_manual(values = c("survkl" = "solid", "Internal" = "dotted")) +
    ggplot2::theme_void(base_size = 13) +
    ggplot2::theme(legend.position = "top",
                   legend.title = ggplot2::element_blank(),
                   legend.text = ggplot2::element_text(size = 14),
                   legend.key.width = grid::unit(1.2, "lines"),
                   legend.key.height = grid::unit(0.6, "lines")) +
    ggplot2::guides(color = ggplot2::guide_legend(keywidth = 1.2, keyheight = 0.4, title = NULL),
                    linetype = ggplot2::guide_legend(keywidth = 1.2, keyheight = 0.4, title = NULL))
  
  cowplot::plot_grid(cowplot::get_legend(g_legend), g_main, ncol = 1, rel_heights = c(0.08, 1))
}

Try the survkl package in your browser

Any scripts or data that you put into this service are public.

survkl documentation built on April 22, 2026, 1:08 a.m.