R/SDAAP.R In accSDA: Accelerated Sparse Discriminant Analysis

#' Sparse Discriminant Analysis solved via Accelerated Proximal Gradient
#'
#' Applies accelerated proximal gradient algorithm to
#' the optimal scoring formulation of sparse discriminant analysis proposed
#' by Clemmensen et al. 2011.
#'
#' @param Xt n by p data matrix, (not a data frame, but a matrix)
#' @param Yt n by K matrix of indicator variables (Yij = 1 if i in class j).
#'     This will later be changed to handle factor variables as well.
#'     Each observation belongs in a single class, so for a given row/observation,
#'     only one element is 1 and the rest is 0.
#' @param Om p by p parameter matrix Omega in generalized elastic net penalty.
#' @param gam Regularization parameter for elastic net penalty.
#' @param lam Regularization parameter for l1 penalty, must be greater than zero.
#' @param q Desired number of discriminant vectors.
#' @param PGsteps Maximum number if inner proximal gradient algorithm for finding beta.
#' @param PGtol Stopping tolerance for inner APG method.
#' @param maxits Number of iterations to run
#' @param tol Stopping tolerance for proximal gradient algorithm.
#' @return \code{SDAAP} returns an object of \code{\link{class}} "\code{SDAAP}" including a list
#' with the following named components: (More will be added later to handle the predict function)
#'
#' \describe{
#'   \item{\code{call}}{The matched call.}
#'   \item{\code{B}}{p by q matrix of discriminant vectors.}
#'   \item{\code{Q}}{K by q matrix of scoring vectors.}
#' }
#' @keywords internal
SDAAP <- function(x, ...) UseMethod("SDAAP")

#' @return \code{NULL}
#'
#' @rdname SDAAP
#' @method SDAAP default
SDAAP.default <- function(Xt, Yt, Om, gam, lam, q, PGsteps, PGtol, maxits, tol){
# TODO: Handle Yt as a factor and generate dummy matrix from it

# Get training data size
nt <- dim(Xt)[1] # num. samples
p <- dim(Xt)[2]  # num. features
K <- dim(Yt)[2]  # num. classes

# Structure to store elements for matrix A used
# later on, i.e. precomputed values for speed.
A <- structure(list(
flag = NA,
gom = NA,
X = NA,
n = NA,
A = NA
),
class = "Amat"
)

# Check if Omega is diagonal
if(norm(diag(diag(Om))-Om, type = "F") < 1e-15){
A$flag <- 1 A$gom <- gam*diag(Om)
A$X <- Xt A$n <- nt
alpha <- 1/(2*(norm(Xt, type="1")*norm(Xt, type="I")/nt + norm(diag(A$gom), type="I"))) }else{ A$flag <- 0
A$A <- 2*(crossprod(Xt)/nt + gam*Om) alpha <- 1/(norm(A$A, type="F"))
}

D <- (1/nt)*crossprod(Yt)
R <- chol(D)

Q <- matrix(1,K,q)
B <- matrix(0,p,q)

#------------------------------------------------------
# Alternating direction method to update (theta,beta)
#------------------------------------------------------
for(j in 1:q){
###
# Initialization
###
Qj <- Q[,1:j]

Mj <- function(u){
return(u-Qj%*%(crossprod(Qj,D%*%u)))
}

# Initialize theta
theta <- Mj(matrix(stats::runif(K),nrow=K,ncol=1))
theta <- theta/as.numeric(sqrt(crossprod(theta,D)%*%theta))

# Initialize beta
beta <- matrix(0,p,1)

for(its in 1:maxits){
# Compute coefficient vector for elastic net step
d <- 2*crossprod(Xt,Yt%*%(theta/nt))

# Update beta using proximal gradient step
b_old <- beta
betaOb <- APG_EN2(A, d, beta, lam, alpha, PGsteps, PGtol)
beta <- betaOb\$x

# Update theta using the projected solution
if(norm(beta, type="2") > 1e-12){
# Update theta
b <- crossprod(Yt,Xt%*%beta)
#y <- solve(t(R),b)
#z <- solve(R,y)
y <- forwardsolve(t(R),b)
z <- backsolve(R,y)
tt <- Mj(z)
t_old <- theta
theta <- tt/sqrt(as.numeric(crossprod(tt, D)%*%tt))

# Update changes
db <- norm(beta-b_old, type="2")/norm(beta, type="2")
dt <- norm(theta-t_old, type="2")/norm(theta, type="2")
} else{
# Update b and theta
beta <- beta*0
theta <- theta*0
db <- 0
dt <- 0
}

if(max(db,dt) < tol){
# Converged
break
}
}
# Update Q and B
Q[,j] <- theta
B[,j] <- beta
}
#Return B and Q in a SDAAP object
retOb <- structure(
list(call = match.call(),
B = B,
Q = Q),
class = "SDAAP")
return(retOb)
}

# TODO: Implement print, summary, predict and plot of SDAAP object
# Potentially also print for output from summary and predict.
# Skim over again to see if crossprod can be used for more speed.


Try the accSDA package in your browser

Any scripts or data that you put into this service are public.

accSDA documentation built on May 2, 2019, 5:42 a.m.