Nothing
#' Plot methods for \code{missoNet} and cross-validated fits
#'
#' Visualize either the cross-validation (CV) error surface or the
#' goodness-of-fit (GoF) surface over the \eqn{\lambda_B}–\eqn{\lambda_\Theta}
#' grid for objects returned by \code{\link{missoNet}} or \code{\link{cv.missoNet}}.
#' Two display types are supported:
#' a 2D heatmap (default) and a 3D scatter surface.
#'
#' @section What gets plotted:
#' \itemize{
#' \item \strong{CV objects} (created by \code{\link{cv.missoNet}} or any
#' \code{missoNet} object that carries CV results): the color encodes the
#' mean CV error for each \eqn{(\lambda_B, \lambda_\Theta)} pair. The
#' \emph{minimum-error} solution is outlined; if 1-SE solutions were
#' computed, they are also marked (dashed/dotted outlines).
#'
#' \item \strong{Non-CV objects} (created by \code{\link{missoNet}} without CV):
#' the color encodes the GoF value over the grid; the selected
#' \emph{minimum} (best) solution is outlined.
#' }
#'
#' @section Axes and scales:
#' For heatmaps, axes are the raw \eqn{\lambda} values; rows are
#' \eqn{\lambda_\Theta} and columns are \eqn{\lambda_B}.
#' For 3D scatter plots, both \eqn{\lambda} axes are shown on the
#' \eqn{\log_{10}} scale for readability.
#'
#' @section Color mapping:
#' A viridis-like palette is used. Breaks are based on distribution quantiles
#' of the CV error or GoF values to enhance contrast across the grid.
#'
#' @param x A fitted object returned by \code{\link{missoNet}} or
#' \code{\link{cv.missoNet}}.
#' @param type Character string specifying the plot type.
#' One of \code{"heatmap"} (default) or \code{"scatter"}.
#' @param detailed.axes Logical; if \code{TRUE} (default) show dense axis labels.
#' If \code{FALSE}, a sparser labeling is used to avoid clutter.
#' @param plt.surf Logical; for \code{type = "scatter"} only, draw light surface
#' grid lines and highlight the minimum point. Ignored for heatmaps. Default \code{TRUE}.
#' @param ... Additional graphical arguments forwarded to
#' \code{\link[ComplexHeatmap]{Heatmap}} when \code{type = "heatmap"}, or to
#' \code{\link[scatterplot3d]{scatterplot3d}} when \code{type = "scatter"}.
#'
#' @return
#' \itemize{
#' \item For \code{type = "heatmap"}: a \code{ComplexHeatmap} \code{Heatmap}
#' object (invisibly drawn by \code{ComplexHeatmap}).
#' \item For \code{type = "scatter"}: a \code{"scatterplot3d"} object,
#' returned \emph{invisibly}.
#' }
#'
#' @details
#' This S3 method detects whether \code{x} contains cross-validation results and
#' chooses an appropriate plotting backend:
#' \itemize{
#' \item \strong{Heatmap}: uses \code{\link[ComplexHeatmap]{Heatmap}} with a
#' viridis-like color ramp (via \code{\link[circlize]{colorRamp2}}). The
#' selected \eqn{(\lambda_B, \lambda_\Theta)} is outlined in white; 1-SE
#' choices (if present) are highlighted with dashed/dotted outlines.
#' \item \strong{Scatter}: uses \code{\link[scatterplot3d]{scatterplot3d}}
#' to draw the error/GoF surface on \eqn{\log_{10}} scales. When
#' \code{plt.surf = TRUE}, light lattice lines are added, and the minimum is
#' marked.
#' }
#'
#' @section Dependencies:
#' Requires \pkg{ComplexHeatmap}, \pkg{circlize}, \pkg{scatterplot3d}, and \pkg{grid}.
#'
#' @seealso \code{\link{missoNet}}, \code{\link{cv.missoNet}},
#' \code{\link[ComplexHeatmap]{Heatmap}}, \code{\link[scatterplot3d]{scatterplot3d}}
#'
#' @author Yixiao Zeng \email{yixiao.zeng@@mail.mcgill.ca}, Celia M. T. Greenwood
#'
#' @examples
#' sim <- generateData(n = 150, p = 10, q = 8, rho = 0.1, missing.type = "MCAR")
#'
#' \donttest{
#' ## Fit a model without CV (plots GoF surface)
#' fit <- missoNet(X = sim$X, Y = sim$Z, verbose = 0)
#' plot(fit, type = "heatmap") # GoF heatmap
#' plot(fit, type = "scatter", plt.surf = TRUE) # GoF 3D scatter
#'
#' ## Cross-validation (plots CV error surface)
#' cvfit <- cv.missoNet(X = sim$X, Y = sim$Z, verbose = 0)
#' plot(cvfit, type = "heatmap", detailed.axes = FALSE)
#' plot(cvfit, type = "scatter", plt.surf = FALSE)
#' }
#'
#' @keywords hplot models
#' @aliases plot.cv.missoNet
#' @method plot missoNet
#' @export
#'
#' @importFrom ComplexHeatmap Heatmap
#' @importFrom scatterplot3d scatterplot3d
#' @importFrom circlize colorRamp2
#' @importFrom grid gpar grid.rect grid.points unit
plot.missoNet <- function(x, type = c("heatmap", "scatter"), detailed.axes = TRUE, plt.surf = TRUE, ...) {
type <- match.arg(type)
is_cv <- !is.null(x$param_set$cv.errors.mean)
if (is_cv) {
switch(type,
heatmap = plot.cv.heatmap(x, detailed.axes, ...),
scatter = plot.cv.scatter(x, detailed.axes, plt.surf, ...))
} else {
switch(type,
heatmap = plot.heatmap(x, detailed.axes, ...),
scatter = plot.scatter(x, detailed.axes, plt.surf, ...))
}
}
# ============================================================================
# INTERNAL HELPER FUNCTIONS
# ============================================================================
#' @noRd
plot.cv.heatmap <- function(cv.missoNet.obj, detailed.axes, ...) {
lambda.Beta <- cv.missoNet.obj$lambda.beta.seq
lambda.Theta <- cv.missoNet.obj$lambda.theta.seq
cv.lamB <- unique(cv.missoNet.obj$param_set$cv.grid.beta)
cv.lamTh <- unique(cv.missoNet.obj$param_set$cv.grid.theta)
if (is.null(cv.missoNet.obj$param_set$cv.errors.mean)) {
stop("No cross-validation errors found in the object")
}
cvm <- NULL
cvm.bg <- min(cv.missoNet.obj$param_set$cv.errors.mean, na.rm = TRUE) * 0.9
for (l in 1:length(lambda.Beta)) {
if (lambda.Beta[l] %in% cv.lamB) {
cvm.pad <- rep(cvm.bg, length(lambda.Theta))
cv.lamB.id <- which(cv.lamB == lambda.Beta[l])
cvm.pad[lambda.Theta %in% cv.lamTh] <- cv.missoNet.obj$param_set$cv.errors.mean[
((cv.lamB.id-1) * length(cv.lamTh) + 1) : (cv.lamB.id * length(cv.lamTh))
]
cvm <- cbind(cvm, cvm.pad)
} else {
cvm <- cbind(cvm, rep(cvm.bg, length(lambda.Theta)))
}
}
if (detailed.axes) {
rownames(cvm) <- sprintf("%.4f", lambda.Theta)
colnames(cvm) <- sprintf("%.4f", lambda.Beta)
} else {
rownames(cvm) <- rep(" ", length(lambda.Theta))
colnames(cvm) <- rep(" ", length(lambda.Beta))
# Position selection for sparse labels
pos <- unique(round(seq(1, length(lambda.Theta), length.out = min(11, length(lambda.Theta)))))
rownames(cvm)[pos] <- sprintf("%.4f", lambda.Theta[pos])
pos <- unique(round(seq(1, length(lambda.Beta), length.out = min(11, length(lambda.Beta)))))
colnames(cvm)[pos] <- sprintf("%.4f", lambda.Beta[pos])
}
# Flip vertically and transpose for correct orientation
cvm <- t(apply(cvm, 1, rev))
quantiles <- quantile(cv.missoNet.obj$param_set$cv.errors.mean,
c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
col <- circlize::colorRamp2(
c(cvm.bg, quantiles),
c("white", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
)
ComplexHeatmap::Heatmap(
cvm,
col = col,
border = TRUE,
row_title = expression(lambda[Theta]),
row_title_side = "right",
row_title_rot = 0,
row_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
row_names_side = "left",
row_names_gp = grid::gpar(fontsize = 8),
column_title = expression(lambda[Beta]),
column_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
column_names_gp = grid::gpar(fontsize = 8),
cluster_rows = FALSE,
cluster_columns = FALSE,
name = "CV Error",
heatmap_legend_param = list(
title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
labels_gp = grid::gpar(fontsize = 8),
grid_height = grid::unit(4, "mm"),
grid_width = grid::unit(4, "mm")
),
...,
cell_fun = function(j, i, x, y, width, height, fill) {
# Highlight lambda.min with solid white border
if (j == which(rev(lambda.Beta) == cv.missoNet.obj$est.min$lambda.beta) &&
i == which(lambda.Theta == cv.missoNet.obj$est.min$lambda.theta)) {
grid::grid.rect(x = x, y = y, width = width, height = height,
gp = grid::gpar(lwd = 2, col = "white", fill = NA))
# Corner markers for better visibility
grid::grid.points(x = x, y = y, pch = 19, size = grid::unit(3, "mm"),
gp = grid::gpar(col = "white"))
}
# Highlight 1se.beta with dashed border
if (!is.null(cv.missoNet.obj$est.1se.beta)) {
if (j == which(rev(lambda.Beta) == cv.missoNet.obj$est.1se.beta$lambda.beta) &&
i == which(lambda.Theta == cv.missoNet.obj$est.1se.beta$lambda.theta)) {
grid::grid.rect(x = x, y = y, width = width, height = height,
gp = grid::gpar(lwd = 2, lty = "dashed", col = "yellow", fill = NA))
}
}
# Highlight 1se.theta with dotted border
if (!is.null(cv.missoNet.obj$est.1se.theta)) {
if (j == which(rev(lambda.Beta) == cv.missoNet.obj$est.1se.theta$lambda.beta) &&
i == which(lambda.Theta == cv.missoNet.obj$est.1se.theta$lambda.theta)) {
grid::grid.rect(x = x, y = y, width = width, height = height,
gp = grid::gpar(lwd = 2, lty = "dotted", col = "cyan", fill = NA))
}
}
}
)
}
#' @noRd
plot.heatmap <- function(missoNet.obj, detailed.axes, ...) {
lambda.Beta <- missoNet.obj$lambda.beta.seq
lambda.Theta <- missoNet.obj$lambda.theta.seq
GoF.lamB <- unique(missoNet.obj$param_set$gof.grid.beta)
GoF.lamTh <- unique(missoNet.obj$param_set$gof.grid.theta)
if (is.null(missoNet.obj$param_set$gof)) {
stop("No goodness-of-fit values found in the object")
}
GoF <- NULL
GoF.bg <- min(missoNet.obj$param_set$gof, na.rm = TRUE) * 0.9
for (l in 1:length(lambda.Beta)) {
if (lambda.Beta[l] %in% GoF.lamB) {
GoF.pad <- rep(GoF.bg, length(lambda.Theta))
GoF.lamB.id <- which(GoF.lamB == lambda.Beta[l])
GoF.pad[lambda.Theta %in% GoF.lamTh] <- missoNet.obj$param_set$gof[
((GoF.lamB.id-1) * length(GoF.lamTh) + 1) : (GoF.lamB.id * length(GoF.lamTh))
]
GoF <- cbind(GoF, GoF.pad)
} else {
GoF <- cbind(GoF, rep(GoF.bg, length(lambda.Theta)))
}
}
if (detailed.axes) {
rownames(GoF) <- sprintf("%.4f", lambda.Theta)
colnames(GoF) <- sprintf("%.4f", lambda.Beta)
} else {
rownames(GoF) <- rep(" ", length(lambda.Theta))
colnames(GoF) <- rep(" ", length(lambda.Beta))
pos <- unique(round(seq(1, length(lambda.Theta), length.out = min(11, length(lambda.Theta)))))
rownames(GoF)[pos] <- sprintf("%.4f", lambda.Theta[pos])
pos <- unique(round(seq(1, length(lambda.Beta), length.out = min(11, length(lambda.Beta)))))
colnames(GoF)[pos] <- sprintf("%.4f", lambda.Beta[pos])
}
GoF <- t(apply(GoF, 1, rev))
# MOD: Enhanced color scheme
quantiles <- quantile(missoNet.obj$param_set$gof,
c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
col <- circlize::colorRamp2(
c(GoF.bg, quantiles),
c("white", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
)
ComplexHeatmap::Heatmap(
GoF,
col = col,
border = TRUE,
row_title = expression(lambda[Theta]),
row_title_side = "right",
row_title_rot = 0,
row_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
row_names_side = "left",
row_names_gp = grid::gpar(fontsize = 8),
column_title = expression(lambda[Beta]),
column_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
column_names_gp = grid::gpar(fontsize = 8),
cluster_rows = FALSE,
cluster_columns = FALSE,
name = "GoF",
heatmap_legend_param = list(
title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
labels_gp = grid::gpar(fontsize = 8),
grid_height = grid::unit(4, "mm"),
grid_width = grid::unit(4, "mm")
),
...,
cell_fun = function(j, i, x, y, width, height, fill) {
if (j == which(rev(lambda.Beta) == missoNet.obj$est.min$lambda.beta) &&
i == which(lambda.Theta == missoNet.obj$est.min$lambda.theta)) {
grid::grid.rect(x = x, y = y, width = width, height = height,
gp = grid::gpar(lwd = 2, col = "white", fill = NA))
grid::grid.points(x = x, y = y, pch = 19, size = grid::unit(3, "mm"),
gp = grid::gpar(col = "white"))
}
}
)
}
#' @noRd
plot.cv.scatter <- function(cv.missoNet.obj, detailed.axes = TRUE, plt.surf = TRUE, ...) {
lambda.Beta <- cv.missoNet.obj$lambda.beta.seq
lambda.Theta <- cv.missoNet.obj$lambda.theta.seq
cv.lamB <- unique(cv.missoNet.obj$param_set$cv.grid.beta)
cv.lamTh <- unique(cv.missoNet.obj$param_set$cv.grid.theta)
if (is.null(cv.missoNet.obj$param_set$cv.errors.mean)) {
stop("No cross-validation errors found in the object")
}
cvm <- NULL
cvm.bg <- min(cv.missoNet.obj$param_set$cv.errors.mean, na.rm = TRUE) * 0.9
for (l in 1:length(lambda.Beta)) {
if (lambda.Beta[l] %in% cv.lamB) {
cvm.pad <- rep(cvm.bg, length(lambda.Theta))
cv.lamB.id <- which(cv.lamB == lambda.Beta[l])
cvm.pad[lambda.Theta %in% cv.lamTh] <- cv.missoNet.obj$param_set$cv.errors.mean[
((cv.lamB.id-1) * length(cv.lamTh) + 1) : (cv.lamB.id * length(cv.lamTh))
]
cvm <- cbind(cvm, cvm.pad)
} else {
cvm <- cbind(cvm, rep(cvm.bg, length(lambda.Theta)))
}
}
cvm <- t(apply(cvm, 1, rev))
if (isTRUE(detailed.axes)) {
lab <- c(10, 10, 10)
} else {
lab <- par("lab")
}
# Color palette
quantiles <- quantile(cv.missoNet.obj$param_set$cv.errors.mean,
c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
col <- circlize::colorRamp2(
c(cvm.bg, quantiles),
c("gray90", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
)
cvm.long <- as.vector(cvm)
lambda.Beta.long <- rep(rev(log10(lambda.Beta)), each = length(lambda.Theta))
lambda.Theta.long <- rep(log10(lambda.Theta), length(lambda.Beta))
s3d <- scatterplot3d::scatterplot3d(
lambda.Beta.long, lambda.Theta.long, cvm.long,
type = "p",
pch = 19,
cex.symbols = 1.2,
color = col(cvm.long),
lab = lab,
zlab = "CV Error",
xlab = expression(log[10](lambda[Beta])),
ylab = expression(log[10](lambda[Theta])),
main = "Cross-Validation Error Surface",
grid = TRUE,
box = TRUE,
angle = 40,
...
)
if (isTRUE(plt.surf)) {
lambda.Beta.log <- rev(log10(lambda.Beta))
lambda.Theta.log <- log10(lambda.Theta)
# Semi-transparent surface lines
for (i in 1:length(lambda.Beta.log)) {
s3d$points3d(rep(lambda.Beta.log[i], length(lambda.Theta.log)),
lambda.Theta.log, cvm[, i],
type = "l", lty = 1, col = "gray70", lwd = 0.5)
}
for (i in 1:length(lambda.Theta.log)) {
s3d$points3d(lambda.Beta.log,
rep(lambda.Theta.log[i], length(lambda.Beta.log)),
cvm[i, ],
type = "l", lty = 1, col = "gray70", lwd = 0.5)
}
# Highlight minimum point
min.idx <- which.min(cvm)
min.i <- (min.idx - 1) %% nrow(cvm) + 1
min.j <- (min.idx - 1) %/% nrow(cvm) + 1
s3d$points3d(lambda.Beta.log[min.j], lambda.Theta.log[min.i], min(cvm),
pch = 18, cex = 2, col = "red")
}
invisible(s3d) # Return plot object invisibly
}
#' @noRd
plot.scatter <- function(missoNet.obj, detailed.axes = TRUE, plt.surf = TRUE, ...) {
lambda.Beta <- missoNet.obj$lambda.beta.seq
lambda.Theta <- missoNet.obj$lambda.theta.seq
GoF.lamB <- unique(missoNet.obj$param_set$gof.grid.beta)
GoF.lamTh <- unique(missoNet.obj$param_set$gof.grid.theta)
if (is.null(missoNet.obj$param_set$gof)) {
stop("No goodness-of-fit values found in the object")
}
GoF <- NULL
GoF.bg <- min(missoNet.obj$param_set$gof, na.rm = TRUE) * 0.9
for (l in 1:length(lambda.Beta)) {
if (lambda.Beta[l] %in% GoF.lamB) {
GoF.pad <- rep(GoF.bg, length(lambda.Theta))
GoF.lamB.id <- which(GoF.lamB == lambda.Beta[l])
GoF.pad[lambda.Theta %in% GoF.lamTh] <- missoNet.obj$param_set$gof[
((GoF.lamB.id-1) * length(GoF.lamTh) + 1) : (GoF.lamB.id * length(GoF.lamTh))
]
GoF <- cbind(GoF, GoF.pad)
} else {
GoF <- cbind(GoF, rep(GoF.bg, length(lambda.Theta)))
}
}
GoF <- t(apply(GoF, 1, rev))
if (isTRUE(detailed.axes)) {
lab <- c(10, 10, 10)
} else {
lab <- par("lab")
}
quantiles <- quantile(missoNet.obj$param_set$gof,
c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
col <- circlize::colorRamp2(
c(GoF.bg, quantiles),
c("gray90", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
)
GoF.long <- as.vector(GoF)
lambda.Beta.long <- rep(rev(log10(lambda.Beta)), each = length(lambda.Theta))
lambda.Theta.long <- rep(log10(lambda.Theta), length(lambda.Beta))
s3d <- scatterplot3d::scatterplot3d(
lambda.Beta.long, lambda.Theta.long, GoF.long,
type = "p",
pch = 19,
cex.symbols = 1.2,
color = col(GoF.long),
lab = lab,
zlab = "GoF",
xlab = expression(log[10](lambda[Beta])),
ylab = expression(log[10](lambda[Theta])),
main = "Goodness-of-Fit Surface",
grid = TRUE,
box = TRUE,
angle = 40,
...
)
if (isTRUE(plt.surf)) {
lambda.Beta.log <- rev(log10(lambda.Beta))
lambda.Theta.log <- log10(lambda.Theta)
for (i in 1:length(lambda.Beta.log)) {
s3d$points3d(rep(lambda.Beta.log[i], length(lambda.Theta.log)),
lambda.Theta.log, GoF[, i],
type = "l", lty = 1, col = "gray70", lwd = 0.5)
}
for (i in 1:length(lambda.Theta.log)) {
s3d$points3d(lambda.Beta.log,
rep(lambda.Theta.log[i], length(lambda.Beta.log)),
GoF[i, ],
type = "l", lty = 1, col = "gray70", lwd = 0.5)
}
# Highlight minimum point
min.idx <- which.min(GoF)
min.i <- (min.idx - 1) %% nrow(GoF) + 1
min.j <- (min.idx - 1) %/% nrow(GoF) + 1
s3d$points3d(lambda.Beta.log[min.j], lambda.Theta.log[min.i], min(GoF),
pch = 18, cex = 2, col = "red")
}
invisible(s3d)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.