#'@title Smoothed Poisson Matrix Factorization
#'@description Smoothed Poisson Matrix Factorization/ smoothed topic model
#'@param X count matrix
#'@param K number of factors/ranks
#'@param init initialization methods, 'lee','scd' from package NNLM, or 'uniform' randomly initialize; or provide init as a list with L_init and F_init.
#'@param init_loss loss function of the initialization method, either mkl or mse.
#'@param maxiter maximum iterations
#'@param tol stop criteria
#'@param fix_F if TRUE, F will no be updated.
#'@param bmsm_control control parameters of BMSM, see bmsm_control_default()
#'@param ebpm_method point_gamma or two_gamma
#'@param ebpm_control control parameters of ebpm, see ebpm_control_default()
#'@param nug_control control parameters of smashgen, see nug_control_default()
#'@param smooth_f,smooth_l whether to get smooth estimate of loadings or factors.
#'@param nugget whether to assume nugget effects
#'@param return_all whether return all outputs or simplified ones
#'@return EL,EF: posterior of loadings and factors
#'@import mixsqp
#'@import NNLM
#'@import ebpm
#'@import Matrix
#'@export
stm = function(X,K,
init = 'scd',init_loss = 'mkl',maxiter=100,tol=1e-3,
fix_F = FALSE,
bmsm_control_l=list(), bmsm_control_f=list(),
nug_control_l=list(), nug_control_f=list(),
#filter.number = 1,family = "DaubExPhase",
ebpm_method='point_gamma',
ebpm_control_l=list(), ebpm_control_f=list(),
smooth_f=TRUE,smooth_l=FALSE,
nugget=FALSE,
printevery=10){
n = dim(X)[1]
p = dim(X)[2]
res = init_stm(X,K,init,init_loss)
#plot(res$ql$El[,1])
#plot(res$ql$El[,2])
#plot(res$ql$El[,3])
#inited = list(L_init = res$ql$El,F_init = res$qf$Ef)
#EZ = array(dim = c(n,p,K))
#EZ = Calc_EZ(X,K,EZ,ql_hat,qf_hat)
KL = c()
# loglik = c()
#browser()
KL[1] = mKL(X,tcrossprod(res$ql$El,res$qf$Ef))
X = Matrix::Matrix(X,sparse = TRUE)
X_idx = summary(X)
for(iter in 1:maxiter){
b_k_max = 0
for(k in 1:K){
# get row and col sums of EZ_k
b_k = res$ql$Elogl[X_idx$i,k]+res$qf$Elogf[X_idx$j,k] - res$a
EZ_k = sparseMatrix(i=X_idx$i,j=X_idx$j,x = X_idx$x*exp(b_k)/res$b,dims = c(n,p))
l_seq = rowSums(EZ_k)
l_scale = sum(res$qf$Ef[,k])
# adj.ratio = sqrt(l_scale/f_scale)
#
# l_scale = l_scale/adj.ratio
# l_seq = l_seq * adj.ratio
# f_scale = f_scale*adj.ratio
# f_seq = f_seq / adj.ratio
#print(l_scale)
#print(f_scale)
# Update L
if(smooth_l){
lk_hat = update_smooth(l_seq, l_scale, nugget,bmsm_control_l,nug_control_l)
res$ql$El[,k] = lk_hat$E
res$ql$Elogl[,k] = lk_hat$Elog
#loglikL = loglikL + lk_hat$loglik
res$nugget_l[k] = lk_hat$nugget
res$gl[[k]] = lk_hat$pi_weights
}else{
lk_hat = update_nsmooth(l_seq,l_scale,ebpm_control_l,ebpm_method)
res$ql$El[,k] = lk_hat$posterior$mean
res$ql$Elogl[,k] = lk_hat$posterior$mean_log
#loglikL = loglikL + lk_hat$log_likelihood
res$gl[[k]] = lk_hat$fitted_g
}
# Update F
if(!fix_F){
f_seq = colSums(EZ_k)
f_scale = sum(res$ql$El[,k])
if(smooth_f){
fk_hat = update_smooth(f_seq, f_scale,nugget,bmsm_control_f,nug_control_f)
res$qf$Ef[,k] = fk_hat$E
res$qf$Elogf[,k] = fk_hat$Elog
#loglikR = loglikR + fk_hat$loglik
res$nugget_f[k] = fk_hat$nugget
res$gf[[k]] = fk_hat$pi_weight
}else{
fk_hat = update_nsmooth(f_seq,f_scale,ebpm_control_f,ebpm_method)
res$qf$Ef[,k] = fk_hat$posterior$mean
res$qf$Elogf[,k] = fk_hat$posterior$mean_log
#loglikR = loglikR + fk_hat$log_likelihood
res$gf[[k]] = fk_hat$fitted_g
}
}
b_k_new = res$ql$Elogl[X_idx$i,k] + res$qf$Elogf[X_idx$j, k] - res$a
res$b = res$b - exp(b_k) + exp(b_k_new)
b_k_max = pmax(b_k_new, b_k_max)
}
# # Update Z
# EZ = Calc_EZ(X,K,EZ,ql_hat,qf_hat)
res$b = res$b/exp(b_k_max)
res$a = b_k_max + res$a
KL[iter+1] = mKL(X,tcrossprod(res$ql$El,res$qf$Ef))
#loglik[iter] = loglikR+loglikL
########
if(iter%%printevery==0){
print(sprintf('At iter %d, mean KL: %f',iter,KL[iter+1]))
}
########
if(abs(KL[iter+1]-KL[iter])<=tol){
break
}
}
if(iter==maxiter){
warning('Reached maximum iterations')
}
lambda_hat = tcrossprod(res$ql$El,res$qf$Ef)
#lambda_init = L_init%*%F_init
# loglik = sum(dpois(X,lambda_hat,log = TRUE))
ldf = poisson2multinom(res$qf$Ef,res$ql$El)
fit = list(res = res,EL = ldf$L,EF = ldf$FF,d=ldf$s)
return(fit)
# if(return_all){
# return(list(ql=ql_hat,qf=qf_hat,gf=gf_hat,gl=gl_hat,KL=KL,Lambda_hat=lambda_hat,
# init = inited,
# input = list(X=X,K=K),nugget=list(nugget_l=nugget_l,nugget_f=nugget_f)))
# }else{
# return(list(ql=ql_hat$El,qf=qf_hat$Ef,nugget=list(nugget_l=nugget_l,nugget_f=nugget_f),KL=KL))
# }
}
#'@title initialize the stm model
#'@param X input data matrix
#'@param K number of topics
#'@param init init methods, or a list of init L and F
#'@param init_loss mkl or mse
#'@export
init_stm = function(X,K,init,init_loss){
if(is.list(init)){
L_init = init$L_init
F_init = init$F_init
if(is.null(L_init)){
X_init_fit = NNLM::nnmf(as.matrix(X),K,method='lee',
loss='mse',show.warning = F,
init = list(H=t(F_init)),
verbose = F,max.iter = 50)
L_init = X_init_fit$W
}
}else{
if(init%in%c('scd','lee')){
X_init_fit = NNLM::nnmf(as.matrix(X),K,method=init,loss=init_loss,show.warning = F,verbose = F,max.iter = 50)
L_init = X_init_fit$W
F_init = t(X_init_fit$H)
}
# if(init == 'scd'){
# X_init_fit = NNLM::nnmf(as.matrix(X),K,method='scd',loss=init_loss,show.warning = F,verbose = F,max.iter = 50)
# L_init = X_init_fit$W
# F_init = t(X_init_fit$H)
# }
# if(init == 'lee'){
# X_init_fit = NNLM::nnmf(as.matrix(X),K,method='lee',loss=init_loss,show.warning = F,verbose = F,max.iter = 100)
# L_init = X_init_fit$W
# F_init = t(X_init_fit$H)
# }
if(init == 'uniform'){
L_init = matrix(runif(n*K),nrow=n,ncol=K)
F_init = matrix(runif(K*p),nrow=p,ncol=K)
ratio = median(X)/(median(L_init)*median(F_init))
L_init = L_init*sqrt(ratio)
F_init = F_init*sqrt(ratio)
}
if(init == 'kmeans'){
kmeans.init=kmeans(as.matrix(X),K,nstart=5)
L_init = rep(1,n)%o%normalize(as.vector(table(kmeans.init$cluster)))
F_init = t(kmeans.init$centers)
row.names(F_init)=NULL
}
}
# adjust scale of L and F, mainly for stability.
ratio = adjLF(L_init,F_init)
L_init = ratio$L_init
F_init = ratio$F_init
Elogl = log(L_init+1e-10)
Elogf = log(F_init+1e-10)
ql = list(El = L_init, Elogl = Elogl)
qf = list(Ef = F_init, Elogf = Elogf)
a = 0
b = 0
X = Matrix(X,sparse = TRUE)
d = summary(X)
temp = Elogl[d$i,] + Elogf[d$j,]
a = rowMax(temp)
b = rowSums(exp(temp-a))
gl = list()
gf = list()
return(list(ql=ql,qf=qf,gl=gl,gf=gf,
a=a,b=b,
nugget_l = rep(0,K),
nugget_f = rep(0,K)))
}
rowMax = function(X){
do.call(pmax.int, c(na.rm = TRUE, as.data.frame(X)))
}
Calc_EZ = function(X,K,EZ,ql_hat,qf_hat){
n = nrow(X)
p = ncol(X)
for(k in 1:K){
EZ[,,k] = outer(ql_hat$Elogl[,k], qf_hat$Elogf[k,], "+")
}
EZ = softmax3d(EZ)
EZ = as.vector(EZ)*as.vector(X)
dim(EZ) = c(n,p,K)
EZ
}
softmax3d <- function(x){
score.exp <- exp(x)
probs <-as.vector(score.exp)/as.vector(rowSums(score.exp,dims=2))
probs[is.na(probs)] = 0
dim(probs) <- dim(x)
return(probs)
}
update_smooth = function(x,sf,nugget,bmsm_control=list(),nug_control=list()){
if(min(x) < 0){stop ("negative values in x not permitted")}
if(nugget){
control0 = nug_control_default()
if (any(!is.element(names(nug_control),names(control0))))
stop("Argument \"nug_control\" contains unknown parameter names")
control1 = modifyList(control0,nug_control,keep.null = TRUE)
fit = smash.gen.poiss(x,s=sf,filter.number=control1$filter.number,
family=control1$family,nugget=control1$nugget,
robust=control1$robust,
robust.q = control1$robust.q,
transformation = control1$transformation,
method = control1$method,
nug.init = control1$nug.init,
ash.pm = control1$ash.pm,
eps = control1$eps,
maxiter = control1$maxiter,
tol = control1$tol)
est = fit$lambda.est
if(control1$transformation=='lik_expansion'){
est_log = fit$mu.est
}else{
est_log = log(est)
}
pi_weights = NULL
nugget.est = fit$nugget.est
}else{
fit = BMSM(x,sf,bmsm_control)
est = fit$E
est_log = fit$Elog
pi_weights = fit$pi_weights
nugget.est = 0
}
#loglik = fit$loglik
results = list("E" = est,
'Elog' = est_log,
"pi_weights" = pi_weights,
#"loglik" = loglik,
"nugget" = nugget.est)
return(results)
}
update_nsmooth = function(x,s,ebpm_control = list(),ebpm_method){
control0 = ebpm_control_default()
if (any(!is.element(names(ebpm_control),names(control0))))
stop("Argument \"ebpm_control\" contains unknown parameter names")
control1 = modifyList(control0,ebpm_control,keep.null = TRUE)
#scale = control1$scale
#point_mass=control1$point_mass
#nullweight=control1$nullweight
#shape= control1$shape
g_init = control1$g_init
fix_g = control1$fix_g
#m = control1$m
control = control1$control
#low = control1$low
#d = control1$d
pi0 = control1$pi0
if(ebpm_method=='point_gamma'){
out = ebpm::ebpm_point_gamma(x,s,g_init,fix_g,pi0,control)
}
if(ebpm_method=='two_gamma'){
out = ebpm::ebpm_two_gamma(x,s,g_init,fix_g,pi0,control)
}
out
}
#'@title Default parameters of ebpm
#'@export
ebpm_control_default = function(){
list(pi0 = 'estimate',
#point_mass=F,
#nullweight=100,
#shape=1,
g_init = NULL,
fix_g = FALSE,
#m = 2,
control = NULL)
#low = NULL,
#d=NULL
}
#'@title Default parameters of smash gen
#'@param filter.number,family wavelet basis, see wavethresh pakcage for more details.
#'@export
nug_control_default = function(){
list(nugget = NULL,
robust=T,
robust.q = 0.99,
transformation = 'lik_expansion',
method='ti.thresh',
nug.init = NULL,
ash.pm=FALSE,
eps='estimate',
maxiter=10,
tol=1e-2,
filter.number = 1,
family = "DaubExPhase")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.