Nothing
#' Plot Cross-Validation Results vs Eta
#'
#' @description
#' Plots cross-validated performance across \code{eta} for
#' \code{cv.coxkl}, \code{cv.coxkl_ridge}, or \code{cv.coxkl_enet} results.
#' The main CV curve is drawn as a solid purple line; a green dotted horizontal
#' reference line is placed at the value corresponding to \code{eta = 0}
#' (or the closest available \code{eta}), with a solid green point marking that
#' reference level.
#'
#' @param object A fitted cross-validation result of class \code{"cv.coxkl"},
#' \code{"cv.coxkl_ridge"}, or \code{"cv.coxkl_enet"}.
#' @param line_color Color for the CV performance curve. Default \code{"#7570B3"}.
#' @param baseline_color Color for the horizontal reference line and point.
#' Default \code{"#1B9E77"}.
#' @param ... Additional arguments (currently ignored).
#'
#' @details
#' The function reads the performance metric from the object:
#' \itemize{
#' \item For \code{"cv.coxkl"}: uses \code{object$internal_stat} (one row per \code{eta}).
#' \item For \code{"cv.coxkl_ridge"} and \code{"cv.coxkl_enet"}:
#' uses \code{object$integrated_stat.best_per_eta} (best \code{lambda} per \code{eta}).
#' }
#' The y-axis label is set to \dQuote{Loss} if \code{criteria} in the object is
#' \dQuote{V&VH} or \dQuote{LinPred}; otherwise it is \dQuote{C Index}.
#' The horizontal reference (“baseline”) is taken from the plotted series at
#' \code{eta = 0} (or the nearest \code{eta} present in the results).
#'
#' @return A \code{ggplot} object showing cross-validation performance versus \code{eta}.
#'
#' @seealso \code{\link{cv.coxkl}}, \code{\link{cv.coxkl_ridge}}, \code{\link{cv.coxkl_enet}}
#'
#' @examples
#' data(Exampledata_lowdim)
#'
#' train_dat_lowdim <- ExampleData_lowdim$train
#' beta_external_good_lowdim <- ExampleData_lowdim$beta_external_good
#'
#' etas <- generate_eta(method = "exponential", n = 100, max_eta = 30)
#' cv_res <- cv.coxkl(z = train_dat_lowdim$z,
#' delta = train_dat_lowdim$status,
#' time = train_dat_lowdim$time,
#' stratrum = train_dat_lowdim$stratum,
#' beta = beta_external_good_lowdim,
#' etas = etas,
#' nfolds = 5,
#' criteria = c("V&VH"),
#' seed = 1)
#' cv.plot(cv_res)
#'
#' @importFrom ggplot2 ggplot scale_color_manual scale_linetype_manual guides guide_legend coord_cartesian aes geom_line geom_point geom_segment labs theme_minimal theme element_blank element_line element_text
#' @importFrom grid unit
#' @importFrom cowplot plot_grid get_legend
#' @importFrom rlang .data
#' @export
cv.plot <- function(object,
line_color = "#7570B3",
baseline_color = "#1B9E77",
...) {
if (inherits(object, "cv.coxkl")) {
df <- object$internal_stat
criteria <- object$criteria
} else if (inherits(object, "cv.coxkl_ridge") || inherits(object, "cv.coxkl_enet")) {
df <- object$integrated_stat.best_per_eta
criteria <- object$criteria
} else {
stop("Object must be a cv.coxkl, cv.coxkl_ridge, or cv.coxkl_enet.", call. = FALSE)
}
is_loss <- criteria %in% c("V&VH", "LinPred")
ylab <- if (is_loss) "Loss" else "C Index"
loss_candidates <- c("VVH_Loss", "LinPred_Loss", "Loss", "loss")
cindex_candidates <- c("CIndex_pooled", "CIndex_foldaverage", "CIndex", "cindex")
candidates <- if (is_loss) loss_candidates else cindex_candidates
metric_col <- NULL
for (nm in candidates) if (nm %in% names(df)) { metric_col <- nm; break }
if (is.null(metric_col)) {
num_cols <- names(df)[vapply(df, is.numeric, logical(1))]
num_cols <- setdiff(num_cols, c("eta", "lambda"))
if (length(num_cols) == 0L) stop("Could not detect metric column in CV results.", call. = FALSE)
metric_col <- num_cols[length(num_cols)]
}
df$eta <- as.numeric(df$eta)
df <- df[order(df$eta), , drop = FALSE]
df$metric <- as.numeric(df[[metric_col]])
if (!any(df$eta == 0)) idx0 <- which.min(abs(df$eta - 0)) else idx0 <- which(df$eta == 0)[1]
baseline_val <- df$metric[idx0]
baseline_eta <- df$eta[idx0]
xmin <- min(df$eta, na.rm = TRUE)
xmax <- max(df$eta, na.rm = TRUE)
if (is_loss) opt_idx <- which.min(df$metric) else opt_idx <- which.max(df$metric)
opt_eta <- df$eta[opt_idx]
ylow <- min(c(df$metric, baseline_val), na.rm = TRUE) * 0.995
yhigh <- max(c(df$metric, baseline_val), na.rm = TRUE) * 1.005
g_main <- ggplot(df, aes(x = .data$eta, y = .data$metric, group = 1)) +
geom_line(linewidth = 1, color = line_color) +
geom_point(size = 1.3, color = line_color) +
geom_segment(
data = data.frame(xmin = xmin, xmax = xmax, y = baseline_val),
aes(x = .data$xmin, xend = .data$xmax, y = .data$y, yend = .data$y),
inherit.aes = FALSE, color = baseline_color, linetype = "dotted", linewidth = 1
) +
geom_point(
data = data.frame(eta = baseline_eta, metric = baseline_val),
aes(x = .data$eta, y = .data$metric),
inherit.aes = FALSE, color = baseline_color, shape = 16, size = 2.4
) +
ggplot2::geom_segment(
data = data.frame(x = opt_eta),
ggplot2::aes(x = .data$x, xend = .data$x, y = ylow, yend = yhigh),
inherit.aes = FALSE, color = "#D95F02", linewidth = 1, linetype = "dashed"
) +
ggplot2::labs(x = expression(eta), y = ylab) +
ggplot2::theme_minimal(base_size = 13) +
ggplot2::theme(panel.grid = ggplot2::element_blank(),
panel.border = ggplot2::element_blank(),
axis.line = ggplot2::element_line(color = "black"),
axis.ticks.length = grid::unit(0.1, "cm"),
axis.ticks = ggplot2::element_line(color = "black"),
axis.text = ggplot2::element_text(size = 14),
legend.position = "none") +
ggplot2::coord_cartesian(ylim = c(ylow, yhigh))
legend_df <- data.frame(
x = rep(c(0, 1), 2),
y = rep(1, 4),
Method = factor(rep(c("survkl", "Internal"), each = 2),
levels = c("survkl", "Internal"))
)
internal_df <- legend_df[legend_df$Method == "Internal", , drop = FALSE]
g_legend <- ggplot2::ggplot(legend_df, ggplot2::aes(x = .data$x, y = .data$y,
color = .data$Method, linetype = .data$Method)) +
ggplot2::geom_line(linewidth = 1) +
ggplot2::geom_point(
data = internal_df,
ggplot2::aes(x = 0.5, y = 1, color = .data$Method),
inherit.aes = FALSE, shape = 16, size = 2.4
) +
ggplot2::scale_color_manual(values = c("survkl" = line_color, "Internal" = baseline_color)) +
ggplot2::scale_linetype_manual(values = c("survkl" = "solid", "Internal" = "dotted")) +
ggplot2::theme_void(base_size = 13) +
ggplot2::theme(legend.position = "top",
legend.title = ggplot2::element_blank(),
legend.text = ggplot2::element_text(size = 14),
legend.key.width = grid::unit(1.2, "lines"),
legend.key.height = grid::unit(0.6, "lines")) +
ggplot2::guides(color = ggplot2::guide_legend(keywidth = 1.2, keyheight = 0.4, title = NULL),
linetype = ggplot2::guide_legend(keywidth = 1.2, keyheight = 0.4, title = NULL))
cowplot::plot_grid(cowplot::get_legend(g_legend), g_main, ncol = 1, rel_heights = c(0.08, 1))
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.