R/cv.missoNet.R

Defines functions cv.missoNet

Documented in cv.missoNet

#' Cross-validation for missoNet
#'
#' @description
#' Perform \eqn{k}-fold cross-validation to select the regularization pair
#' \code{(lambda.beta, lambda.theta)} for \code{\link{missoNet}}. For each
#' fold the model is trained on \eqn{k-1} partitions and evaluated on the held-out
#' partition over a grid of lambda pairs; the pair with minimum mean CV error is
#' returned, with optional 1-SE models for more regularized solutions.
#'
#' @details
#' Internally, predictors \code{X} and responses \code{Y} can be standardized
#' for optimization; all reported estimates are re-scaled back to the original
#' data scale. Missingness in \code{Y} is handled via unbiased estimating
#' equations using column-wise observation probabilities estimated from \code{Y}
#' (or supplied via \code{rho}). This is appropriate when the missingness of each
#' response is independent of its unobserved value (e.g., MCAR).
#'
#' If \code{adaptive.search = TRUE}, a fast two-stage pre-optimization narrows
#' the lambda grid before computing fold errors on a focused neighborhood; this
#' can be substantially faster on large grids but may occasionally miss the global
#' optimum. 
#' 
#' When \code{compute.1se = TRUE}, two additional solutions are reported:
#' the largest \code{lambda.beta} and the largest \code{lambda.theta} whose CV
#' error is within one standard error of the minimum (holding the other lambda
#' fixed at its optimal value). At the end, three special lambda pairs are identified:
#' \itemize{
#'   \item \strong{lambda.min}: Parameters giving minimum CV error
#'   \item \strong{lambda.1se.beta}: Largest \eqn{\lambda_B} within 1 SE of minimum 
#'     (with \eqn{\lambda_\Theta} fixed at optimum)
#'   \item \strong{lambda.1se.theta}: Largest \eqn{\lambda_\Theta} within 1 SE of minimum 
#'     (with \eqn{\lambda_B} fixed at optimum)
#' }
#' 
#' The 1SE rules provide more regularized models that may generalize better.
#'
#' @param X Numeric matrix (\eqn{n \times p}). Predictors (no missing values).
#' @param Y Numeric matrix (\eqn{n \times q}). Responses. Missing values should be
#'   coded as \code{NA}/\code{NaN}.
#' @param kfold Integer \eqn{\ge 2}. Number of folds (default \code{5}).
#' @param rho Optional numeric vector of length \eqn{q}. Working missingness
#'   probabilities (per response). If \code{NULL} (default), estimated from \code{Y}.
#' @param lambda.beta,lambda.theta Optional numeric vectors. Candidate
#'   regularization paths for \eqn{\mathbf{B}} and \eqn{\Theta}. If
#'   \code{NULL}, sequences are generated automatically from the data.
#'   Avoid supplying a single value because warm starts along a path are used.
#' @param lambda.beta.min.ratio,lambda.theta.min.ratio Optional numerics in \code{(0,1]}.
#'   Ratio of the smallest to the largest value when generating lambda sequences
#'   (ignored if the corresponding \code{lambda.*} is supplied).
#' @param n.lambda.beta,n.lambda.theta Optional integers. Lengths of the
#'   automatically generated lambda paths (ignored if the corresponding
#'   \code{lambda.*} is supplied).
#' @param beta.pen.factor Optional \eqn{p \times q} non-negative matrix of element-wise
#'   penalty multipliers for \eqn{\mathbf{B}}. \code{Inf} = maximum penalty;
#'   \code{0} = no penalty for the corresponding coefficient. Default: all 1s (equal penalty).
#' @param theta.pen.factor Optional \eqn{q \times q} non-negative matrix of element-wise
#'   penalty multipliers for \eqn{\Theta}. Off-diagonal entries control edge
#'   penalties; diagonal treatment is governed by \code{penalize.diagonal}. \code{Inf} =
#'   maximum penalty; \code{0} = no penalty for that element. Default: all 1s (equal penalty).
#' @param penalize.diagonal Logical or \code{NULL}. Whether to penalize diagonal entries
#'   of \eqn{\Theta}. If \code{NULL} (default) the choice is made automatically.
#' @param beta.max.iter,theta.max.iter Integers. Max iterations for the
#'   \eqn{\mathbf{B}} update (FISTA) and \eqn{\Theta} update (graphical lasso).
#'   Defaults: \code{10000}.
#' @param beta.tol,theta.tol Numerics \eqn{> 0}. Convergence tolerances for the
#'   \eqn{\mathbf{B}} and \eqn{\Theta} updates. Defaults: \code{1e-5}.
#' @param eta Numeric in \code{(0,1)}. Backtracking line-search parameter for the
#'   \eqn{\mathbf{B}} update (default \code{0.8}).
#' @param eps Numeric in \code{(0,1)}. Eigenvalue floor used to stabilize positive
#'   definiteness operations (default \code{1e-8}).
#' @param standardize Logical. Standardize columns of \code{X} internally? Default \code{TRUE}.
#' @param standardize.response Logical. Standardize columns of \code{Y} internally?
#'   Default \code{TRUE}.
#' @param compute.1se Logical. Also compute 1-SE solutions? Default \code{TRUE}.
#' @param relax.net (Experimental) Logical. If \code{TRUE}, refit active edges of \eqn{\Theta}
#'   without \eqn{\ell_1} penalty (de-biased network). Default \code{FALSE}.
#' @param adaptive.search (Experimental) Logical. Use adaptive two-stage lambda search? Default \code{FALSE}.
#' @param shuffle Logical. Randomly shuffle fold assignments? Default \code{TRUE}.
#' @param seed Optional integer seed (used when \code{shuffle = TRUE}).
#' @param parallel Logical. Evaluate folds in parallel using a provided cluster?
#'   Default \code{FALSE}.
#' @param cl Optional cluster from \code{parallel::makeCluster()} (required if
#'   \code{parallel = TRUE}).
#' @param verbose Integer in \code{0,1,2}. \code{0} = silent, \code{1} = progress (default),
#'   \code{2} = detailed tracing (not supported in parallel mode).
#'
#' @return
#' A list of class \code{"missoNet"} with components:
#' \describe{
#'   \item{est.min}{List of estimates at the CV minimum:
#'     \code{Beta} (\eqn{p \times q}), \code{Theta} (\eqn{q \times q}),
#'     intercept \code{mu} (length \eqn{q}), \code{lambda.beta}, \code{lambda.theta},
#'     \code{lambda.beta.idx}, \code{lambda.theta.idx}, and (if requested)
#'     \code{relax.net}.}
#'   \item{est.1se.beta}{List of estimates at the 1-SE \code{lambda.beta}
#'     (if \code{compute.1se = TRUE}); \code{NULL} otherwise.}
#'   \item{est.1se.theta}{List of estimates at the 1-SE \code{lambda.theta}
#'     (if \code{compute.1se = TRUE}); \code{NULL} otherwise.}
#'   \item{rho}{Length-\eqn{q} vector of working missingness probabilities.}
#'   \item{kfold}{Number of folds used.}
#'   \item{fold.index}{Integer vector of length \eqn{n} giving fold assignments
#'     (names are \code{"fold-k"}).}
#'   \item{lambda.beta.seq, lambda.theta.seq}{Unique lambda values explored along
#'     the grid for \eqn{\mathbf{B}} and \eqn{\Theta}.}
#'   \item{penalize.diagonal}{Logical indicating whether the diagonal of
#'     \eqn{\Theta} was penalized.}
#'   \item{beta.pen.factor, theta.pen.factor}{Penalty factor matrices actually used.}
#'   \item{param_set}{List with CV diagnostics:
#'     \code{n}, \code{p}, \code{q}, \code{standardize}, \code{standardize.response},
#'     mean errors \code{cv.errors.mean}, bounds \code{cv.errors.upper/lower},
#'     and the evaluated grids \code{cv.grid.beta}, \code{cv.grid.theta} (length equals
#'     number of fitted models).}
#' }
#'
#' @examples
#' sim <- generateData(n = 120, p = 12, q = 6, rho = 0.1)
#' X <- sim$X; Y <- sim$Z
#'
#' \donttest{
#' # Basic 5-fold cross-validation
#' cvfit <- cv.missoNet(X = X, Y = Y, kfold = 5, verbose = 0)
#' 
#' # Extract optimal estimates
#' Beta.min <- cvfit$est.min$Beta
#' Theta.min <- cvfit$est.min$Theta
#' 
#' # Extract 1SE estimates (if computed)
#' if (!is.null(cvfit$est.1se.beta)) {
#'   Beta.1se <- cvfit$est.1se.beta$Beta
#' }
#' if (!is.null(cvfit$est.1se.theta)) {
#'   Theta.1se <- cvfit$est.1se.theta$Theta
#' }
#' 
#' # Make predictions
#' newX <- matrix(rnorm(10 * 12), 10, 12)
#' pred.min <- predict(cvfit, newx = newX, s = "lambda.min")
#' pred.1se <- predict(cvfit, newx = newX, s = "lambda.1se.beta")
#' 
#' # Parallel cross-validation
#' library(parallel)
#' cl <- makeCluster(min(detectCores() - 1, 2))
#' cvfit2 <- cv.missoNet(X = X, Y = Y, kfold = 5, 
#'                       parallel = TRUE, cl = cl)
#' stopCluster(cl)
#' 
#' # Adaptive search for efficiency
#' cvfit3 <- cv.missoNet(X = X, Y = Y, kfold = 5,
#'                       adaptive.search = TRUE)
#' 
#' # Reproducible CV with specific lambdas
#' cvfit4 <- cv.missoNet(X = X, Y = Y, kfold = 5,
#'                       lambda.beta = 10^seq(0, -2, length = 20),
#'                       lambda.theta = 10^seq(0, -2, length = 20),
#'                       seed = 486)
#' 
#' # Plot CV results
#' plot(cvfit, type = "heatmap")
#' plot(cvfit, type = "scatter")
#' }
#'
#' @seealso
#' \code{\link{missoNet}} for model fitting;
#' generic methods such as \code{plot()} and \code{predict()} for objects of class
#' \code{"missoNet"}.
#'
#' @references
#' Zeng, Y., et al. (2025). \emph{Multivariate regression with missing response data
#' for modelling regional DNA methylation QTLs}. arXiv:2507.05990.
#'
#' @author
#' Yixiao Zeng \email{yixiao.zeng@@mail.mcgill.ca}, Celia M. T. Greenwood
#' @export

cv.missoNet <- function(X, Y,
                        kfold = 5,
                        rho = NULL,
                        lambda.beta = NULL,
                        lambda.theta = NULL,
                        lambda.beta.min.ratio = NULL,
                        lambda.theta.min.ratio = NULL,
                        n.lambda.beta = NULL,
                        n.lambda.theta = NULL,
                        beta.pen.factor = NULL,
                        theta.pen.factor = NULL,
                        penalize.diagonal = NULL,
                        beta.max.iter = 10000,
                        beta.tol = 1e-5,
                        theta.max.iter = 10000,
                        theta.tol = 1e-5,
                        eta = 0.8,
                        eps = 1e-8,
                        standardize = TRUE,
                        standardize.response = TRUE,
                        compute.1se = TRUE,
                        relax.net = FALSE,
                        adaptive.search = FALSE,
                        shuffle = TRUE,
                        seed = NULL,
                        parallel = FALSE, 
                        cl = NULL,
                        verbose = 1) {
  
  ## ---------- Input validation ----------
  if (!is.matrix(X) && !is.data.frame(X)) stop("`X` must be a matrix or data.frame.")
  if (!is.matrix(Y) && !is.data.frame(Y)) stop("`Y` must be a matrix or data.frame.")
  
  X <- as.matrix(X)
  Y <- as.matrix(Y)
  
  stopifnot(
    "X and Y must have same number of rows" = nrow(X) == nrow(Y),
    "lambda.theta must be non-negative" = all(is.null(lambda.theta) | lambda.theta >= 0),
    "lambda.beta must be non-negative" = all(is.null(lambda.beta) | lambda.beta >= 0),
    "beta.max.iter must be positive" = beta.max.iter > 0 && is.finite(beta.max.iter),
    "theta.max.iter must be positive" = theta.max.iter > 0 && is.finite(theta.max.iter),
    "beta.tol must be positive" = beta.tol > 0 && is.finite(beta.tol),
    "theta.tol must be positive" = theta.tol > 0 && is.finite(theta.tol),
    "eps must be in (0, 1)" = eps > 0 && eps < 1,
    "eta must be in (0, 1)" = eta > 0 && eta < 1,
    "verbose must be 0, 1, or 2" = verbose %in% c(0, 1, 2)
  )
  
  n <- nrow(X); p <- ncol(X); q <- ncol(Y)
  
  if (!(is.numeric(kfold) && kfold >= 2)) {
    stop("`kfold` must be an integer >= 2.")
  }
  if (kfold > n) {
    if (verbose > 0) {
      warning(sprintf("kfold (%d) exceeds n (%d), setting kfold = n", kfold, n))
    }
    kfold <- n
  }
  
  check_data_condition <- function(X, Y, verbose) {
    # Check for multicollinearity
    if (ncol(X) <= nrow(X) && ncol(X) > 1) {
      cor_X <- suppressWarnings(cor(X, use = "complete.obs"))
      if (any(is.finite(cor_X))) {
        diag(cor_X) <- 0
        max_cor <- max(abs(cor_X), na.rm = TRUE)
        if (max_cor > 0.95) {
          warning("High collinearity detected (max correlation: ", 
                  round(max_cor, 3), "). Consider removing redundant predictors.")
        }
      }
    }
    
    # Check for near-zero variance
    var_X <- apply(X, 2, var, na.rm = TRUE)
    if (any(var_X <= .Machine$double.eps * 100, na.rm = TRUE)) {
      warning("Near-zero variance predictors detected. These may cause numerical issues.")
    }
    
    # Condition number warning (only when p <= n)
    if (ncol(X) <= nrow(X)) {
      svd.X <- svd(X, nu = 0, nv = 0)
      cond.num <- max(svd.X$d) / max(min(svd.X$d), .Machine$double.eps * 100)
      if (cond.num > 1e12) {
        warning("Predictor matrix `X` appears ill-conditioned (cond. no. = ",
                format(cond.num, scientific = TRUE), ").")
      }
    }
    
    return(invisible(NULL))
  }
  
  check_data_condition(X, Y, verbose)
  
  ## ---------- Header ----------
  if (verbose > 0) {
    cat("\n=============================================================\n")
    cat("               cv.missoNet - Cross-Validation\n")
    cat("=============================================================\n")
    cat("\n> Initializing model...\n")
  }
  
  ## ---------- Fold indices ----------
  if (shuffle) {
    if (!is.null(seed)) set.seed(seed)
    ind <- sample(n)
  } else {
    ind <- seq_len(n)
  }
  foldid <- unlist(lapply(seq_len(kfold), function(x) {
    rep(x, length((1 + floor((x - 1) * n / kfold)) : floor(x * n / kfold)))
  }))
  names(ind) <- paste0("fold-", foldid)
  
  ## ---------- Parameter initialization ----------
  init.obj <- InitParams(
    X = X, Y = Y, rho = rho,
    lamB.pf = beta.pen.factor,
    lamTh.pf = theta.pen.factor,
    pdiag = penalize.diagonal,
    standardize = standardize,
    standardize.response = standardize.response
  )
  
  # Adaptive tolerance adjustment for ill-conditioned problems
  if (init.obj$n_eff < max(p, q)) {
    beta.tol <- max(beta.tol, 1e-4)
    theta.tol <- max(theta.tol, 1e-4)
    if (verbose > 0) {
      cat("Note: relaxed tolerances due to limited effective n\n")
    }
  }
  
  ## ---------- Model configuration ----------
  if (verbose > 0) {
    cat("\n--- Model Configuration -------------------------------------\n")
    cat(sprintf("  Data dimensions:      n = %5d, p = %4d, q = %4d\n", n, p, q))
    cat(sprintf("  Missing rate (avg):   %5.1f%%\n", mean(init.obj$rho.vec) * 100))
    cat(sprintf("  Cross-validation:     %d-fold\n", kfold))
  }
  
  # High missingness warning
  if (mean(init.obj$rho.vec) > 0.5) {
    message("\n============================================================")
    message("HIGH MISSINGNESS WARNING")
    message("------------------------------------------------------------")
    message("Over 50% missing data detected. Recommendations:")
    message("- Consider imputation before analysis")
    message("- Use conservative regularization (larger lambda values)")
    message("- Interpret results with caution")
    message("============================================================\n")
  }
  
  ## ---------- Lambda preparation ----------
  if (isTRUE(adaptive.search)) {
    if (verbose > 0) cat("  Lambda grid:          adaptive (fast pre-test)\n")
    
    step.obj <- stepOptim(
      X = X, Y = Y, init.obj = init.obj, GoF = "eBIC", cv = TRUE,
      eta = eta, eps = eps,
      lamB.vec = lambda.beta, lamTh.vec = lambda.theta,
      n.lamB = n.lambda.beta, n.lamTh = n.lambda.theta,
      lamB.min.ratio = lambda.beta.min.ratio, lamTh.min.ratio = lambda.theta.min.ratio,
      lamB.scale.factor = NULL, lamTh.scale.factor = NULL,
      Beta.maxit = beta.max.iter, Beta.thr = beta.tol,
      Theta.maxit = theta.max.iter, Theta.thr = theta.tol,
      verbose = verbose
    )
    
    golden_ratio <- 1.618
    obs_prob_avg <- if (!is.null(init.obj$obs_prob)) mean(init.obj$obs_prob) else (1 - mean(init.obj$rho.vec))
    
    lamB.range <- c(
      max(step.obj$lamB.min / golden_ratio, max(step.obj$lamB.unique) * 0.001) / max(obs_prob_avg, 0.1),
      min(step.obj$lamB.min * golden_ratio^8, max(step.obj$lamB.unique) * 0.75)
    )
    lamTh.range <- c(
      max(step.obj$lamTh.min / golden_ratio^8, max(step.obj$lamTh.unique) * 0.001) / max(obs_prob_avg, 0.1),
      min(step.obj$lamTh.min * golden_ratio^8, max(step.obj$lamTh.unique) * 0.75)
    )
    
    lamB.vec.neighbor <- step.obj$lamB.unique[
      step.obj$lamB.unique >= lamB.range[1] & step.obj$lamB.unique <= lamB.range[2]
    ]
    lamTh.vec.neighbor <- step.obj$lamTh.unique[
      step.obj$lamTh.unique >= lamTh.range[1] & step.obj$lamTh.unique <= lamTh.range[2]
    ]
    
    # Fallback if neighborhood is empty
    if (!length(lamB.vec.neighbor)) {
      k <- max(5, ceiling(length(step.obj$lamB.unique) * 0.25))
      i0 <- step.obj$lamB.id
      lo <- max(1, i0 - floor(k / 2))
      hi <- min(length(step.obj$lamB.unique), lo + k - 1)
      lamB.vec.neighbor <- step.obj$lamB.unique[lo:hi]
    }
    if (!length(lamTh.vec.neighbor)) {
      k <- max(5, ceiling(length(step.obj$lamTh.unique) * 0.25))
      i0 <- step.obj$lamTh.id
      lo <- max(1, i0 - floor(k / 2))
      hi <- min(length(step.obj$lamTh.unique), lo + k - 1)
      lamTh.vec.neighbor <- step.obj$lamTh.unique[lo:hi]
    }
    
    # Create grid for main optimization
    lamB.vec <- rep(lamB.vec.neighbor, each = length(lamTh.vec.neighbor))
    lamTh.vec <- rep(lamTh.vec.neighbor, times = length(lamB.vec.neighbor))
    
    lambda.Beta  <- sort(unique(step.obj$lamB.unique),  decreasing = TRUE)
    lambda.Theta <- sort(unique(step.obj$lamTh.unique), decreasing = TRUE)
    
  } else {
    if (verbose > 0) cat("  Lambda grid:          standard (dense)\n")
    
    lambda.obj <- InitLambda(
      lamB = lambda.beta, lamTh = lambda.theta, init.obj = init.obj,
      n = n, p = p, q = q,
      n.lamB = n.lambda.beta, n.lamTh = n.lambda.theta,
      lamB.min.ratio = lambda.beta.min.ratio, lamTh.min.ratio = lambda.theta.min.ratio,
      lamB.scale.factor = NULL, lamTh.scale.factor = NULL
    )
    
    lamB.vec <- lambda.obj$lamB.vec
    lamTh.vec <- lambda.obj$lamTh.vec
    
    lambda.Beta  <- sort(unique(lamB.vec),  decreasing = TRUE)
    lambda.Theta <- sort(unique(lamTh.vec), decreasing = TRUE)
  }
  
  if (verbose > 0) {
    cat(sprintf("  Lambda grid size:     %d x %d = %d models\n", 
                length(unique(lamB.vec)), length(unique(lamTh.vec)), length(lamTh.vec)))
    cat("-------------------------------------------------------------\n")
  }
  
  ## ---------- Cross-validation ----------
  if (verbose > 0) {
    cat("\n--- Cross-Validation Progress -------------------------------\n")
    cat(sprintf("  Fitting %d lambda pair%s across %d fold%s\n",
                length(lamTh.vec),
                if (length(lamTh.vec) == 1) "" else "s",
                kfold,
                if (kfold == 1) "" else "s"))
    if (parallel && !is.null(cl)) {
      # Check cluster health
      cluster_check <- tryCatch({
        parallel::clusterEvalQ(cl, 1)
        TRUE
      }, error = function(e) FALSE)
      
      if (!cluster_check) {
        warning("Parallel cluster appears unhealthy, falling back to sequential")
        parallel <- FALSE
        cl <- NULL
      } else {
        nworkers <- length(cl)
        cat(sprintf("  Execution: parallel (%d worker%s)\n", 
                    nworkers, if (nworkers == 1) "" else "s"))
      }
    } else {
      cat("  Execution: sequential\n")
    }
    cat("-------------------------------------------------------------\n\n")
  }
  
  if (!parallel || is.null(cl)) {
    ## ---------- Sequential execution ----------
    err <- matrix(0, kfold, length(lamTh.vec))
    
    # Memory-efficient warm start allocation
    if (length(lamTh.vec) > 100) {
      Beta.warm <- vector("list", length(lamTh.vec))
    } else {
      Beta.warm <- lapply(seq_along(lamTh.vec), function(x) matrix(0, p, q))
    }
    Beta.warm.fold <- lapply(seq_len(kfold), function(k) Beta.warm)
    
    for (k in seq_len(kfold)) {
      if (verbose > 0) cat(sprintf("  Fold %d of %d\n", k, kfold))
      
      foldind <- ind[(1 + floor((k - 1) * n / kfold)) : floor(k * n / kfold)]
      X.tr <- X[-foldind, , drop = FALSE]
      Y.tr <- Y[-foldind, , drop = FALSE]
      X.va <- X[foldind, , drop = FALSE]
      Y.va <- Y[foldind, , drop = FALSE]
      n.tr <- nrow(X.tr); n.va <- nrow(X.va)
      
      # Missingness (training / validation)
      if (is.null(rho)) {
        rho.vec    <- apply(Y.tr, 2, function(x) sum(is.na(x)) / n.tr)
        rho.vec.va <- apply(Y.va, 2, function(x) sum(is.na(x)) / n.va)
      } else {
        rho.vec <- rho.vec.va <- init.obj$rho.vec
      }
      
      # Rho matrices
      rho.mat.1     <- matrix(1 - rho.vec,    nrow = p, ncol = q, byrow = TRUE)              # p x q
      rho.mat.2     <- outer(1 - rho.vec, 1 - rho.vec, `*`); diag(rho.mat.2) <- 1 - rho.vec  # q x q
      rho.mat.va.1  <- matrix(1 - rho.vec.va, nrow = p, ncol = q, byrow = TRUE)
      rho.mat.va.2  <- outer(1 - rho.vec.va, 1 - rho.vec.va, `*`); diag(rho.mat.va.2) <- 1 - rho.vec.va
      
      # Scaling
      mx.tr <- apply(X.tr, 2, robust_mean)
      mx.va <- apply(X.va, 2, robust_mean)
      my.tr <- apply(Y.tr, 2, robust_mean, na.rm = TRUE)
      my.va <- apply(Y.va, 2, robust_mean, na.rm = TRUE)
      
      X.tr <- robust_scale(X.tr, center = mx.tr, scale = init.obj$sdx)
      X.va <- robust_scale(X.va, center = mx.va, scale = init.obj$sdx)
      Y.tr <- robust_scale(Y.tr, center = my.tr, scale = init.obj$sdy)
      Y.va <- robust_scale(Y.va, center = my.va, scale = init.obj$sdy)
      
      Z.tr <- Y.tr; Z.tr[is.na(Z.tr)] <- 0
      Z.va <- Y.va; Z.va[is.na(Z.va)] <- 0
      
      # Precompute train/validation info
      info <- list()
      info$n <- n.tr; info$p <- p; info$q <- q
      info$xtx <- crossprod(X.tr); info$xtx <- make_positive_definite(info$xtx)
      info$til.xty <- crossprod(X.tr, Z.tr) / rho.mat.1
      
      xtx.va     <- crossprod(X.va); xtx.va <- make_positive_definite(xtx.va)
      til.xty.va <- crossprod(X.va, Z.va) / rho.mat.va.1
      til.ytx.va <- t(til.xty.va)
      til.yty.va <- crossprod(Z.va) / rho.mat.va.2; til.yty.va <- make_positive_definite(til.yty.va)
      
      ## ---- Warm-start path over unique lamB as initial estimates ----
      lamB.uniq <- unique(lamB.vec)
      B.init <- lapply(seq_along(lamB.uniq), function(x) matrix(0, p, q))
      Beta <- matrix(0, p, q)
      
      residual.cov <- (crossprod(Z.tr) / rho.mat.2) / (n.tr - 1)
      residual.cov <- make_positive_definite(residual.cov)
      
      lamTh.mat <- lamTh.vec[1] * init.obj$lamTh.pf
      lamTh.mat[lamTh.mat == 0] <- 1e-12
      diag(lamTh.mat) <- 1e-12
      
      Theta.out <- tryCatch({
        run_glasso(S = residual.cov, rho = lamTh.mat, 
                   thr = min(0.001, theta.tol * 100), 
                   maxIt = max(1000, theta.max.iter / 10))
      }, error = function(e) {
        if (verbose > 0) cat("  Warning: glasso failed in warm start; using diagonal approximation\n")
        list(wi = diag(pmin(pmax(1/diag(residual.cov), eps), 1/eps)))
      })
      Theta <- make_symmetric(Theta.out$wi)
      
      for (i in seq_along(lamB.uniq)) {
        lamB.mat <- lamB.uniq[i] * init.obj$lamB.pf
        lamB.mat[lamB.mat == 0] <- 1e-12
        
        Beta <- updateBeta(Theta = Theta, B0 = Beta,
                           n = info$n, xtx = info$xtx, xty = info$til.xty, lamB = lamB.mat,
                           eta = eta, tolin = min(0.001, beta.tol * 10), 
                           maxitrin = max(1000, ceiling(beta.max.iter / 10)))$Bhat
        E.tr <- Y.tr - X.tr %*% Beta
        residual.cov <- getResCov(E.tr, n.tr, rho.mat.2)
        
        # Update Theta
        Theta.out <- tryCatch({
          run_glasso(S = residual.cov, rho = lamTh.mat, 
                     thr = min(0.001, theta.tol * 100), 
                     maxIt = max(1000, theta.max.iter / 10))
        }, error = function(e) {
          list(wi = Theta)  # Keep previous Theta
        })
        Theta <- make_symmetric(Theta.out$wi)
        B.init[[i]] <- Beta
      }
      
      ## ---- Sequential training across the grid with early stopping ----
      info.update <- list()
      info.update$B.init <- B.init[[1]]
      E.tr <- Y.tr - X.tr %*% info.update$B.init
      info.update$residual.cov <- getResCov(E.tr, n.tr, rho.mat.2)
      Beta.thr.rescale <- beta.tol
      lamB.crt <- lamB.vec[1]
      lamTh.lb <- min(lamTh.vec)
      
      if (verbose == 1) { 
        pb <- txtProgressBar(min = 0, max = length(lamTh.vec), style = 3, width = 50, char = "=") 
      }
      
      for (i in seq_along(lamTh.vec)) {
        # Update warm start if lamB changes
        if (lamB.vec[i] < lamB.crt) {
          info.update$B.init <- B.init[[which(lamB.uniq == lamB.vec[i])]]
          E.tr <- Y.tr - X.tr %*% info.update$B.init
          info.update$residual.cov <- getResCov(E.tr, n.tr, rho.mat.2)
          Beta.thr.rescale <- beta.tol
        }
        
        info.update$B.init <- update.missoNet(
          lamTh = lamTh.vec[i], lamB = lamB.vec[i],
          Beta.maxit = beta.max.iter, Beta.thr = Beta.thr.rescale,
          Theta.maxit = theta.max.iter, Theta.thr = theta.tol,
          verbose = verbose, eps = eps, eta = eta, init.obj = init.obj,
          info = info, info.update = info.update, under.cv = TRUE
        )
        
        err[k, i] <- getEvalErr(yty = til.yty.va, ytx = til.ytx.va,
                                xty = til.xty.va, xtx = xtx.va,
                                Beta = info.update$B.init, n = n.va)
        
        # Allocate warm start memory only when needed
        if (is.null(Beta.warm.fold[[k]][[i]])) {
          Beta.warm.fold[[k]][[i]] <- info.update$B.init
        } else {
          Beta.warm.fold[[k]][[i]] <- info.update$B.init
        }
        
        lamB.crt <- lamB.vec[i]
        
        # Early stopping logic (only in adaptive mode)
        # TODO: refine criteria
        
        if (lamTh.vec[i] > lamTh.lb) {
          E.tr <- Y.tr - X.tr %*% info.update$B.init
          info.update$residual.cov <- getResCov(E.tr, n.tr, rho.mat.2)
          Beta.thr.rescale <- max(beta.tol, beta.tol * norm(info.update$B.init, "F"))
        }
        
        if (verbose == 1 && exists("pb")) setTxtProgressBar(pb, i)
      }
      if (verbose == 1 && exists("pb")) close(pb)
    }
    
    # Average across folds
    for (k in seq_len(kfold)) {
      for (i in seq_along(lamTh.vec)) {
        if (is.null(Beta.warm[[i]])) Beta.warm[[i]] <- matrix(0, p, q)
        Beta.warm[[i]] <- Beta.warm[[i]] + Beta.warm.fold[[k]][[i]]
      }
    }
    Beta.warm <- lapply(Beta.warm, function(B) B / kfold)
    
  } else {
    ## ---------- Parallel execution ----------
    if (verbose > 0) {
      pbapply::pboptions(type = "txt", style = 3, char = "=", txt.width = 50, use_lb = TRUE, nout = kfold)
    } else {
      pbapply::pboptions(type = "none", use_lb = TRUE)
    }
    
    par.out <- pbapply::pblapply(seq_len(kfold), function(k) {
      parWrapper(k = k, X = X, Y = Y,
                 init.obj = init.obj, rho = rho, ind = ind, kfold = kfold,
                 lamTh.vec = lamTh.vec, lamB.vec = lamB.vec,
                 Beta.maxit = beta.max.iter, Beta.thr = beta.tol,
                 Theta.maxit = theta.max.iter, Theta.thr = theta.tol,
                 eps = eps, eta = eta)
    }, cl = cl)
    
    err <- matrix(0, kfold, length(lamTh.vec))
    Beta.warm <- lapply(seq_along(lamTh.vec), function(x) matrix(0, p, q))
    for (k in seq_len(kfold)) {
      err[k, ] <- par.out[[k]]$err.fold
      for (i in seq_along(lamTh.vec)) {
        Beta.warm[[i]] <- Beta.warm[[i]] + par.out[[k]]$Beta.warm.fold[[i]]
      }
    }
    Beta.warm <- lapply(Beta.warm, function(B) B / kfold)
    rm(par.out)
  }
  
  if (verbose > 0) cat("\n-------------------------------------------------------------\n\n")
  
  ## ---------- Aggregate CV errors ----------
  err.cv  <- colSums(err) / kfold
  err.sd  <- apply(err, 2, sd) / sqrt(kfold)
  err.up  <- err.cv + err.sd
  err.low <- err.cv - err.sd
  
  ## ---------- Select lambda.min and refit ----------
  cv.min   <- which.min(err.cv)
  lamTh.min <- lamTh.vec[cv.min]
  lamB.min  <- lamB.vec[cv.min]
  
  if (verbose > 0) cat("> Refitting at lambda.min ...\n\n")
  
  out.min <- update.missoNet(
    X = X, Y = Y, lamTh = lamTh.min, lamB = lamB.min,
    Beta.maxit = max(1000, beta.max.iter), Beta.thr = min(1e-3, beta.tol),
    Theta.maxit = max(1000, theta.max.iter), Theta.thr = min(1e-3, theta.tol),
    verbose = verbose, eps = eps, eta = eta,
    info = NULL, info.update = NULL, under.cv = FALSE,
    init.obj = init.obj, B.init = Beta.warm[[cv.min]]
  )
  out.min$lambda.beta   <- lamB.min
  out.min$lambda.theta  <- lamTh.min
  out.min$lambda.beta.idx  <- which(lambda.Beta  == lamB.min)[1]
  out.min$lambda.theta.idx <- which(lambda.Theta == lamTh.min)[1]
  out.min$Beta <- sweep(out.min$Beta / init.obj$sdx, 2, init.obj$sdy, `*`)
  out.min$mu   <- as.numeric(init.obj$my - crossprod(out.min$Beta, init.obj$mx))
  
  if (relax.net) {
    out.min$relax.net <- tryCatch({
      relax.glasso(X = X, Y = Y, init.obj = init.obj, est = out.min, eps = eps,
                   Theta.thr = min(1e-3, theta.tol * 100), Theta.maxit = max(1000, theta.max.iter / 10))
    }, error = function(e) {
      if (verbose > 0) warning("Relaxed fit failed at lambda.min: ", e$message)
      NULL
    })
  } else {
    out.min$relax.net <- NULL
  }
  
  ## ---------- Optional 1SE fits ----------
  outB.1se <- NULL
  outTh.1se <- NULL
  if (compute.1se) {
    if (verbose > 0) cat("> Computing 1SE models ...\n\n")
    
    # 1SE for Beta (fix Theta at min)
    new.lamB.vec <- lamB.vec[lamTh.vec == lamTh.min]
    new.err.cv   <- err.cv[lamTh.vec == lamTh.min]
    
    candB <- new.lamB.vec[new.err.cv <= err.up[cv.min]]
    lamB.1se <- if (length(candB)) max(candB) else lamB.min
    
    if (lamB.1se != lamB.min) {
      idx.1se.B <- which((lamTh.vec == lamTh.min) & (lamB.vec == lamB.1se))[1]
      outB.1se <- update.missoNet(
        X = X, Y = Y, lamTh = lamTh.min, lamB = lamB.1se,
        Beta.maxit = max(1000, beta.max.iter), Beta.thr = min(1e-3, beta.tol),
        Theta.maxit = max(1000, theta.max.iter), Theta.thr = min(1e-3, theta.tol),
        verbose = verbose, eps = eps, eta = eta,
        info = NULL, info.update = NULL, under.cv = FALSE,
        init.obj = init.obj,
        B.init = if (idx.1se.B <= length(Beta.warm)) Beta.warm[[idx.1se.B]] else NULL
      )
      outB.1se$lambda.beta  <- lamB.1se
      outB.1se$lambda.theta <- lamTh.min
      outB.1se$lambda.beta.idx  <- which(lambda.Beta  == lamB.1se)[1]
      outB.1se$lambda.theta.idx <- which(lambda.Theta == lamTh.min)[1]
      outB.1se$Beta <- sweep(outB.1se$Beta / init.obj$sdx, 2, init.obj$sdy, `*`)
      outB.1se$mu   <- as.numeric(init.obj$my - crossprod(outB.1se$Beta, init.obj$mx))
      if (relax.net) {
        outB.1se$relax.net <- tryCatch({
          relax.glasso(X = X, Y = Y, init.obj = init.obj, est = outB.1se, eps = eps,
                       Theta.thr = min(1e-3, theta.tol * 100), Theta.maxit = max(1000, theta.max.iter / 10))
        }, error = function(e) {
          if (verbose > 0) warning("Relaxed fit failed at lambda.1se.beta: ", e$message)
          NULL
        })
      } else {
        outB.1se$relax.net <- NULL
      }
    } else {
      if (verbose > 0) {
        cat("Warning: lambda.beta (1SE) equals lambda.beta (min), stop refitting\n")
      }
    }
    
    # 1SE for Theta (fix Beta at min)
    new.lamTh.vec <- lamTh.vec[lamB.vec == lamB.min]
    new.err.cv    <- err.cv[lamB.vec == lamB.min]
    
    candTh <- new.lamTh.vec[new.err.cv <= err.up[cv.min]]
    lamTh.1se <- if (length(candTh)) max(candTh) else lamTh.min
    
    if (lamTh.1se != lamTh.min) {
      idx.1se.Th <- which((lamTh.vec == lamTh.1se) & (lamB.vec == lamB.min))[1]
      outTh.1se <- update.missoNet(
        X = X, Y = Y, lamTh = lamTh.1se, lamB = lamB.min,
        Beta.maxit = max(1000, beta.max.iter), Beta.thr = min(1e-3, beta.tol),
        Theta.maxit = max(1000, theta.max.iter), Theta.thr = min(1e-3, theta.tol),
        verbose = verbose, eps = eps, eta = eta,
        info = NULL, info.update = NULL, under.cv = FALSE,
        init.obj = init.obj,
        B.init = if (idx.1se.Th <= length(Beta.warm)) Beta.warm[[idx.1se.Th]] else NULL
      )
      outTh.1se$lambda.beta  <- lamB.min
      outTh.1se$lambda.theta <- lamTh.1se
      outTh.1se$lambda.beta.idx  <- which(lambda.Beta  == lamB.min)[1]
      outTh.1se$lambda.theta.idx <- which(lambda.Theta == lamTh.1se)[1]
      outTh.1se$Beta <- sweep(outTh.1se$Beta / init.obj$sdx, 2, init.obj$sdy, `*`)
      outTh.1se$mu   <- as.numeric(init.obj$my - crossprod(outTh.1se$Beta, init.obj$mx))
      if (relax.net) {
        outTh.1se$relax.net <- tryCatch({
          relax.glasso(X = X, Y = Y, init.obj = init.obj, est = outTh.1se, eps = eps,
                       Theta.thr = min(1e-3, theta.tol * 100), Theta.maxit = max(1000, theta.max.iter / 10))
        }, error = function(e) {
          if (verbose > 0) warning("Relaxed fit failed at lambda.1se.theta: ", e$message)
          NULL
        })
      } else {
        outTh.1se$relax.net <- NULL
      }
    } else {
      if (verbose > 0) {
        cat("Warning: lambda.theta (1SE) equals lambda.theta (min), stop refitting\n")
      }
    }
  }
  
  ## ---------- Results summary ----------
  if (verbose > 0) {
    cat("\n--- Cross-Validation Results --------------------------------\n")
    cat(sprintf("  Optimal lambda.beta:   %.4e\n", lamB.min))
    cat(sprintf("  Optimal lambda.theta:  %.4e\n", lamTh.min))
    cat(sprintf("  Min CV error:          %.4f\n", min(err.cv)))
    
    # Sparsity information
    active_preds <- sum(rowSums(abs(out.min$Beta)) > 1e-8)
    active_edges <- sum(abs(out.min$Theta[upper.tri(out.min$Theta, diag = FALSE)]) > 1e-8)
    cat(sprintf("  Active predictors:     %d / %d (%.1f%%)\n", 
                active_preds, p, 100 * active_preds / p))
    cat(sprintf("  Network edges:         %d / %d (%.1f%%)\n",
                active_edges, q * (q - 1) / 2, 100 * active_edges / (q * (q - 1) / 2)))
    cat("-------------------------------------------------------------\n\n")
    
    cat("=============================================================\n")
  }
  
  ## ---------- Output object ----------
  cv.obj <- list(
    est.min = out.min,
    est.1se.beta = outB.1se,
    est.1se.theta = outTh.1se,
    rho = init.obj$rho.vec,
    kfold = kfold,
    fold.index = ind,
    lambda.beta.seq  = lambda.Beta,
    lambda.theta.seq = lambda.Theta,
    penalize.diagonal = init.obj$penalize_diagonal,
    beta.pen.factor = init.obj$lamB.pf,
    theta.pen.factor = init.obj$lamTh.pf,
    
    param_set = list(
      n = n, p = p, q = q,
      standardize = standardize,
      standardize.response = standardize.response,
      cv.errors.mean = err.cv,
      cv.errors.upper = err.up,
      cv.errors.lower = err.low,
      cv.grid.beta = lamB.vec,
      cv.grid.theta = lamTh.vec
    )
  )
  class(cv.obj) <- c("missoNet", class(cv.obj))
  return(cv.obj)
}

Try the missoNet package in your browser

Any scripts or data that you put into this service are public.

missoNet documentation built on Sept. 9, 2025, 5:55 p.m.