Nothing
NMTF <- function(X, M=NULL,
pseudocount=.Machine$double.eps,
initU=NULL, initS=NULL, initV=NULL,
fixU=FALSE, fixS=FALSE, fixV=FALSE,
L1_U=1e-10, L1_S=1e-10, L1_V=1e-10,
L2_U=1e-10, L2_S=1e-10, L2_V=1e-10,
orthU=FALSE, orthV=FALSE,
rank = c(3, 4),
algorithm = c("Frobenius", "KL", "IS", "ALS", "PG", "COD", "Beta"),
Beta = 2, root = FALSE, thr = 1e-10, num.iter = 100,
viz = FALSE, figdir = NULL, verbose = FALSE){
# Argument check
algorithm <- match.arg(algorithm)
.checkNMTF(X, M, pseudocount, initU, initS, initV,
fixU, fixS, fixV, L1_U, L1_S, L1_V, L2_U, L2_S, L2_V, orthU, orthV,
rank, Beta, root, thr, num.iter, viz, figdir, verbose)
# Initizalization
int <- .initNMTF(X, M, pseudocount, rank, initU, initS, initV,
algorithm, Beta, thr, verbose)
X <- int$X
M <- int$M
pM <- int$pM
M_NA <- int$M_NA
U <- int$U
S <- int$S
V <- int$V
RecError <- int$RecError
TrainRecError <- int$TrainRecError
TestRecError <- int$TestRecError
RelChange <- int$RelChange
Beta <- int$Beta
algorithm <- int$algorithm
# Iteration
iter <- 1
while ((RecError[iter] > thr) && (iter <= num.iter)) {
# Reconstruction
X_bar <- U %*% S %*% t(V)
pre_Error <- .recError(X, X_bar)
# Update U
if(!fixU){
U <- .updateU(X, pM, U, S, V, L1_U, L2_U, orthU, Beta, root, algorithm)
}
# Update V
if(!fixV){
V <- .updateV(X, pM, U, S, V, L1_V, L2_V, orthV, Beta, root, algorithm)
}
# Update S
if(!fixS){
S <- .updateS(X, pM, U, S, V, L1_S, L2_S, Beta, root, algorithm)
}
# After Update U, S, V
iter <- iter + 1
X_bar <- U %*% S %*% t(V)
RecError[iter] <- .recError(X, X_bar)
TrainRecError[iter] <- .recError((1-M_NA+M)*X, (1-M_NA+M)*X_bar)
TestRecError[iter] <- .recError((M_NA-M)*X, (M_NA-M)*X_bar)
RelChange[iter] <- abs(pre_Error - RecError[iter]) / RecError[iter]
if (viz && !is.null(figdir)) {
png(filename = paste0(figdir, "/", iter-1, ".png"))
.multiImagePlots(list(X, X_bar, U, S, t(V)))
dev.off()
}
if (viz && is.null(figdir)) {
.multiImagePlots(list(X, X_bar, U, S, t(V)))
}
if (verbose) {
cat(paste0(iter-1, " / ", num.iter, " |Previous Error - Error| / Error = ",
RelChange[iter], "\n"))
}
if (is.nan(RelChange[iter])) {
stop("NaN is generated. Please run again or change the parameters.\n")
}
}
# After iteration
if (viz && !is.null(figdir)) {
png(filename = paste0(figdir, "/finish.png"))
image.plot2(X_bar)
dev.off()
png(filename = paste0(figdir, "/original.png"))
image.plot2(X)
dev.off()
}
if (viz && is.null(figdir)) {
.multiImagePlots(list(X, X_bar, U, S, t(V)))
}
names(RecError) <- c("offset", seq_len(iter-1))
names(TrainRecError) <- c("offset", seq_len(iter-1))
names(TestRecError) <- c("offset", seq_len(iter-1))
names(RelChange) <- c("offset", seq_len(iter-1))
# Output
list(U = U, S = S, V = V, rank = rank,
RecError = RecError,
TrainRecError = TrainRecError,
TestRecError = TestRecError,
RelChange = RelChange)
}
.checkNMTF <- function(X, M, pseudocount, initU, initS, initV,
fixU, fixS, fixV, L1_U, L1_S, L1_V, L2_U, L2_S, L2_V, orthU, orthV,
rank, Beta, root, thr, num.iter, viz, figdir, verbose){
stopifnot(is.matrix(X))
if(!is.null(M)){
if(!identical(dim(X), dim(M))){
stop("Please specify the dimensions of X and M are same")
}
}
.checkZeroNA(X, M, type="matrix")
stopifnot(is.numeric(pseudocount))
if(!is.null(initU)){
if(!identical(nrow(X), nrow(initU))){
stop("Please specify nrow(X) and nrow(initU) are same")
}
}
if(!is.null(initS)){
if(rank[1] != nrow(initS)){
stop("Please specify rank[1] and nrow(initS) are same")
}
if(rank[2] != ncol(initS)){
stop("Please specify rank[2] and ncol(initS) are same")
}
}
if(!is.null(initV)){
if(!identical(ncol(X), nrow(initV))){
stop("Please specify ncol(X) and nrow(initV) are same")
}
}
stopifnot(is.logical(fixU))
stopifnot(is.logical(fixS))
stopifnot(is.logical(fixV))
if(L1_U < 0){
stop("Please specify the L1_U that larger than 0")
}
if(L1_S < 0){
stop("Please specify the L1_S that larger than 0")
}
if(L1_V < 0){
stop("Please specify the L1_V that larger than 0")
}
if(L2_U < 0){
stop("Please specify the L2_U that larger than 0")
}
if(L2_S < 0){
stop("Please specify the L2_S that larger than 0")
}
if(L2_V < 0){
stop("Please specify the L2_V that larger than 0")
}
stopifnot(is.logical(orthU))
stopifnot(is.logical(orthV))
stopifnot(is.numeric(rank))
stopifnot(is.numeric(Beta))
stopifnot(is.logical(root))
stopifnot(is.numeric(thr))
stopifnot(is.numeric(num.iter))
stopifnot(is.logical(viz))
if(!is.character(figdir) && !is.null(figdir)){
stop("Please specify the figdir as a string or NULL")
}
stopifnot(is.logical(verbose))
}
.initNMTF <- function(X, M, pseudocount, rank, initU, initS, initV,
algorithm, Beta, thr, verbose){
# NA mask
M_NA <- X
M_NA[] <- 1
M_NA[which(is.na(X))] <- 0
if(is.null(M)){
M <- M_NA
}
pM <- M
# Pseudo count
X[which(is.na(X))] <- pseudocount
X[which(X == 0)] <- pseudocount
pM[which(pM == 0)] <- pseudocount
if(is.null(initU)){
U <- matrix(runif(nrow(X)*rank[1]), nrow=nrow(X), ncol=rank[1])
}else{
U <- initU
}
if(is.null(initV)){
V <- matrix(runif(ncol(X)*rank[2]), nrow=ncol(X), ncol=rank[2])
}else{
V <- initV
}
if(is.null(initS)){
S <- matrix(runif(prod(rank)), nrow=rank[1], ncol=rank[2])
}else{
S <- initS
}
RecError = c()
TrainRecError = c()
TestRecError = c()
RelChange = c()
RecError[1] <- thr * 10
TrainRecError[1] <- thr * 10
TestRecError[1] <- thr * 10
RelChange[1] <- thr * 10
# Algorithm
if (algorithm == "Frobenius") {
Beta = 2
algorithm = "Beta"
}
if (algorithm == "KL") {
Beta = 1
algorithm = "Beta"
}
if (algorithm == "IS") {
Beta = 0
algorithm = "Beta"
}
if (verbose) {
cat("Iterative step is running...\n")
}
list(X=X, M=M, pM=pM, M_NA=M_NA, U=U, S=S, V=V, RecError=RecError,
TrainRecError=TrainRecError, TestRecError=TestRecError,
RelChange=RelChange,
Beta=Beta, algorithm=algorithm)
}
.updateU <- function(X, pM, U, S, V, L1_U, L2_U, orthU, Beta, root, algorithm){
if(algorithm == "Beta"){
if(orthU){
VS <- V %*% t(S)
UU <- U %*% t(U)
numer <- (pM * X) %*% VS
denom <- UU %*% (pM * X) %*% VS
}else{
VS <- V %*% t(S)
numer <- ((U %*% t(VS))^(Beta-2) * (pM * X)) %*% VS
denom <- (U %*% t(VS))^(Beta-1) %*% VS
denom <- denom + L1_U + L2_U * U
}
U <- U * (numer / denom)^.rho(Beta, root)
}
if(algorithm == "ALS"){
VS <- V %*% t(S)
U <- .positive((X %*% VS) %*% ginv(crossprod(VS)))
}
if(algorithm == "PG"){
USV <- pM * (U %*% S %*% t(V))
VS <- V %*% t(S)
VV <- t(V) %*% V
Pu <- U - U / ((USV %*% VS) * (X %*% VS))
numer <- sum(Pu * (USV %*% VS - (pM * X) %*% VS))
denom <- .trace((S %*% VV) %*% (t(S) %*% t(Pu) %*% Pu))
etaU <- numer / denom
U <- .positive(U - etaU * Pu)
}
if(algorithm == "COD"){
for(i in seq_len(ncol(U))){
VS <- V %*% t(S)
USV <- U %*% S %*% t(V)
numer <- ((pM * X) %*% VS)[,i] - (USV %*% VS)[,i]
denom <- as.numeric(crossprod(V %*% S[i, ]))
U[,i] <- U[,i] + .positive(numer / denom)
}
}
U
}
.updateV <- function(X, pM, U, S, V, L1_V, L2_V, orthV, Beta, root, algorithm){
if(algorithm == "Beta"){
SU <- t(S) %*% t(U)
VV <- V %*% t(V)
if(orthV){
numer <- SU %*% (pM * X)
denom <- SU %*% (pM * X) %*% VV
}else{
numer <- SU %*% ((t(SU) %*% t(V))^(Beta - 2) * (pM * X))
denom <- SU %*% (t(SU) %*% t(V))^(Beta - 1)
denom <- denom + L1_V + L2_V * t(V)
}
V <- V * t((numer / denom)^.rho(Beta, root))
}
if(algorithm == "ALS"){
US <- U %*% S
V <- .positive((t(X) %*% US) %*% ginv(crossprod(US)))
}
if(algorithm == "PG"){
VSU <- (V %*% t(S) %*% t(U)) * t(pM)
US <- U %*% S
SUU <- t(S) %*% t(U) %*% U
Pv <- V - V / ((VSU %*% US) * (t(X) %*% US))
numer <- sum(Pv * (VSU %*% US - t(X) %*% US))
denom <- .trace((S %*% t(Pv) %*% Pv) %*% SUU)
etaV <- numer / denom
V <- .positive(V - etaV * Pv)
}
if(algorithm == "COD"){
for(j in seq_len(ncol(V))){
US <- U %*% S
VSU <- V %*% t(S) %*% t(U)
numer <- (t(X) %*% US)[,j] - (VSU %*% US)[,j]
denom <- as.numeric(crossprod(U %*% S[,j]))
V[,j] <- V[,j] + .positive(numer / denom)
}
}
V
}
.updateS <- function(X, pM, U, S, V, L1_S, L2_S, Beta, root, algorithm){
if(algorithm == "Beta"){
US <- U %*% S
numer <- t(U) %*% ((US %*% t(V))^(Beta-2) * (pM * X)) %*% V
denom <- t(U) %*% (US %*% t(V))^(Beta-1) %*% V
denom <- denom + L1_S + L2_S * S
S <- S * (numer / denom)^.rho(Beta, root)
}
if(algorithm == "ALS"){
UU <- t(U) %*% U
UXV <- t(U) %*% (pM * X) %*% V
VV <- t(V) %*% V
S <- .positive(ginv(UU) %*% UXV %*% ginv(VV))
}
if(algorithm == "PG"){
USV <- pM * (U %*% S %*% t(V))
UXV <- t(U) %*% (pM * X) %*% V
UU <- t(U) %*% U
VV <- t(V) %*% V
Ps <- S - S / ((t(U) %*% USV %*% V) * UXV)
numer <- sum(Ps * (t(U) %*% USV %*% V - UXV))
denom <- .trace((UU %*% Ps) %*% (VV %*% t(Ps)))
etaS <- numer / denom
S <- .positive(S - etaS * Ps)
}
if(algorithm == "COD"){
UXV <- t(U) %*% (pM * X) %*% V
UUSVV <- t(U) %*% U %*% S %*% t(V) %*% V
for(i in seq_len(ncol(U))){
for(j in seq_len(ncol(V))){
numer <- UXV[i,j] - UUSVV[i,j]
denom <- as.numeric((U[,i] %*% U[,i]) * (V[,j] %*% V[,j]))
S[i,j] <- S[i,j] + .positive(numer / denom)
}
}
}
S
}
.trace <- function(mat){
sum(diag(mat))
}
.multiImagePlots <- function(inputList){
layout(rbind(1:3, 4:6))
image.plot2(inputList[[1]], main="X")
image.plot2(inputList[[2]], main="rec X")
plot.new()
image.plot2(inputList[[3]], main="U")
image.plot2(inputList[[4]], main="S")
image.plot2(inputList[[5]], main="t(V)")
}
image.plot2 <- function(A, ...){
image.plot(t(A[nrow(A):1,]), ...)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.