R/reduced_rank_regression.R

Defines functions cca_rrr_cv_folds cca_rrr_cv cca_rrr solve_rrr_cvxr solve_rrr_admm

Documented in cca_rrr cca_rrr_cv

# Required libraries
library(dplyr)
library(magrittr)
library(foreach)



# Helper: ADMM-based group sparse solver
solve_rrr_admm <- function(X, tilde_Y, Sx, lambda, rho=1, niter=10, thresh, verbose = FALSE, thresh_0 = 1e-6) {
  p <- ncol(X); q <- ncol(tilde_Y)
  n <- nrow(X)
  Sx_tot <- Sx
  
  invSx <- solve(Sx_tot + rho * diag(p))
  
  U <- Z <- matrix(0, p, q)
  
  prod_xy <- crossprod(X, tilde_Y) / n
  invSx <- solve(Sx_tot + rho * diag(p))
  for (i in seq_len(niter)) {
    U_old <- U; Z_old <- Z
    B <- invSx %*% (prod_xy + rho * (Z - U))
    Z <- B + U
    norm_col <- sqrt(rowSums(Z^2))
    shrinkage <- pmax(0, 1 - (lambda / rho) / norm_col)
    shrinkage[is.nan(shrinkage)] <- 0
    Z <- sweep(Z, 1, shrinkage, '*')
    U <- U + B - Z
    if (verbose) cat("ADMM iter", i, "Primal: ", norm(Z - B), "Dual:", norm(Z_old - Z), "\n")
    if (max(c(norm(Z - B) / sqrt(p), norm(Z_old - Z) / sqrt(p))) < thresh) break
  }
  B_opt <- B
  
  B_opt[abs(B_opt) < thresh_0] <- 0
  return(B_opt)
}


# Helper: CVXR-based group sparse solver
solve_rrr_cvxr <- function(X, tilde_Y, lambda, thresh_0=1e-6) {
  if (!requireNamespace("CVXR", quietly = TRUE)) {
    stop("Package 'CVXR' must be installed to use the CVX solver.",
         call. = FALSE)
  }

  p <- ncol(X); q <- ncol(tilde_Y)
  n <- nrow(X)
  B <- CVXR::Variable(p, q)
  
  objective <- CVXR::Minimize(1 / n * CVXR::sum_squares(tilde_Y - X %*% B) + lambda * sum(CVXR::norm2(B, axis = 1)))
  result <- CVXR::solve(CVXR::Problem(objective))
  B_opt <- result$getValue(B)
  
  B_opt[abs(B_opt) < thresh_0] <- 0
  return(B_opt)
}


#' Canonical Correlation Analysis via Reduced Rank Regression (RRR)
#'
#' Estimates canonical directions using various RRR solvers and penalties.
#'
#' @param X Matrix of predictors.
#' @param Y Matrix of responses.
#' @param Sx Optional X covariance matrix.
#' @param Sy Optional Y covariance matrix.
#' @param lambda Regularization parameter.
#' @param r Rank of the solution.
#' @param highdim Boolean for high-dimensional regime.
#' @param solver Solver type: "rrr", "CVX", or "ADMM".
#' @param LW_Sy Whether to use Ledoit-Wolf shrinkage for Sy.
#' @param standardize Logical; should X and Y be scaled.
#' @param rho ADMM parameter.
#' @param niter Maximum number of iterations for ADMM.
#' @param thresh Convergence threshold.
#' @param thresh_0 For the ADMM solver: Set entries whose absolute value is below this to 0 (default 1e-6).
#' @param verbose Logical for verbose output.
#'
#' @return A list with elements:
#' \itemize{
#'   \item U: Canonical direction matrix for X (p x r)
#'   \item V: Canonical direction matrix for Y (q x r)
#'   \item cor: Canonical covariances
#'   \item loss: The prediction error 1/n * || XU - YV ||^2
#' }

#' @export
cca_rrr <- function(X, Y, Sx=NULL, Sy=NULL,
                    lambda = 0, 
                    r, highdim=TRUE, 
                    solver="ADMM",
                    LW_Sy = TRUE,
                    standardize = TRUE,
                    rho=1,
                    niter=1e4,
                    thresh = 1e-4, thresh_0 = 1e-6,
                    verbose=FALSE) {
  
  n <- nrow(X)
  p <- ncol(X)
  q <- ncol(Y)
  
  X <- if (standardize) scale(X) else X #scale(X, scale = FALSE)
  Y <- if (standardize) scale(Y) else Y # scale(Y, scale = FALSE)
  
  if (is.null(Sx)) Sx <- crossprod(X) / n
  if (is.null(Sy)) {
    Sy <- crossprod(Y) / n
    if (LW_Sy) Sy <- as.matrix(corpcor::cov.shrink(Y, verbose=verbose))
  }
  
  sqrt_inv_Sy <- compute_sqrt_inv(Sy)
  tilde_Y <- Y %*% sqrt_inv_Sy
  Sx_tot <- Sx
  Sxy <- crossprod(X, tilde_Y) / n 
  
  if (!highdim) {
    if(verbose){print("Not using highdim")}
    B_OLS <- solve(Sx_tot, Sxy)
    sqrt_Sx <- compute_sqrt(Sx)
    sqrt_inv_Sx <- compute_sqrt_inv(Sx)
    sol <- svd(sqrt_Sx %*% B_OLS)
    V <- sqrt_inv_Sy %*% sol$v[, 1:r]
    U <- sqrt_inv_Sx %*% sol$u[, 1:r]
  } else {
    if (solver == "CVX") {
      if(verbose){ print("Using CVX solver")}
      B_opt <- solve_rrr_cvxr(X, tilde_Y, lambda, thresh_0=thresh_0) 
    } else if (solver == "ADMM") {
      if(verbose){print("Using ADMM solver")}
      B_opt <- solve_rrr_admm(X, tilde_Y, Sx, lambda=lambda, rho=rho, niter=niter, thresh=thresh, thresh_0=thresh_0,
                              verbose = FALSE)
    } else {
      if(verbose){print("Using gglasso solver")}
        if (!requireNamespace("rrpack", quietly = TRUE)) {
          stop("Package 'rrpack' must be installed to use the rrpack solver.",
              call. = FALSE)
        }
      fit <- rrpack::cv.srrr(tilde_Y, X, nrank = r, method = "glasso", nfold = 2,
                             modstr = list("lamA" = rep(lambda, 10), "nlam" = 10))
      B_opt <- fit$coef
    }
    
    B_opt[abs(B_opt) < thresh_0] <- 0
    active_rows <- which(rowSums(B_opt^2) > 0)
    if (length(active_rows) > r - 1) {
      sqrt_Sx <- compute_sqrt(Sx[active_rows, active_rows])
      sol <- svd(sqrt_Sx %*% B_opt[active_rows, ])
      V <- sqrt_inv_Sy %*% sol$v[, 1:r]
      inv_D <- diag(sapply(sol$d[1:r], function(d) ifelse(d < 1e-4, 0, 1 / d)))
      U <- B_opt %*% sol$v[, 1:r] %*% inv_D
    } else {
      U <- matrix(0, p, r)
      V <- matrix(0, q, r)
    }
  }
  
  loss <- mean((Y %*% V - X %*% U)^2)
  canon_corr <- sapply(seq_len(r), function(i) stats::cov(X %*% U[, i], Y %*% V[, i]))
  
  list(U = U, V = V, loss = loss, cor = canon_corr)
}




#' Cross-validated Canonical Correlation Analysis via RRR
#'
#' Performs cross-validation to select optimal lambda, fits CCA_rrr.
#' Canonical Correlation Analysis via Reduced Rank Regression (RRR)
#' @param X Matrix of predictors.
#' @param Y Matrix of responses.
#' @param r Rank of the solution.
#' @param kfolds Number of folds for cross-validation.
#' @param lambdas Sequence of lambda values for cross-validation.
#' @param parallelize Logical; should cross-validation be parallelized?
#' @param solver Solver type: "rrr", "CVX", or "ADMM".
#' @param LW_Sy Whether to use Ledoit-Wolf shrinkage for Sy.
#' @param standardize Logical; should X and Y be scaled.
#' @param rho ADMM parameter.
#' @param niter Maximum number of iterations for ADMM.
#' @param thresh Convergence threshold.
#' @param verbose Logical for verbose output.
#' @param thresh_0 tolerance for declaring entries non-zero
#' @param nb_cores Number of cores to use for parallelization (default is all available cores minus 1)
#'
#' @return A list with elements:
#' \itemize{
#'   \item U: Canonical direction matrix for X (p x r)
#'   \item V: Canonical direction matrix for Y (q x r)
#'   \item lambda: Optimal regularisation parameter lambda chosen by CV
#'   \item rmse: Mean squared error of prediction (as computed in the CV)
#'   \item cor: Canonical correlations
#' }
#' @importFrom foreach foreach %dopar%
#' @importFrom foreach foreach %do%
#' @export
cca_rrr_cv <- function(X, Y, 
                       r=2, 
                       lambdas=10^seq(-3, 1.5, length.out = 100),
                       kfolds=14,
                       solver="ADMM",
                       parallelize = FALSE,
                       LW_Sy = TRUE,
                       standardize=TRUE,
                       rho=1,
                       thresh_0=1e-6,
                       niter=1e4,
                       thresh = 1e-4, verbose=FALSE,
                       nb_cores = NULL) {
  
  X <- if (standardize) scale(X) else X #scale(X, scale = FALSE)
  Y <- if (standardize) scale(Y) else Y #scale(Y, scale = FALSE)
  n <- nrow(X)
  Sx = matmul(t(X), X) / n
  Sy <- if (LW_Sy) as.matrix(corpcor::cov.shrink(Y, verbose = FALSE)) else crossprod(Y) / n
  cv_function <- function(lambda) {
    #print(lambda)
    cca_rrr_cv_folds(X, Y, Sx=Sx, Sy=NULL, kfolds=kfolds, 
                     LW_Sy = LW_Sy,
                     lambda=lambda, r=r, solver=solver, 
                     standardize=FALSE, rho=rho, niter=niter, thresh=thresh,
                     thresh_0=thresh_0)
  }
  
  if (parallelize && solver %in% c("CVX", "CVXR", "ADMM")) {

    if (!requireNamespace("doParallel", quietly = TRUE)) {
    stop("Package 'doParallel' must be installed to use the parallelization option.",
         call. = FALSE)
    }

    if (!requireNamespace("crayon", quietly = TRUE)) {
    stop("Package 'crayon' must be installed to use the parallelization option.",
         call. = FALSE)
    }

    # --- GRACEFUL PARALLEL SETUP ---
      cl <- setup_parallel_backend(nb_cores)
      
      if (!is.null(cl)) {
        # If the cluster was created successfully, register it and plan to stop it
        if (verbose){
          cat(crayon::green("Parallel backend successfully registered.\n"))
        }
        doParallel::registerDoParallel(cl)
        on.exit(parallel::stopCluster(cl), add = TRUE)
      } else {
        # If setup_parallel_backend returned NULL, print a warning and proceed serially
        warning("All parallel setup attempts failed. Proceeding in serial mode.", immediate. = TRUE)
        parallelize <- FALSE # Ensure %dopar% runs serially
    }
   }

  if (parallelize && solver %in% c("CVX", "CVXR", "ADMM")) {
    if (solver  %in% c("CVX", "CVXR" )){
        if (!requireNamespace("CVXR", quietly = TRUE)) {
         stop("Package 'CVXR' must be installed to use the CVXR/CVX solver.",
         call. = FALSE)
          }

      resultsx <- foreach(lambda=lambdas, .combine=rbind, .packages=c('CVXR','Matrix')) %dopar% {
      data.frame(lambda=lambda, rmse=cv_function(lambda))
    }

    }else{
      resultsx <- foreach(lambda=lambdas, .combine=rbind) %dopar% {
      data.frame(lambda=lambda, rmse=cv_function(lambda))
    }
    }
    
  } else {
    resultsx <- data.frame(lambda = lambdas)
    resultsx$rmse <- sapply(lambdas, cv_function)
    
  }
  
  resultsx <- resultsx %>% 
    dplyr::mutate(rmse = ifelse(is.na(rmse) | rmse == 0, 1e8, rmse)) %>%
    dplyr::filter(rmse > 1e-5)
  
  opt_lambda <- resultsx$lambda[which.min(resultsx$rmse)]
  opt_lambda <- ifelse(is.na(opt_lambda), 0.1, opt_lambda)

  final <- cca_rrr(X, Y, Sx=NULL, Sy=NULL, lambda=opt_lambda, r=r,
                   highdim=TRUE, solver=solver,
                   standardize=FALSE, LW_Sy=LW_Sy, rho=rho, niter=niter, 
                   thresh=thresh, thresh_0=thresh_0,verbose=verbose)


  return(list(U = final$U, 
       V = final$V,
       lambda = opt_lambda,
       #resultsx = resultsx,
       rmse = resultsx$rmse,
       cor = sapply(1:r, function(i) stats::cov(X %*% final$U[,i], Y %*% final$V[,i]))
       ))
}




cca_rrr_cv_folds <- function(X, Y, Sx, Sy, kfolds=5,
                             lambda=0.01,
                             r=2,
                             standardize=FALSE,
                             solver = "ADMM",
                             rho=1,
                             LW_Sy = TRUE,
                             niter=1e4,
                             thresh_0=1e-6,
                             thresh = 1e-4) {
  folds <- caret::createFolds(1:nrow(Y), k = kfolds, list = TRUE)
  
  rmse <- foreach(i = seq_along(folds), .combine = c) %do% { 
    
    n <- nrow(X)
    X_train <- X[-folds[[i]], ]; Y_train <- Y[-folds[[i]], ]
    X_val <- X[folds[[i]], ]; Y_val <- Y[folds[[i]], ]
    n_train <- n - nrow(X_val)
    
    if (is.null(Sx) == FALSE) {
      Sx_train <- (n * Sx - crossprod(X_val)) / n_train
    } else {
      Sx_train <- NULL
    }
    
    tryCatch({
      final <- cca_rrr(X_train, Y_train, Sx=Sx_train, Sy=NULL, highdim=TRUE,
                       lambda=lambda, r=r, solver=solver,
                       LW_Sy=LW_Sy, standardize=FALSE, rho=rho, niter=niter, 
                       thresh=thresh, thresh_0=thresh_0,
                       verbose=FALSE)
      mean((X_val %*% final$U - Y_val %*% final$V)^2)
    }, error = function(e) {
      message("Error in fold ", i, ": ", conditionMessage(e))
      NA
    })
  }
  
  if (mean(is.na(rmse)) == 1) return(1e8)
  mean(rmse, na.rm = TRUE)
}

Try the ccar3 package in your browser

Any scripts or data that you put into this service are public.

ccar3 documentation built on Sept. 16, 2025, 9:11 a.m.