R/get_B_ADMM.R

Defines functions get_B_ADMM

Documented in get_B_ADMM

#' The main function of mSIM package
#'
#' @import mgcv
#' @import MASS
#' @import glmnet
#' @import Matrix
#' @import splines
#' @import lbfgs
#'
#' @export get_B_ADMM
#'
#' @param Y The index matrix
#' @param X The covariate matrix
#' @param B The initial coefficient matrix
#' @param lambda The regualrize coefficient
#' @param rank The rank of B matrix
#' @param alpha The step in the optimization process
#' @param control1 A list that contains parameters for minimize loss function
#' @param control2 A list that contains parameters for regularize
#' @param select.method link funtion selection method
#' @param descend.method optimization method
#' @param plot whether or not plot the error during the iteration
#' @return A list that contains information such as the final B matrix, the training error.
#' @examples
#' X <- scale(X_train)
#' Y <- scale(Y_train)
#' test = get_B_ADMM(Y=Y, X=X, lambda=0.2, rank=3, alpha=1, control1 = list(max.iter=5e1, tol=1e-5),
#' control2=list(ele.sparse=FALSE, row.sparse=TRUE, low.rank=TRUE), select.method='linear', plot=FALSE)


get_B_ADMM <- function(Y, X, B = NULL, lambda, rank, alpha=1, control1=list(max.iter=1e2, tol=1e-2), control2=list(ele.sparse=FALSE, row.sparse=TRUE, low.rank=TRUE), select.method='linear', descent.method='bfgs', plot=FALSE){
#  Y <- scale(Y)
#  X <- scale(X)
#  env_check()

  if(control2$ele.sparse & control2$row.sparse)
    stop('Only one sparse penalty is supported now')
  n <- dim(Y)[1]
  q <- dim(Y)[2]
  p <- dim(X)[2]

  if (is.null(B)){
    B <- get.B.ridge(Y, X, lambda = lambda)
  }

  if(select.method == 'linear'){
    ratio <- c()
    for(P in 1:p){
      temp <- 0
      for(Q in 1:q){
        temp <- temp + (cor(Y[,Q], X[,P]))^2
      }
      ratio <- c(ratio, temp)
    }
    fixed <- which.max(ratio)
  }
  if(select.method == 'nonlinear'){
    ratio <- c()
    for(P in 1:p){
      temp <- 0
      for(Q in 1:q){
        SST <- (n-1)*var(Y[,Q])
        nonlinear.fit <- pspline.gam(Y[,Q], X[,P])
        temp <- temp + sum((Y[,Q] - nonlinear.fit$fitted.values)^2)/SST
      }
      ratio <- c(ratio, temp)
    }
    fixed <- which.min(ratio)
  }

  C = B
  A = B
  D = B
  W1 <- matrix(0, p, q)
  W2 <- matrix(0, p, q)
  W3 <- matrix(0, p, q)
  converge <- 0
  n.iter <- 1
  pri.err.save <- c()
  dual.err.save <- c()

  while(n.iter <= control1$max.iter){    ##2-step iteration
    temp <- B

    # update A
    for(j in 1:q){
      eta <- B[,j] - W1[,j] / alpha
      eta <- eta / norm(eta, '2')
      if(eta[fixed] != 0){
        eta <- sign(eta[fixed]) * eta
      }
      A[,j] <- eta
    }

    # update C
    newY <- B - W2 / alpha
    if(control2$ele.sparse){
      C <- sign(newY)*pmax(abs(newY)- lambda / alpha, 0)
      C[fixed,] <- newY[fixed,]
    }
    if(control2$row.sparse){
      for(k in 1:p){
        if(1){
          if(k != fixed){
            C[k,] <- max(0, 1 - lambda/alpha/norm(newY[k,], '2')) * newY[k,]
          }else{
            C[k,] <- newY[k,]
          }
        }
        if(0){
          C[k,] <- max(0, 1 - lambda/alpha/norm(newY[k,], '2')) * newY[k,]
          if(k == fixed & max(0, 1 - lambda/alpha/norm(newY[k,], '2')) == 0){
            C[k,] <- 1
          }
        }
      }
    }

    # update D
    if(control2$low.rank){
      D <- rank.norm(B - W3 / alpha, rank)
    }

    # update B
    for(j in 1:q){
      y <- Y[,j]
      x <- X
      fit0 <- si.smooth(y, x, C[,j])
      h0 <- si.h(fit0)
      if(descent.method == 'gd'){ # not very useful
        b <- grad.descent(y, x, B[,j], alpha, control2, h0, W1[,j], W2[,j], W3[,j], A[,j], C[,j], D[,j])
      }else if(descent.method == 'bfgs'){
        b <- bfgs.descent(y, x, B[,j], alpha, h0, W1[,j], W2[,j], W3[,j], A[,j], C[,j], D[,j])
      }
      B[,j] <- b$beta
    }

    # update W's
    W1 <- W1 + alpha*(A - B)
    if(control2$ele.sparse | control2$row.sparse){
      W2 <- W2 + alpha*(C - B)
    }
    if(control2$low.rank){
      W3 <- W3 + alpha*(D - B)
    }

    # error
    pri.err <- (sum((A - B)^2) + sum((C - B)^2)*(control2$ele.sparse | control2$row.sparse) + sum((D - B)^2)*control2$low.rank) / (p*q)
    dual.err <- (alpha * norm(temp-B, 'F'))^2 / (p*q)
    pri.err.save <- c(pri.err.save, pri.err)
    dual.err.save <- c(dual.err.save, dual.err)
    error.control <- (max(pri.err, dual.err) <= control1$tol)

    if(plot){
      par(mfrow=c(2,1))
      plot(pri.err.save, type='l')
      plot(dual.err.save, type='l')
    }

    if(error.control){
      converge<-1
      break
    }

    n.iter <- n.iter + 1
  }

  return(list(B.final=col.norm(rank.norm(C, rank)), B.sparse=C, B.proj=A, B.lowrank=D, converge=converge, iteration=n.iter, pri.err=pri.err.save, dual.err=dual.err.save, noPenIndex=fixed))
}
echi/mSIM documentation built on Oct. 6, 2020, 11:09 a.m.