R/relax.glasso.R

Defines functions relax.glasso

#' Relaxed graphical lasso for precision matrix estimation
#'
#' @description
#' Performs relaxed graphical lasso estimation where the sparsity pattern is
#' determined by a previous estimate. Non-zero elements are re-estimated without
#' penalty while zero elements are heavily penalized to remain zero.
#'
#' @param X Predictor matrix (n x p)
#' @param Y Response matrix with potential missing values (n x q)
#' @param init.obj Initialization object containing rho.vec (missing rates)
#' @param est Previous estimation object containing mu, Beta, and Theta
#' @param eps Minimum eigenvalue threshold for positive definiteness (default: 1e-8)
#' @param Theta.thr Convergence threshold for Theta estimation (default: 1e-3)
#' @param Theta.maxit Maximum iterations for Theta estimation (default: 1000)
#'
#' @return Symmetric positive definite precision matrix (q x q)
#'
#' @details
#' This function implements a two-stage approach:
#' 1. Identifies zero elements in the initial Theta estimate
#' 2. Re-estimates Theta with heavy penalties on identified zeros
#'
#' @noRd
relax.glasso <- function(X, Y, init.obj, est, 
                         eps = 1e-8, 
                         Theta.thr = 1e-3, 
                         Theta.maxit = 1000) {
  
  # -------------------- Input Validation --------------------
  if (!is.matrix(X) || !is.matrix(Y)) {
    stop("X and Y must be matrices")
  }
  
  n <- nrow(X)
  p <- ncol(X) 
  q <- ncol(Y)
  
  if (nrow(Y) != n) {
    stop("relax.net: X and Y must have the same number of rows")
  }
  
  if (is.null(est$mu) || is.null(est$Beta) || is.null(est$Theta)) {
    stop("relax.net: Estimation must contain mu, Beta, and Theta components")
  }
  
  if (length(est$mu) != q) {
    stop("relax.net: Length of intercept mu must equal number of columns in Y")
  }
  
  if (!all(dim(est$Beta) == c(p, q))) {
    stop("relax.net: Dimensions of Beta must be (ncol(X), ncol(Y))")
  }
  
  if (!all(dim(est$Theta) == c(q, q))) {
    stop("relax.net: Theta must be a q x q matrix")
  }
  
  # -------------------- Setup --------------------
  
  rho.vec <- init.obj$rho.vec
  if (is.null(rho.vec)) {
    warning("relax.net: Missing rates not found in init.obj, assuming no missing data")
    rho.vec <- rep(0, q)
  }
  
  if (length(rho.vec) != q) {
    stop(sprintf("relax.net: Length of rho.vec (%d) must equal number of responses (%d)", 
                 length(rho.vec), q))
  }
  
  if (any(rho.vec < 0) || any(rho.vec >= 1)) {
    stop("relax.net: Missing rates must be in [0, 1)")
  }
  
  obs.prob <- 1 - rho.vec
  obs.prob <- pmax(obs.prob, eps)
  
  rho.mat <- outer(obs.prob, obs.prob, `*`)
  diag(rho.mat) <- obs.prob
  
  # -------------------- Compute Residuals --------------------
  
  Y.centered <- sweep(Y, 2, est$mu, FUN = "-", check.margin = FALSE)
  
  E <- Y.centered - X %*% est$Beta
  
  # -------------------- Compute Residual Covariance --------------------
  
  residual.cov <- tryCatch(
    getResCov(E, n, rho.mat),
    error = function(e) {
      warning("relax.net: Failed to compute residual covariance, using simple estimate: ", e$message)
      # Fallback to simple pairwise complete covariance
      cov_mat <- cov(E, use = "pairwise.complete.obs")
      cov_mat[!is.finite(cov_mat)] <- 0
      diag(cov_mat) <- pmax(diag(cov_mat), eps)
      return(cov_mat)
    }
  )
  
  if (!is.matrix(residual.cov) || !all(dim(residual.cov) == c(q, q))) {
    stop("relax.net: Invalid residual covariance matrix")
  }
  
  # -------------------- Identify Zero Pattern --------------------
  
  # Threshold for determining zeros (adaptive based on scale)
  zero.thr <- max(1e-8, 1e-6 * max(abs(est$Theta)))
  
  # Find zero elements in upper triangle
  theta.upper <- upper.tri(est$Theta)
  zero.mask <- (abs(est$Theta) <= zero.thr) & theta.upper
  zero.indices <- which(zero.mask, arr.ind = TRUE)
  
  # -------------------- Setup Penalty Matrix --------------------
  
  # Initialize penalty matrix with small baseline penalty
  lamTh.mat <- matrix(1e-12, nrow = q, ncol = q)
  
  # Large penalty for elements identified as zero
  BIG <- 1e12
  
  if (nrow(zero.indices) > 0) {
    for (k in seq_len(nrow(zero.indices))) {
      i <- zero.indices[k, 1]
      j <- zero.indices[k, 2]
      lamTh.mat[i, j] <- lamTh.mat[j, i] <- BIG
    }
    
    if (nrow(zero.indices) > q * (q - 1) / 4) {
      message(sprintf("relax.net: Note - %d/%d off-diagonal elements set to zero", 
                      nrow(zero.indices), q * (q - 1) / 2))
    }
  }
  
  # Never penalize diagonal
  diag(lamTh.mat) <- 0
  
  # -------------------- Run Relaxed Glasso --------------------
  
  Theta.out <- tryCatch({
    result <- run_glasso(
      S = residual.cov, 
      rho = lamTh.mat,
      thr = Theta.thr, 
      maxIt = Theta.maxit
    )
    
    if (is.null(result) || is.null(result$wi)) {
      stop("relax.net: Glasso returned invalid result")
    }
    
    Theta.new <- result$wi
    
    # Post-process
    if (nrow(zero.indices) > 0) {
      for (k in seq_len(nrow(zero.indices))) {
        i <- zero.indices[k, 1]
        j <- zero.indices[k, 2]
        Theta.new[i, j] <- Theta.new[j, i] <- 0
      }
    }
    
    Theta.new
    
  }, error = function(e) {
    warning("relax.net: Glasso failed, using diagonal approximation: ", e$message)
    
    # Fallback: diagonal approximation
    diag.vals <- diag(residual.cov)
    diag.vals[diag.vals < eps] <- eps
    diag.vals[diag.vals > 1/eps] <- 1/eps
    
    Theta.fallback <- diag(1 / diag.vals)
    
    # Try to preserve some structure from original estimate
    if (!is.null(est$Theta)) {
      # Keep relative scaling from original
      scale.factor <- median(diag(est$Theta)) / median(diag(Theta.fallback))
      scale.factor <- max(0.1, min(10, scale.factor))  # Bound the scaling
      Theta.fallback <- Theta.fallback * scale.factor
    }
    
    return(Theta.fallback)
  })
  
  # -------------------- Final Processing --------------------
  
  Theta.final <- make_symmetric(Theta.out)
  
  if (!all(is.finite(Theta.final))) {
    warning("relax.net: Non-finite values in Theta, returning regularized diagonal")
    Theta.final <- diag(q)
  }
  
  # Check eigenvalues and regularize if needed
  min.eig <- min(eigen(Theta.final, symmetric = TRUE, only.values = TRUE)$values)
  if (min.eig < eps) {
    reg.amount <- eps - min.eig + eps
    Theta.final <- Theta.final + reg.amount * diag(q)
    message(sprintf("relax.net: Added %.2e to diagonal for positive definiteness", reg.amount))
  }
  
  return(Theta.final)
}

Try the missoNet package in your browser

Any scripts or data that you put into this service are public.

missoNet documentation built on Sept. 9, 2025, 5:55 p.m.