R/plot_tune_xrnet.R

Defines functions plot.tune_xrnet

Documented in plot.tune_xrnet

#' Plot k-fold cross-validation error grid
#'
#' @description Generates plots to visualize the mean cross-validation error.
#' If no external data was used in the model fit, a plot of the cross-validated
#' error with standard error bars is generated for all penalty values. If
#' external data was used in the model fit, a contour plot of the
#' cross-validated errors is created. Error curves can also be generated for a
#' fixed value of the primary penalty on x (p) or the external penalty (pext)
#' when external data is used.
#'
#' @param x A tune_xrnet class object
#' @param p (optional) penalty value for x (for generating an error curve across
#' external penalties). Use value "opt" to use the optimal penalty value.
#' @param pext (optional) penalty value for external (for generating an error
#' curve across primary penalties). Use value "opt" to use the optimal penalty
#' value.
#' @param ... Additional graphics parameters
#'
#' @return None
#'
#' @details The parameter values p and pext can be used to generate profiled
#' error curves by fixing either the penalty on x or the penalty on external to
#' a fixed value. You cannot specify both at the same time as this would only
#' return a single point.
#'
#' @examples
#'
#' ## load example data
#' data(GaussianExample)
#'
#' ## 5-fold cross validation
#' cv_xrnet <- tune_xrnet(
#'   x = x_linear,
#'   y = y_linear,
#'   external = ext_linear,
#'   family = "gaussian",
#'   control = xrnet_control(tolerance = 1e-6)
#' )
#'
#' ## contour plot of cross-validated error
#' plot(cv_xrnet)
#'
#' ## error curve of external penalties at optimal penalty value
#' plot(cv_xrnet, p = "opt")
#' @export
#' @importFrom graphics filled.contour axis points
#' @importFrom grDevices colorRampPalette

plot.tune_xrnet <- function(x, p = NULL, pext = NULL, ...) {
  if (is.null(x$fitted_model$alphas) || !is.null(p) || !is.null(pext)) {
    if (is.null(x$fitted_model$alphas)) {
      xval <- log(as.numeric(rownames(x$cv_mean)))
      cverr <- x$cv_mean[, 1]
      cvsd <- x$cv_sd[, 1]
      xlab <- "log(Penalty)"
      xopt_val <- log(x$opt_penalty)
    } else {
      if (!is.null(p) && !is.null(pext)) {
        stop(
          "Please only specify either penalty or penalty_ext,
          cannot specify both at the same time"
        )
      } else if (!is.null(p)) {
        if (p == "opt") {
          p <- x$opt_penalty
        }
        p_idx <- match(p, x$fitted_model$penalty)
        if (is.na(p_idx)) {
          stop("The penalty value 'p' is not in the fitted model")
        }
        xval <- log(as.numeric(colnames(x$cv_mean)))
        cverr <- x$cv_mean[p_idx, ]
        cvsd <- x$cv_sd[p_idx, ]
        xlab <- "log(External Penalty)"
        xopt_val <- log(x$opt_penalty_ext)
      } else {
        if (pext == "opt") {
          pext <- x$opt_penalty_ext
        }
        pext_idx <- match(pext, x$fitted_model$penalty_ext)
        if (is.na(pext_idx)) {
          stop("The penalty value 'p' is not in the fitted model")
        }
        xval <- log(as.numeric(rownames(x$cv_mean)))
        cverr <- x$cv_mean[, pext_idx]
        cvsd <- x$cv_sd[, pext_idx]
        xlab <- "log(Penalty)"
        xopt_val <- log(x$opt_penalty)
      }
    }
    graphics::plot(
      x = xval,
      y = cverr,
      ylab = paste0("Mean CV Error (", x$loss, ")"),
      xlab = xlab,
      ylim = range(c(cverr - cvsd, cverr + cvsd)),
      type = "n"
    )
    graphics::arrows(
      xval,
      cverr - cvsd,
      xval,
      cverr + cvsd,
      length = 0.025,
      angle = 90,
      code = 3,
      col = "lightgray"
    )
    graphics::points(
      x = xval,
      y = cverr,
      col = "dodgerblue4",
      pch = 16,
    )
    graphics::abline(v = xopt_val, col = "firebrick")
  } else {
    cvgrid <- x$cv_mean
    cvgrid <- cvgrid[rev(seq_len(nrow(cvgrid))), ]
    cvgrid <- cvgrid[, rev(seq_len(ncol(cvgrid)))]
    minx <- log(x$opt_penalty_ext)
    miny <- log(x$opt_penalty)

    contour_colors <- c(
      "#014636", "#016C59", "#02818A", "#3690C0",
      "#67A9CF", "#A6BDDB", "#D0D1E6", "#ECE2F0", "#FFF7FB"
    )

    graphics::filled.contour(
      x = log(as.numeric(colnames(cvgrid))),
      y = log(as.numeric(rownames(cvgrid))),
      z = t(cvgrid),
      col = colorRampPalette(contour_colors)(25),
      xlab = "log(External Penalty)",
      ylab = "log(Penalty)",
      plot.axes = {
        axis(1)
        axis(2)
        points(minx, miny, col = "red", pch = 16)
      }
    )
  }
}
USCbiostats/hierr documentation built on Dec. 12, 2023, 12:55 p.m.