R/InitParams.R

Defines functions InitParams

# =====================================================================
# InitParams.R - Parameter Initialization for missoNet
# =====================================================================

InitParams <- function(X, Y, rho, lamB.pf, lamTh.pf, pdiag, 
                       standardize, standardize.response) {
  
  # -------------------- Input Validation --------------------
  n <- nrow(X); p <- ncol(X); q <- ncol(Y)
  
  if (any(!is.finite(X))) stop("Predictor matrix 'X' must contain only finite values.")
  if (n < 2) stop("Need at least 2 observations")
  if (p < 1 || q < 1) stop("Need at least 1 predictor and 1 response")
  if (q == 1) {
    stop("missoNet requires multiple responses. For single response, use Lasso/cocoLasso.")
  }
  if (any(!is.finite(Y[!is.na(Y)]))) {
    stop("Response matrix 'Y' must contain only finite values or NA.")
  }
  
  # -------------------- Data Characteristics --------------------
  
  miss_pattern <- is.na(Y)
  miss_by_col <- colMeans(miss_pattern)
  miss_by_row <- rowMeans(miss_pattern)
  overall_miss <- mean(miss_pattern)
  
  # Effective sample size per response
  n_eff_vec <- n * (1 - miss_by_col)
  n_eff <- mean(n_eff_vec)
  
  # Data dimensionality ratios
  np_ratio <- n_eff / p
  nq_ratio <- n_eff / q
  
  # Sparsity detection (rough estimate from correlation)
  if (n > p && n > q) {
    cor_xy <- abs(cor(X, Y, use = "pairwise.complete.obs"))
    cor_yy <- abs(cor(Y, use = "pairwise.complete.obs"))
    diag(cor_yy) <- 0
    
    estimated_sparsity_B <- mean(cor_xy < 0.1, na.rm = TRUE)
    estimated_sparsity_Th <- mean(cor_yy < 0.1, na.rm = TRUE)
  } else {
    # Conservative estimates for high-dimensional cases
    estimated_sparsity_B <- 0.8
    estimated_sparsity_Th <- 0.7
  }
  
  # -------------------- Diagonal Penalization --------------------
  
  if (is.null(pdiag)) {
    if (nq_ratio < 1.5) {
      pdiag <- TRUE
      if (n > 10) {
        message("Auto-enabling diagonal penalization (n_eff/q = ", 
                round(nq_ratio, 2), " < 1.5)")
      }
    } else if (nq_ratio < 3 && overall_miss > 0.2) {
      pdiag <- TRUE
      if (n > 10) {
        message("Auto-enabling diagonal penalization due to high missingness")
      }
    } else {
      pdiag <- FALSE
    }
  }
  
  # -------------------- Robust Standardization --------------------
  
  mx <- apply(X, 2, robust_mean)
  my <- apply(Y, 2, robust_mean, na.rm = TRUE)
  
  if (isTRUE(standardize)) {
    sdx <- apply(X, 2, robust_sd)
    const_cols <- which(sdx <= .Machine$double.eps * 100)
    if (length(const_cols) > 0) {
      warning(length(const_cols), " near-constant predictor(s) detected")
      sdx[const_cols] <- 1
    }
  } else {
    sdx <- rep(1, p)
  }
  
  if (isTRUE(standardize.response)) {
    sdy <- apply(Y, 2, robust_sd, na.rm = TRUE)
    const_resp <- which(sdy <= .Machine$double.eps * 100)
    if (length(const_resp) > 0) {
      warning(length(const_resp), " near-constant response(s) detected")
      sdy[const_resp] <- 1
    }
  } else {
    sdy <- rep(1, q)
  }
  
  Xs <- robust_scale(X, mx, sdx)
  Ys <- robust_scale(Y, my, sdy)
  Z <- Ys; Z[is.na(Z)] <- 0
  
  # -------------------- Missing Probability Processing --------------------
  
  if (is.null(rho)) {
    rho.vec <- miss_by_col
    
    # Warn about problematic missingness patterns
    if (any(rho.vec > 0.5)) {
      warning("Extreme missingness (>50%) in ", sum(rho.vec > 0.5), 
              " response(s). Results may be unstable.")
    }
    
    # Check for systematic patterns
    if (sd(miss_by_row) / mean(miss_by_row) > 2 && mean(miss_by_row) > 0.1) {
      warning("Detected systematic missingness patterns across observations. ",
              "Verify MAR assumption.")
    }
  } else if (length(rho) == 1) {
    if (rho < 0 || rho >= 1) stop("'rho' must be in [0, 1)")
    rho.vec <- rep(rho, q)
  } else if (length(rho) == q) {
    if (any(rho < 0 | rho >= 1)) stop("All 'rho' elements must be in [0, 1)")
    rho.vec <- rho
  } else {
    stop("'rho' must be scalar or length ", q)
  }
  
  obs_prob <- floor_prob(1 - rho.vec, floor_val = max(1e-6, 1/n))
  
  # -------------------- Penalty Factor Processing --------------------
  
  process_penalty_matrix <- function(pf, target_dim, name, symmetric = FALSE) {
    if (is.null(pf)) {
      pf <- matrix(1, target_dim[1], target_dim[2])
      if (symmetric && !isTRUE(pdiag)) diag(pf) <- 0
      return(pf)
    }
    
    if (!is.matrix(pf) || !all(dim(pf) == target_dim)) {
      stop(sprintf("'%s' must be a %dx%d matrix", name, target_dim[1], target_dim[2]))
    }
    if (symmetric && !isSymmetric(pf)) {
      stop(sprintf("'%s' must be symmetric", name))
    }
    if (any(pf < 0, na.rm = TRUE)) {
      stop(sprintf("Negative values in '%s' not allowed", name))
    }
    
    # Remove outliers
    exclude <- (pf >= 1e10)      # Variables to exclude
    no_penalty <- (pf <= 1e-10)  # Variables with no penalty
    
    tmp <- pf
    tmp[exclude | no_penalty] <- NA
    
    if (symmetric) {
      # For symmetric matrices, use lower triangle
      sel <- lower.tri(tmp, diag = TRUE)
      vals <- tmp[sel]
    } else {
      vals <- as.vector(tmp)
    }
    
    if (sum(!is.na(vals)) > 0) {
      scale_factor <- median(vals, na.rm = TRUE)
    } else {
      scale_factor <- 1
    }
    
    # Normalization
    pf <- pf / scale_factor
    
    # Restore special values
    pf[exclude] <- 1e12
    pf[no_penalty] <- 0
    
    if (symmetric && !isTRUE(pdiag)) diag(pf) <- 0
    
    return(pf)
  }
  
  lamB.pf <- process_penalty_matrix(lamB.pf, c(p, q), "lamB.penalty.factor", FALSE)
  lamTh.pf <- process_penalty_matrix(lamTh.pf, c(q, q), "lamTh.penalty.factor", TRUE)
  
  # -------------------- Lambda Max Computation --------------------
  
  # Pairwise observation probabilities
  rho.mat.1 <- matrix(obs_prob, nrow = p, ncol = q, byrow = TRUE)
  rho.mat.2 <- outer(obs_prob, obs_prob, `*`); diag(rho.mat.2) <- obs_prob
  
  # Unbiased moments
  til.xty <- crossprod(Xs, Z) / rho.mat.1
  
  # Small sample correction factor
  ss_correction <- if (n < 50) n - 2 else n - 1
  
  # Initial residual covariance
  residual.cov <- make_positive_definite(crossprod(Z) / rho.mat.2 / ss_correction)
  
  compute_lambda_max_smart <- function(S, W, type = "theta") {
    valid_idx <- W > 1e-10 & W < 1e10 & is.finite(S) & is.finite(W)
    
    if (type == "theta") {
      # For Theta: off-diagonal only
      valid_idx <- valid_idx & (row(S) != col(S))
    }
    
    if (sum(valid_idx) == 0) return(1.0)
    
    candidates <- abs(S[valid_idx]) / W[valid_idx]
    
    # Use multiple percentiles for robustness
    percentiles <- c(0.95, 0.99, 1.0)
    lambda_candidates <- quantile(candidates, percentiles, na.rm = TRUE)
    
    # Adaptive selection based on sparsity
    if (type == "theta") {
      # Higher lambda for sparser problems
      lambda_max <- lambda_candidates[min(2, 1 + floor(2 * estimated_sparsity_Th))]
    } else {
      lambda_max <- lambda_candidates[min(2, 1 + floor(2 * estimated_sparsity_B))]
    }
    
    # Scale based on dimensionality
    if (type == "theta") {
      scale_factor <- (1 + log(max(1, q / n_eff))) * 1.2
    } else {
      scale_factor <- 1 + log(max(1, p * q / n_eff)) / 2
    }
    
    lambda_max <- lambda_max * scale_factor
    
    # Bounds for stability
    lambda_max <- max(lambda_max, 1e-4)
    lambda_max <- min(lambda_max, 100)
    
    return(as.numeric(lambda_max))
  }
  
  # Compute lambda max values
  lamTh.max <- compute_lambda_max_smart(residual.cov, lamTh.pf, type = "theta")
  lamB.max <- compute_lambda_max_smart(2 * til.xty / n, lamB.pf, type = "beta")
  
  # -------------------- Warm Start Strategy --------------------
  
  # Intelligent warm start decision
  warm.start <- TRUE  # Generally beneficial
  
  # Adaptive warm start strategy
  # if (p * q > 10000) {
  #   warm.start <- "aggressive"  # More aggressive for large problems
  # } else if (n < 20 || p * q < 50) {
  #   warm.start <- "conservative"  # Less aggressive for small problems
  # }
  
  # -------------------- Return Enhanced Init Object --------------------
  
  init.obj <- list(
    # Missingness information
    rho.vec = rho.vec,
    obs_prob = obs_prob,
    
    # Penalty factors
    lamB.pf = lamB.pf,
    lamTh.pf = lamTh.pf,
    
    # Lambda bounds
    lamB.max = lamB.max,
    lamTh.max = lamTh.max,
    
    # Standardization
    sdx = sdx, sdy = sdy,
    mx = mx, my = my,
    
    # Data characteristics
    n_eff = n_eff,
    n_eff_vec = n_eff_vec,
    np_ratio = np_ratio,
    nq_ratio = nq_ratio,
    overall_miss = overall_miss,
    estimated_sparsity_B = estimated_sparsity_B,
    estimated_sparsity_Th = estimated_sparsity_Th,
    
    # Warm start strategy
    warm.start = warm.start,
    
    # Adaptive flags
    penalize_diagonal = pdiag,
    high_dimensional = (np_ratio < 1 || nq_ratio < 1)
  )
  
  return(init.obj)
}

Try the missoNet package in your browser

Any scripts or data that you put into this service are public.

missoNet documentation built on Sept. 9, 2025, 5:55 p.m.