R/DEM1.R

Defines functions DEM1

Documented in DEM1

#' The DEM1 algorithm is a divide and conquer algorithm, which is used to solve the parameter estimation of multivariate Gaussian mixture model.
#'
#' @param y is a data matrix
#' @param M is the number of subsets
#' @param seed is the recommended way to specify seeds
#' @param alpha0 is the initial value of the mixing weight
#' @param mu0 is the initial value of the mean
#' @param sigma0 is the initial value of the covariance
#' @param i is the number of iterations
#' @param epsilon is the threshold value
#'
#' @return DEM1alpha,DEM1mu,DEM1sigma,DEM1time
#' @export
#'

#' @examples
#' library(mvtnorm)
#' alpha1= c(rep(1/4,4)) 
#' mu1=matrix(0,nrow=4,ncol=4) 
#' for (k in 1:4){
#' mu1[4,]=c(runif(4,(k-1)*3,k*3)) 
#' }
#' sigma1=list()
#' for (k in 1:4){
#' sigma1[[k]]= diag(4)*0.1
#' }
#' y= matrix(0,nrow=200,ncol=4) 
#' for(k in 1:4){
#' y[c(((k-1)*200/4+1):(k*200/4)),] = rmvnorm(200/4,mu1[k,],sigma1[[k]]) 
#' }
#' M=5
#' seed=123
#' alpha0= alpha1
#' mu0=mu1
#' sigma0=sigma1
#' i=10
#' epsilon=0.005
#' DEM1(y,M,seed,alpha0,mu0,sigma0,i,epsilon)

DEM1=function(y,M,seed,alpha0,mu0,sigma0,i,epsilon){
n=nrow(y)
p=ncol(y)
K=length(alpha0)
nm=n/M 
alpha=alpha0 
mu=mu0 
sigma=sigma0
time1=system.time(for (step in 1:i){
M1=m1=c(rep(0,K)) 
M2=m2=matrix(rep(0, K*p), nrow = K) 
M3=list()
for (k in 1:K){
M3[[k]]=matrix(rep(0, p*p), nrow = p)
}
m3=M3
set.seed(seed)
mr=matrix(sample(c(1:n),n,replace=FALSE),nrow = M,ncol=nm,byrow=TRUE)
for (m in 1:M) {
y1=y[mr[m,],] 
den=matrix(rep(0, K*nm), nrow = nm) 
prob=matrix(rep(0, K*nm), nrow = nm) 
weight=matrix(rep(0, K*nm), nrow = nm) 
for (k in 1:K){
den[, k]=dmvnorm(y1, mu[k,], sigma[[k]], log=FALSE) 
weight[, k]=alpha[k] * den[, k] 
}
prob=weight/rowSums(weight)
prob1=colSums(prob)
for (k in 1:K){
m1[k]=prob1[k] 
m2[k,] = t(y1) %*% prob[,k] 
varmat = matrix(0, ncol=ncol(y), nrow=ncol(y)) 
for (j in 1:nm){
varmat = varmat + prob[j,k]*((y1[j,])%*%t(y1[j,])) 
} 
m3[[k]]=varmat 
}
for (k in 1:K){
M1[k]=M1[k]+m1[k] 
M2[k,]=M2[k,]+m2[k,] 
M3[[k]]=M3[[k]]+m3[[k]] 
}
}
oldalpha=alpha
oldmu=mu
oldsigma=sigma
for (k in 1:K){
alpha[k]=M1[k]/n  
mu[k,] = M2[k,]/M1[k] 
sigma[[k]] = M3[[k]]/M1[k]-mu[k,]%*%t(mu[k,]) 
}
Sigma=c(rep(0,K))
for (k in 1:K){
Sigma[k]=max(abs(sigma[[k]]-oldsigma[[k]]))
}
     if(max(abs(alpha-oldalpha))<epsilon &
       max(abs(mu-oldmu))<epsilon & 
       max(Sigma)<epsilon)break  
 cat(
   "step",step,"\n",
   "alpha",alpha,"\n",
   "mu",mu,"\n"
)
}
)
time=time1/M
return(list(DEM1alpha=alpha, DEM1mu=mu, DEM1sigma=sigma,DEM1time=time))
}

Try the DEM package in your browser

Any scripts or data that you put into this service are public.

DEM documentation built on May 14, 2022, 9:05 a.m.