R/SZVD_ADMM.R

Defines functions vec_shrink SZVD_ADMM

Documented in SZVD_ADMM vec_shrink

#' Alternating Direction Method of Multipliers for SZVD
#'
#' Iteratively solves the problem
#' \deqn{\min(-1/2*x^TB^Tx + \gamma p(y): ||x||_2 \leq 1, DNx = y)}
#'
#' @param B Between class covariance matrix for objective (in space defined by N).
#' @param N basis matrix for null space of covariance matrix W.
#' @param D penalty dictionary/basis.
#' @param sols0 initial solutions sols0$x, sols0$y, sols0$z
#' @param pen_scal penalty scaling term.
#' @param gamma l1 regularization parameter
#' @param beta penalty term controlling the splitting constraint.
#' @param tol tol$abs = absolute error, tol$rel = relative error to be
#'        achieved to declare convergence of the algorithm.
#' @param maxits maximum number of iterations of the algorithm to run.
#' @param quiet toggles between displaying intermediate statistics.
#' @return \code{SZVD_ADMM} returns an object of \code{\link{class}} "\code{SZVD_ADMM}" including a list
#' with the following named components
#'
#' \describe{
#'   \item{\code{x,y,z}}{Iterates at termination.}
#'   \item{\code{its}}{Number of iterations required to converge.}
#'   \item{\code{errtol}}{Stopping error bound at termination}
#' }
#' @seealso Used by: \code{\link{SZVDcv}}.
#' @details
#' This function is used by other functions and should only be called explicitly for
#' debugging purposes.
#' @keywords internal
SZVD_ADMM <- function(B,  N, D, sols0, pen_scal, gamma, beta, tol, maxits, quiet = TRUE){
  ######################################################################################
  # Precomputes quantities that will be used repeatedly by the algorithm.
  ######################################################################################

  # Dimension of decision variables.
  p = dim(D)[1];

  # Compute D*N.
  DN = D%*%N

  ##################################################################
  # Initialize x solution and constants for x update.
  ##################################################################
  if (dim(B)[2] == 1){ ## K = 2 case.

    # Compute (DN)'*(mu1-mu2)
    w = t(DN) %*% B

    # Compute initial x if omitted.
    if (missing(sols0)){
      x = w/norm(w,'f')
    } else{
      x = sols0$x
    }

    # constant for x update step.
    Xup = beta - t(w)%*%w
    Xup = Xup[1,1]


  }  else{    ## K > 2 Case.

    # Dimension of the null-space of W.
    l= dim(N)[2]

    # Compute initial x if omitted.
    if (missing(sols0)){
      x = eigen((B+t(B))/2)$vectors[,1]
    }
    else{
      x = sols0$x
    }

    # Take Cholesky of beta I - B (for use in update of x)
    L = t(chol((beta*diag(l) - B)));

    # Presolve the x-update system.
    #Xup = solve(beta*diag(l) - B) %*% t(DN)
  }

  ##############################################################################################
  # Initialize decision variables y and z.
  ##############################################################################################

  # If omitted, set initial y equal to unpenalized ZVD, and z = 0 to start.
  if (missing(sols0)){
    y = DN %*%x
    z = as.matrix(rep(0,p))
  }
  else{
    y = sols0$y
    z = sols0$z
  }


  ##############################################################################################
  ## Call the algorithm.
  ##############################################################################################

  for (iter in 1:maxits){


    ##############################################################################################
    # Step 1: Perform shrinkage to update y_k+1.
    ##############################################################################################

    # Save previous iterate.
    yold = y

    # Call soft-thresholding.
    y = vec_shrink(beta*DN %*% x + z, gamma * pen_scal)

    # Normalize y (if necessary).
    tmp = max(c(0, norm(as.matrix(y), 'f') - beta ))
    y = y/(beta + tmp);

    # Truncate complex part (if have +0i terms appearing.)
    y = Re(y)



    ##############################################################################################
    # Step 2: Update x_k+1 by solving
    # x_k+1 = argmin { -x'*A*x + beta/2 l2(x - y_k+1 + z_k)^2}
    ##############################################################################################

    # Compute RHS.
    b = t(DN)%*%(beta*y - z)

    if (dim(B)[2]==1){    ## K=2

      # Update using Sherman-Morrison formula.
      x = (b + ((t(b)%*%w)[1,1]/Xup)*w)/beta

    }
    else{  ## K > 2.
      # Update using by solving the system LL'x = b.
      btmp = forwardsolve(L, b)
      x = backsolve(t(L), btmp)

      #x = Xup %*% (beta*y - z)
    }

    # Truncate complex part (if have +0i terms appearing.)
    x = Re(x)

    ##############################################################################################
    #  Step 3: Update the Lagrange multipliers
    # (according to the formula z_k+1 = z_k + beta*(N*x_k+1 - y_k+1) ).
    ##############################################################################################
    zold = z
    z = Re(z + beta*(DN%*%x - y))



    ##############################################################################################
    ### Check stopping criteria.
    ##############################################################################################


    ### Primal constraint violation.

    # Primal residual.
    r = DN%*%x-y;

    # l2 norm of the residual.
    dr = norm(as.matrix(r),'f');

    ### Dual constraint violation.

    # Dual residual.
    s = beta*(y - yold)

    # l2 norm of the residual.
    ds = norm(as.matrix(s), 'f')

    ###  Check if the stopping criteria are satisfied.

    # Compute absolute and relative tolerances.
    ep = sqrt(p)*tol$abs + tol$rel*max(c(norm(as.matrix(x), 'f'), norm(as.matrix(y), 'f') ))
    es = sqrt(p)*tol$abs + tol$rel*norm(as.matrix(y), 'f')

    # Display current iteration stats.
    if (quiet==FALSE){
      print(sprintf("it = %g, dr = %e, dr-ep = %e, ds = %e, ds-es = %e, normy = %e", iter, dr, dr-ep, ds, ds-es, norm(as.matrix(y), 'F')), quote=F)
    }

    # Check if the residual norms are less than the given tolerance.
    if (max(c(dr - ep, ds - es, 0), na.rm=TRUE)==0 & iter > 10){
      break # The algorithm has converged.
    }

  }


  ##############################################################################################
  # Output results.
  ##############################################################################################
  results = list(x=x, y=y, its = iter, z=z, errtol = min(c(ep,es)))
  return(results)
}

#' Softmax for SZVD ADMM iterations
#'
#' Applies softmax the soft thresholding shrinkage operator to v with tolerance a.
#' That is, output is the vector with entries with absolute value v_i - a if
#' |v_i| > a and zero otherwise, with sign pattern matching that of v.
#'
#' @param v Vector to be thresholded.
#' @param a Vector of tolerances.
#' @return thresholded v vector.
#'
#' \describe{
#'   \item{\code{x,y,z}}{Iterates at termination.}
#'   \item{\code{its}}{Number of iterations required to converge.}
#'   \item{\code{errtol}}{Stopping error bound at termination}
#' }
#' @seealso Used by: \code{\link{SZVD_ADMM}}.
#' @details
#' This function is used by other functions and should only be called explicitly for
#' debugging purposes.
#' @keywords internal
vec_shrink <- function(v,a){
  s = sign(v)*pmax(abs(v)-a, rep(0, length(v)))
  return(s)
}

Try the accSDA package in your browser

Any scripts or data that you put into this service are public.

accSDA documentation built on Sept. 5, 2022, 5:05 p.m.