R/plot.missoNet.R

Defines functions plot.scatter plot.cv.scatter plot.heatmap plot.cv.heatmap plot.missoNet

Documented in plot.missoNet

#' Plot methods for \code{missoNet} and cross-validated fits
#'
#' Visualize either the cross-validation (CV) error surface or the
#' goodness-of-fit (GoF) surface over the \eqn{\lambda_B}–\eqn{\lambda_\Theta}
#' grid for objects returned by \code{\link{missoNet}} or \code{\link{cv.missoNet}}.
#' Two display types are supported:
#' a 2D heatmap (default) and a 3D scatter surface.
#'
#' @section What gets plotted:
#' \itemize{
#'   \item \strong{CV objects} (created by \code{\link{cv.missoNet}} or any
#'   \code{missoNet} object that carries CV results): the color encodes the
#'   mean CV error for each \eqn{(\lambda_B, \lambda_\Theta)} pair. The
#'   \emph{minimum-error} solution is outlined; if 1-SE solutions were
#'   computed, they are also marked (dashed/dotted outlines).
#'
#'   \item \strong{Non-CV objects} (created by \code{\link{missoNet}} without CV):
#'   the color encodes the GoF value over the grid; the selected
#'   \emph{minimum} (best) solution is outlined.
#' }
#'
#' @section Axes and scales:
#' For heatmaps, axes are the raw \eqn{\lambda} values; rows are
#' \eqn{\lambda_\Theta} and columns are \eqn{\lambda_B}.
#' For 3D scatter plots, both \eqn{\lambda} axes are shown on the
#' \eqn{\log_{10}} scale for readability.
#'
#' @section Color mapping:
#' A viridis-like palette is used. Breaks are based on distribution quantiles
#' of the CV error or GoF values to enhance contrast across the grid.
#'
#' @param x A fitted object returned by \code{\link{missoNet}} or
#'   \code{\link{cv.missoNet}}.
#' @param type Character string specifying the plot type.
#'   One of \code{"heatmap"} (default) or \code{"scatter"}.
#' @param detailed.axes Logical; if \code{TRUE} (default) show dense axis labels.
#'   If \code{FALSE}, a sparser labeling is used to avoid clutter.
#' @param plt.surf Logical; for \code{type = "scatter"} only, draw light surface
#'   grid lines and highlight the minimum point. Ignored for heatmaps. Default \code{TRUE}.
#' @param ... Additional graphical arguments forwarded to
#'   \code{\link[ComplexHeatmap]{Heatmap}} when \code{type = "heatmap"}, or to
#'   \code{\link[scatterplot3d]{scatterplot3d}} when \code{type = "scatter"}.
#'
#' @return
#' \itemize{
#'   \item For \code{type = "heatmap"}: a \code{ComplexHeatmap} \code{Heatmap}
#'   object (invisibly drawn by \code{ComplexHeatmap}).
#'   \item For \code{type = "scatter"}: a \code{"scatterplot3d"} object,
#'   returned \emph{invisibly}.
#' }
#'
#' @details
#' This S3 method detects whether \code{x} contains cross-validation results and
#' chooses an appropriate plotting backend:
#' \itemize{
#'   \item \strong{Heatmap}: uses \code{\link[ComplexHeatmap]{Heatmap}} with a
#'   viridis-like color ramp (via \code{\link[circlize]{colorRamp2}}). The
#'   selected \eqn{(\lambda_B, \lambda_\Theta)} is outlined in white; 1-SE
#'   choices (if present) are highlighted with dashed/dotted outlines.
#'   \item \strong{Scatter}: uses \code{\link[scatterplot3d]{scatterplot3d}}
#'   to draw the error/GoF surface on \eqn{\log_{10}} scales. When
#'   \code{plt.surf = TRUE}, light lattice lines are added, and the minimum is
#'   marked.
#' }
#'
#' @section Dependencies:
#' Requires \pkg{ComplexHeatmap}, \pkg{circlize}, \pkg{scatterplot3d}, and \pkg{grid}.
#'
#' @seealso \code{\link{missoNet}}, \code{\link{cv.missoNet}},
#'   \code{\link[ComplexHeatmap]{Heatmap}}, \code{\link[scatterplot3d]{scatterplot3d}}
#'
#' @author Yixiao Zeng \email{yixiao.zeng@@mail.mcgill.ca}, Celia M. T. Greenwood
#'
#' @examples
#' sim <- generateData(n = 150, p = 10, q = 8, rho = 0.1, missing.type = "MCAR")
#'
#' \donttest{
#' ## Fit a model without CV (plots GoF surface)
#' fit <- missoNet(X = sim$X, Y = sim$Z, verbose = 0)
#' plot(fit, type = "heatmap")                        # GoF heatmap
#' plot(fit, type = "scatter", plt.surf = TRUE)       # GoF 3D scatter
#'
#' ## Cross-validation (plots CV error surface)
#' cvfit <- cv.missoNet(X = sim$X, Y = sim$Z, verbose = 0)
#' plot(cvfit, type = "heatmap", detailed.axes = FALSE)
#' plot(cvfit, type = "scatter", plt.surf = FALSE)
#' }
#'
#' @keywords hplot models
#' @aliases plot.cv.missoNet
#' @method plot missoNet
#' @export
#'
#' @importFrom ComplexHeatmap Heatmap
#' @importFrom scatterplot3d scatterplot3d
#' @importFrom circlize colorRamp2
#' @importFrom grid gpar grid.rect grid.points unit

plot.missoNet <- function(x, type = c("heatmap", "scatter"), detailed.axes = TRUE, plt.surf = TRUE, ...) {
  type <- match.arg(type)
  
  is_cv <- !is.null(x$param_set$cv.errors.mean)
  
  if (is_cv) {
    switch(type, 
           heatmap = plot.cv.heatmap(x, detailed.axes, ...), 
           scatter = plot.cv.scatter(x, detailed.axes, plt.surf, ...))
  } else {
    switch(type, 
           heatmap = plot.heatmap(x, detailed.axes, ...), 
           scatter = plot.scatter(x, detailed.axes, plt.surf, ...))
  }
}

# ============================================================================
#                           INTERNAL HELPER FUNCTIONS
# ============================================================================

#' @noRd
plot.cv.heatmap <- function(cv.missoNet.obj, detailed.axes, ...) {
  lambda.Beta <- cv.missoNet.obj$lambda.beta.seq
  lambda.Theta <- cv.missoNet.obj$lambda.theta.seq
  cv.lamB <- unique(cv.missoNet.obj$param_set$cv.grid.beta)
  cv.lamTh <- unique(cv.missoNet.obj$param_set$cv.grid.theta)
  
  if (is.null(cv.missoNet.obj$param_set$cv.errors.mean)) {
    stop("No cross-validation errors found in the object")
  }
  
  cvm <- NULL
  cvm.bg <- min(cv.missoNet.obj$param_set$cv.errors.mean, na.rm = TRUE) * 0.9
  
  for (l in 1:length(lambda.Beta)) {
    if (lambda.Beta[l] %in% cv.lamB) {
      cvm.pad <- rep(cvm.bg, length(lambda.Theta))
      cv.lamB.id <- which(cv.lamB == lambda.Beta[l])
      cvm.pad[lambda.Theta %in% cv.lamTh] <- cv.missoNet.obj$param_set$cv.errors.mean[
        ((cv.lamB.id-1) * length(cv.lamTh) + 1) : (cv.lamB.id * length(cv.lamTh))
      ]
      cvm <- cbind(cvm, cvm.pad)
    } else {
      cvm <- cbind(cvm, rep(cvm.bg, length(lambda.Theta)))
    }
  }
  
  if (detailed.axes) {
    rownames(cvm) <- sprintf("%.4f", lambda.Theta)
    colnames(cvm) <- sprintf("%.4f", lambda.Beta)
  } else {
    rownames(cvm) <- rep(" ", length(lambda.Theta))
    colnames(cvm) <- rep(" ", length(lambda.Beta))
    # Position selection for sparse labels
    pos <- unique(round(seq(1, length(lambda.Theta), length.out = min(11, length(lambda.Theta)))))
    rownames(cvm)[pos] <- sprintf("%.4f", lambda.Theta[pos])
    pos <- unique(round(seq(1, length(lambda.Beta), length.out = min(11, length(lambda.Beta)))))
    colnames(cvm)[pos] <- sprintf("%.4f", lambda.Beta[pos])
  }
  
  # Flip vertically and transpose for correct orientation
  cvm <- t(apply(cvm, 1, rev))
  
  quantiles <- quantile(cv.missoNet.obj$param_set$cv.errors.mean, 
                        c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
  col <- circlize::colorRamp2(
    c(cvm.bg, quantiles),
    c("white", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725") 
  )
  
  ComplexHeatmap::Heatmap(
    cvm, 
    col = col, 
    border = TRUE,
    row_title = expression(lambda[Theta]), 
    row_title_side = "right", 
    row_title_rot = 0,
    row_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
    row_names_side = "left", 
    row_names_gp = grid::gpar(fontsize = 8),
    column_title = expression(lambda[Beta]), 
    column_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
    column_names_gp = grid::gpar(fontsize = 8),
    cluster_rows = FALSE, 
    cluster_columns = FALSE,
    name = "CV Error",
    heatmap_legend_param = list(
      title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
      labels_gp = grid::gpar(fontsize = 8),
      grid_height = grid::unit(4, "mm"),
      grid_width = grid::unit(4, "mm")
    ),
    ...,
    cell_fun = function(j, i, x, y, width, height, fill) {
      # Highlight lambda.min with solid white border
      if (j == which(rev(lambda.Beta) == cv.missoNet.obj$est.min$lambda.beta) && 
          i == which(lambda.Theta == cv.missoNet.obj$est.min$lambda.theta)) {
        grid::grid.rect(x = x, y = y, width = width, height = height, 
                        gp = grid::gpar(lwd = 2, col = "white", fill = NA))
        # Corner markers for better visibility
        grid::grid.points(x = x, y = y, pch = 19, size = grid::unit(3, "mm"),
                          gp = grid::gpar(col = "white"))
      }
      
      # Highlight 1se.beta with dashed border
      if (!is.null(cv.missoNet.obj$est.1se.beta)) {
        if (j == which(rev(lambda.Beta) == cv.missoNet.obj$est.1se.beta$lambda.beta) && 
            i == which(lambda.Theta == cv.missoNet.obj$est.1se.beta$lambda.theta)) {
          grid::grid.rect(x = x, y = y, width = width, height = height, 
                          gp = grid::gpar(lwd = 2, lty = "dashed", col = "yellow", fill = NA))
        }
      }
      
      # Highlight 1se.theta with dotted border
      if (!is.null(cv.missoNet.obj$est.1se.theta)) {
        if (j == which(rev(lambda.Beta) == cv.missoNet.obj$est.1se.theta$lambda.beta) && 
            i == which(lambda.Theta == cv.missoNet.obj$est.1se.theta$lambda.theta)) {
          grid::grid.rect(x = x, y = y, width = width, height = height, 
                          gp = grid::gpar(lwd = 2, lty = "dotted", col = "cyan", fill = NA))
        }
      }
    }
  )
}

#' @noRd
plot.heatmap <- function(missoNet.obj, detailed.axes, ...) {
  lambda.Beta <- missoNet.obj$lambda.beta.seq
  lambda.Theta <- missoNet.obj$lambda.theta.seq
  GoF.lamB <- unique(missoNet.obj$param_set$gof.grid.beta)
  GoF.lamTh <- unique(missoNet.obj$param_set$gof.grid.theta)
  
  if (is.null(missoNet.obj$param_set$gof)) {
    stop("No goodness-of-fit values found in the object")
  }
  
  GoF <- NULL
  GoF.bg <- min(missoNet.obj$param_set$gof, na.rm = TRUE) * 0.9
  
  for (l in 1:length(lambda.Beta)) {
    if (lambda.Beta[l] %in% GoF.lamB) {
      GoF.pad <- rep(GoF.bg, length(lambda.Theta))
      GoF.lamB.id <- which(GoF.lamB == lambda.Beta[l])
      GoF.pad[lambda.Theta %in% GoF.lamTh] <- missoNet.obj$param_set$gof[
        ((GoF.lamB.id-1) * length(GoF.lamTh) + 1) : (GoF.lamB.id * length(GoF.lamTh))
      ]
      GoF <- cbind(GoF, GoF.pad)
    } else {
      GoF <- cbind(GoF, rep(GoF.bg, length(lambda.Theta)))
    }
  }
  
  if (detailed.axes) {
    rownames(GoF) <- sprintf("%.4f", lambda.Theta)
    colnames(GoF) <- sprintf("%.4f", lambda.Beta)
  } else {
    rownames(GoF) <- rep(" ", length(lambda.Theta))
    colnames(GoF) <- rep(" ", length(lambda.Beta))
    pos <- unique(round(seq(1, length(lambda.Theta), length.out = min(11, length(lambda.Theta)))))
    rownames(GoF)[pos] <- sprintf("%.4f", lambda.Theta[pos])
    pos <- unique(round(seq(1, length(lambda.Beta), length.out = min(11, length(lambda.Beta)))))
    colnames(GoF)[pos] <- sprintf("%.4f", lambda.Beta[pos])
  }
  
  GoF <- t(apply(GoF, 1, rev))
  
  # MOD: Enhanced color scheme
  quantiles <- quantile(missoNet.obj$param_set$gof, 
                        c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
  col <- circlize::colorRamp2(
    c(GoF.bg, quantiles),
    c("white", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
  )
  
  ComplexHeatmap::Heatmap(
    GoF, 
    col = col, 
    border = TRUE,
    row_title = expression(lambda[Theta]), 
    row_title_side = "right", 
    row_title_rot = 0,
    row_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
    row_names_side = "left", 
    row_names_gp = grid::gpar(fontsize = 8),
    column_title = expression(lambda[Beta]), 
    column_title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
    column_names_gp = grid::gpar(fontsize = 8),
    cluster_rows = FALSE, 
    cluster_columns = FALSE,
    name = "GoF",
    heatmap_legend_param = list(
      title_gp = grid::gpar(fontsize = 10, fontface = "bold"),
      labels_gp = grid::gpar(fontsize = 8),
      grid_height = grid::unit(4, "mm"),
      grid_width = grid::unit(4, "mm")
    ),
    ...,
    cell_fun = function(j, i, x, y, width, height, fill) {
      if (j == which(rev(lambda.Beta) == missoNet.obj$est.min$lambda.beta) && 
          i == which(lambda.Theta == missoNet.obj$est.min$lambda.theta)) {
        grid::grid.rect(x = x, y = y, width = width, height = height, 
                        gp = grid::gpar(lwd = 2, col = "white", fill = NA))
        grid::grid.points(x = x, y = y, pch = 19, size = grid::unit(3, "mm"),
                          gp = grid::gpar(col = "white"))
      }
    }
  )
}

#' @noRd
plot.cv.scatter <- function(cv.missoNet.obj, detailed.axes = TRUE, plt.surf = TRUE, ...) {
  lambda.Beta <- cv.missoNet.obj$lambda.beta.seq
  lambda.Theta <- cv.missoNet.obj$lambda.theta.seq
  cv.lamB <- unique(cv.missoNet.obj$param_set$cv.grid.beta)
  cv.lamTh <- unique(cv.missoNet.obj$param_set$cv.grid.theta)
  
  if (is.null(cv.missoNet.obj$param_set$cv.errors.mean)) {
    stop("No cross-validation errors found in the object")
  }
  
  cvm <- NULL
  cvm.bg <- min(cv.missoNet.obj$param_set$cv.errors.mean, na.rm = TRUE) * 0.9
  
  for (l in 1:length(lambda.Beta)) {
    if (lambda.Beta[l] %in% cv.lamB) {
      cvm.pad <- rep(cvm.bg, length(lambda.Theta))
      cv.lamB.id <- which(cv.lamB == lambda.Beta[l])
      cvm.pad[lambda.Theta %in% cv.lamTh] <- cv.missoNet.obj$param_set$cv.errors.mean[
        ((cv.lamB.id-1) * length(cv.lamTh) + 1) : (cv.lamB.id * length(cv.lamTh))
      ]
      cvm <- cbind(cvm, cvm.pad)
    } else {
      cvm <- cbind(cvm, rep(cvm.bg, length(lambda.Theta)))
    }
  }
  
  cvm <- t(apply(cvm, 1, rev))
  
  if (isTRUE(detailed.axes)) {
    lab <- c(10, 10, 10)
  } else { 
    lab <- par("lab") 
  }
  
  # Color palette
  quantiles <- quantile(cv.missoNet.obj$param_set$cv.errors.mean, 
                        c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
  col <- circlize::colorRamp2(
    c(cvm.bg, quantiles),
    c("gray90", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
  )
  
  cvm.long <- as.vector(cvm)
  lambda.Beta.long <- rep(rev(log10(lambda.Beta)), each = length(lambda.Theta))
  lambda.Theta.long <- rep(log10(lambda.Theta), length(lambda.Beta))
  
  s3d <- scatterplot3d::scatterplot3d(
    lambda.Beta.long, lambda.Theta.long, cvm.long,
    type = "p", 
    pch = 19,
    cex.symbols = 1.2,
    color = col(cvm.long), 
    lab = lab,
    zlab = "CV Error", 
    xlab = expression(log[10](lambda[Beta])), 
    ylab = expression(log[10](lambda[Theta])),
    main = "Cross-Validation Error Surface",
    grid = TRUE,
    box = TRUE,
    angle = 40,
    ...
  )
  
  if (isTRUE(plt.surf)) {
    lambda.Beta.log <- rev(log10(lambda.Beta))
    lambda.Theta.log <- log10(lambda.Theta)
    
    # Semi-transparent surface lines
    for (i in 1:length(lambda.Beta.log)) {
      s3d$points3d(rep(lambda.Beta.log[i], length(lambda.Theta.log)), 
                   lambda.Theta.log, cvm[, i], 
                   type = "l", lty = 1, col = "gray70", lwd = 0.5)
    }
    for (i in 1:length(lambda.Theta.log)) {
      s3d$points3d(lambda.Beta.log, 
                   rep(lambda.Theta.log[i], length(lambda.Beta.log)), 
                   cvm[i, ], 
                   type = "l", lty = 1, col = "gray70", lwd = 0.5)
    }
    
    # Highlight minimum point
    min.idx <- which.min(cvm)
    min.i <- (min.idx - 1) %% nrow(cvm) + 1
    min.j <- (min.idx - 1) %/% nrow(cvm) + 1
    s3d$points3d(lambda.Beta.log[min.j], lambda.Theta.log[min.i], min(cvm),
                 pch = 18, cex = 2, col = "red")
  }
  
  invisible(s3d)  # Return plot object invisibly
}

#' @noRd
plot.scatter <- function(missoNet.obj, detailed.axes = TRUE, plt.surf = TRUE, ...) {
  lambda.Beta <- missoNet.obj$lambda.beta.seq
  lambda.Theta <- missoNet.obj$lambda.theta.seq
  GoF.lamB <- unique(missoNet.obj$param_set$gof.grid.beta)
  GoF.lamTh <- unique(missoNet.obj$param_set$gof.grid.theta)
  
  if (is.null(missoNet.obj$param_set$gof)) {
    stop("No goodness-of-fit values found in the object")
  }
  
  GoF <- NULL
  GoF.bg <- min(missoNet.obj$param_set$gof, na.rm = TRUE) * 0.9
  
  for (l in 1:length(lambda.Beta)) {
    if (lambda.Beta[l] %in% GoF.lamB) {
      GoF.pad <- rep(GoF.bg, length(lambda.Theta))
      GoF.lamB.id <- which(GoF.lamB == lambda.Beta[l])
      GoF.pad[lambda.Theta %in% GoF.lamTh] <- missoNet.obj$param_set$gof[
        ((GoF.lamB.id-1) * length(GoF.lamTh) + 1) : (GoF.lamB.id * length(GoF.lamTh))
      ]
      GoF <- cbind(GoF, GoF.pad)
    } else {
      GoF <- cbind(GoF, rep(GoF.bg, length(lambda.Theta)))
    }
  }
  
  GoF <- t(apply(GoF, 1, rev))
  
  if (isTRUE(detailed.axes)) {
    lab <- c(10, 10, 10)
  } else { 
    lab <- par("lab") 
  }
  
  quantiles <- quantile(missoNet.obj$param_set$gof, 
                        c(0, 0.1, 0.25, 0.5, 0.75, 0.9, 1), na.rm = TRUE)
  col <- circlize::colorRamp2(
    c(GoF.bg, quantiles),
    c("gray90", "#440154", "#414487", "#2A788E", "#22A884", "#7AD151", "#FDE725", "#FDE725")
  )
  
  GoF.long <- as.vector(GoF)
  lambda.Beta.long <- rep(rev(log10(lambda.Beta)), each = length(lambda.Theta))
  lambda.Theta.long <- rep(log10(lambda.Theta), length(lambda.Beta))
  
  s3d <- scatterplot3d::scatterplot3d(
    lambda.Beta.long, lambda.Theta.long, GoF.long,
    type = "p", 
    pch = 19,
    cex.symbols = 1.2,
    color = col(GoF.long), 
    lab = lab,
    zlab = "GoF", 
    xlab = expression(log[10](lambda[Beta])), 
    ylab = expression(log[10](lambda[Theta])),
    main = "Goodness-of-Fit Surface",
    grid = TRUE,
    box = TRUE,
    angle = 40,
    ...
  )
  
  if (isTRUE(plt.surf)) {
    lambda.Beta.log <- rev(log10(lambda.Beta))
    lambda.Theta.log <- log10(lambda.Theta)
    
    for (i in 1:length(lambda.Beta.log)) {
      s3d$points3d(rep(lambda.Beta.log[i], length(lambda.Theta.log)), 
                   lambda.Theta.log, GoF[, i], 
                   type = "l", lty = 1, col = "gray70", lwd = 0.5)
    }
    for (i in 1:length(lambda.Theta.log)) {
      s3d$points3d(lambda.Beta.log, 
                   rep(lambda.Theta.log[i], length(lambda.Beta.log)), 
                   GoF[i, ], 
                   type = "l", lty = 1, col = "gray70", lwd = 0.5)
    }
    
    # Highlight minimum point
    min.idx <- which.min(GoF)
    min.i <- (min.idx - 1) %% nrow(GoF) + 1
    min.j <- (min.idx - 1) %/% nrow(GoF) + 1
    s3d$points3d(lambda.Beta.log[min.j], lambda.Theta.log[min.i], min(GoF),
                 pch = 18, cex = 2, col = "red")
  }
  
  invisible(s3d)
}

Try the missoNet package in your browser

Any scripts or data that you put into this service are public.

missoNet documentation built on Sept. 9, 2025, 5:55 p.m.