R/mvnmix.R

Defines functions sim.mvn.mixture plot.mvn.mixture print.mvn.mixture mvnmix getMeanVar toPar toTheta

Documented in mvnmix

toTheta <- function(mu,Sigma,p) {
  theta <- c(as.vector(t(mu)), as.vector(t(Sigma)), p[-nrow(mu)])
  return(theta)
}
toPar <- function(theta, D, k) {
  mus <- Sigmas <- c()
  for (j in 1:k) {
    muj.idx <- (1+((j-1)*D)):(j*D)
    mus <- rbind(mus, theta[muj.idx])
    Sigmaj.start <- k*D+1 + ((j-1)*D^2)
    Sigmaj.idx <- Sigmaj.start + 1:D^2-1
    Sigmas <- rbind(Sigmas, theta[Sigmaj.idx])
  };  ps <- tail(theta,k-1); ps <- c(ps,1-sum(ps))
  return(list(mu=mus, Sigma=Sigmas, p=ps))
}
getMeanVar <- function(object,k,iter,...) {
  if (missing(iter))
    pp <- with(object,toPar(pars,D,k))
  else
    pp <- with(object,toPar(thetas[iter,],D,k))
  res <- list()
  for (i in 1:object$k) {
    mu <- pp$mu[i,]
    V <- matrix(pp$Sigma[i,],ncol=object$D);
    res <- c(res, list(list(mean=mu, var=V)))
  }
  if (missing(k))
    return(res)
  else
    return(res[[k]])
}



#' Estimate mixture latent variable model
#'
#' Estimate mixture latent variable model
#'
#' Estimate parameters in a mixture of latent variable models via the EM
#' algorithm.
#'
#' @param data \code{data.frame}
#' @param k Number of mixture components
#' @param theta Optional starting values
#' @param steps Maximum number of iterations
#' @param tol Convergence tolerance of EM algorithm
#' @param lambda Regularisation parameter. Added to diagonal of covariance matrix (to avoid
#' singularities)
#' @param mu Initial centres (if unspecified random centres will be chosen)
#' @param silent Turn on/off output messages
#' @param extra Extra debug information
#' @param n.start Number of restarts
#' @param init Function to choose initial centres
#' @param ... Additional arguments parsed to lower-level functions
#' @return A \code{mixture} object
#' @author Klaus K. Holst
#' @seealso \code{mixture}
#' @keywords models regression
#' @examples
#'
#' data(faithful)
#' set.seed(1)
#' M1 <- mvnmix(faithful[,"waiting",drop=FALSE],k=2)
#' M2 <- mvnmix(faithful,k=2)
#' if (interactive()) {
#'     par(mfrow=c(2,1))
#'     plot(M1,col=c("orange","blue"),ylim=c(0,0.05))
#'     plot(M2,col=c("orange","blue"))
#' }
#'
#' @export mvnmix
mvnmix <- function(data, k=2, theta, steps=500,
            tol=1e-16, lambda=0,
            mu=NULL,
            silent=TRUE, extra=FALSE,
            n.start=1,
            init="kmpp", ...
            )  {

    if (k<2) stop("Only one cluster")
    ## theta = (mu1, ..., muk, Sigma1, ..., Sigmak, p1, ..., p[k-1])
    if (is.vector(data)) data <- matrix(data,ncol=1)
    if (is.data.frame(data)) data <- as.matrix(data)
    i <- 0
    E <- tol
    D <- ncol(data)
    yunique <- unique(data)
    if (n.start>1) extra <- FALSE

    logllmax <- -Inf
    for (ii in seq(n.start)) {
      if (ii>1) mu <- NULL
      if (missing(theta)) {
          mus <- c()
          if (!is.null(mu)) {
              mus <- mu
          } else {
              if (!exists(init)) { ## Random select centres
                  idx <- sample(NROW(data),k)
              } else {
                  idx <- do.call(init, list(data, k))
              }
              mus <- unlist(lapply(idx, function(i) cbind(data)[i,,drop=TRUE]))
          }
          Sigmas <- rep(as.vector(cov(data)),k)
          ps <- rep(1/k,k-1)
          theta <- c(mus,Sigmas,ps)
      }

      theta0 <- theta
      if (!silent)
          cat(i,":\t", paste(formatC(theta0),collapse=" "),"\n")
      thetas <- members <- c()
      while ((i<steps) & (E>=tol)) {
          if (extra)
              thetas <- rbind(thetas, theta)
          pp <- toPar(theta,D,k)
          mus <- pp$mu; Sigmas <- pp$Sigma; ps <- pp$p
          ## E(xpectation step)
          lphis <- c()
          for (j in 1:k) {
              C <- matrix(Sigmas[j,],ncol=D); diag(C) <- diag(C)+lambda ## Assure C is not singular
              lphis <- cbind(lphis, lava::dmvn0(data,mus[j,],C,log=TRUE))
          }
          gammas <- c()
          ## denom <- t(ps%*%t(phis))
          for (j in 1:k) {
              gammas <- cbind(gammas, log(ps[j]) + lphis[,j])
          }
          ## llmax <- apply(gammas,1,max)
          ## for (j in 1:k) {
          ##     gammas[,j] <- gammas[,j]-llmax
          ## }
          gammas <- exp(gammas) #
          denom <- rowSums(gammas)
          for (j in 1:k) gammas[,j] <- gammas[,j]/denom # Posterior
          sqrtgammas <- sqrt(gammas)
          ## M(aximization step)
          mus.new <- c()
          Sigmas.new <- c()
          for (j in 1:k) {
              mu.new <- colSums(gammas[,j]*data)/sum(gammas[,j])
              mus.new <- rbind(mus.new, mu.new)
              wcy <- sqrtgammas[,j]*t(t(data)-mus.new[j,])
              Sigma.new <- t(wcy)%*%wcy/sum(gammas[,j])
              Sigmas.new <- rbind(Sigmas.new, as.vector(Sigma.new))
          }; ps.new <- colMeans(gammas)
          theta.old <- theta
          if (extra)
              members <- cbind(members,
                               apply(gammas,1,function(x) order(x,decreasing=TRUE)[1]))
          theta <- toTheta(mus.new,Sigmas.new,ps.new)
          E <- sum((theta-theta.old)^2)
          i <- i+1
          iter <- i
          if (!silent)
              cat(i,":\t", paste(formatC(theta),collapse=" "),
                  ",\t\te=",formatC(E), "\n",sep="")
      }
      if (n.start>1) {
          logll <- sum(log(denom))
          if (logll>logllmax) {
              logllmax <- logll
              theta.keep <- theta
              gammas.keep <- gammas
              E.keep <- E
          }
      }
    }
    if (n.start>1) {
        theta <- theta.keep
        gammas <- gammas.keep
        E <- E.keep
    }

        myvars <- colnames(data)
    if (is.null(myvars)) myvars <- colnames(data) <- paste("y",1:NCOL(data),sep="")
    data <- as.data.frame(data)
    m <- lvm(myvars,silent=TRUE); m <- covariance(m,myvars,pairwise=TRUE)
    models <- datas <- c()
    for (i in 1:k) {
        models <- c(models, list(m))
        datas <- c(datas, list(data))
    }

    membership <- apply(gammas,1,function(x) order(x,decreasing=TRUE)[1])
    res <- list(pars=theta, thetas=thetas , gammas=gammas, member=membership,
                members=members, k=k, D=D, data=data, E=E,
                prob=rbind(colMeans(gammas)),
              iter=iter,
              models=models,
              multigroup=multigroup(models,datas)
              )
    class(res) <- c("mvn.mixture","lvm.mixture")

    parpos <- c()
    npar1 <- D+D*(D-1)/2
    for (i in 1:k)
        parpos <- c(parpos, list(c(seq_len(D)+(i-1)*D, k*D + seq_len(npar1)+
                                                       (i-1)*(npar1))))

    theta <- c(unlist(lapply(getMeanVar(res),function(x) x$mean)),
             unlist(lapply(getMeanVar(res),function(x) c(diag(x$var),unlist(x$var[upper.tri(x$var)])))))
    res$theta <- rbind(theta)
    res$parpos <- parpos
    res$opt <- list(estimate=theta)
    if (requireNamespace('mets', quietly=TRUE))
        res$vcov <- solve(information(res,type="E"))
    return(res)
}


##' @export
print.mvn.mixture <- function(x,...) {
  par <- toPar(x$pars,x$D,x$k)
  space <- paste(rep(" ",12),collapse="")
  for (i in 1:x$k) {
    cat("Cluster ",i," (p=",formatC(par$p[i]),"):\n",sep="")
    cat(rep("-",50),"\n",sep="")
    cat("\tcenter = \n ",space,paste(formatC(par$mu[i,]),collapse=" "),sep="")
    cat("\n\tvariance = \t");
    V <- matrix(formatC(par$Sigma[i,],flag=" "),ncol=x$D);
    colnames(V) <- rep("",x$D); rownames(V) <- rep(space,x$D)
    print(V, quote=FALSE)
    cat("\n")
  }
  invisible(par)
}

##' @export
plot.mvn.mixture <- function(x, label=2,iter,col,alpha=0.5,nonpar=TRUE,...) {
  opts <- list(...)
  ##  cols <- opts$col; if(is.null(cols)) cols <- 1:gmfit$k
  if (missing(col)) col <- 1:x$k
  lwd <- opts$lwd; if (is.null(lwd)) lwd <- 2
  cex <- opts$cex; if(is.null(cex)) cex <- 0.9
  y <- as.matrix(x$data)
  if (is.vector(y)) y <- matrix(y,ncol=1)
  pp <- getMeanVar(x,iter=iter)
  D <- ncol(y)
  pi <- colSums(x$gammas)/nrow(x$gammas)

  if (D==1) {
    if (nonpar)
      plot(density(as.vector(y)), main="", ...)
    else
      plot(density(y), main="", type="n", col="lightgray", ...)
    if (!is.null(label)) {
      for (i in 1:x$k) {
        rug(y[x$member==i], col=col[i])
      }
    }
    else
      rug(y)
    cc <- par("usr")
    {
      mycurve <- function(xx) {
        a <- 0;
        for (i in 1:(x$k))
          a <- a+pi[i]*dnorm(xx,pp[[i]]$mean,sqrt(pp[[i]]$var[1]))
        a
      }
      curve(mycurve, from=cc[1], to=cc[2], add=TRUE, lwd=lwd,...)
    }
  }
  if (D==2) {
      if (!requireNamespace("ellipse")) stop("ellipse required")
    plot(y, type="n", ...)

    for (i in 1:x$k) {
      C1 <- with(pp[[i]], ellipse::ellipse(var, centre=mean))
      lines(C1, col=col[i], lwd=lwd)
    }

    if (!is.null(label)) {
      for (i in 1:x$k) {
          if (label==1 | missing(iter)) {
          pot <- y[which(x$member==i),]
        }
        else {
          pot <- y[which(x$members[,iter]==i),]
        }
        points(pot, cex=cex, pch=16, col=Col(col[i],alpha))
      }
    }
    else
      points(y, cex=cex)
  }
  if (D==3) {
    if (!requireNamespace("rgl")) stop("rgl required")
    rgl::plot3d(y, type="n", box=FALSE)
    for (i in 1:x$k) {
        pot <- y[which(x$member==i),]
        rgl::plot3d(pot, type="s", radius=0.1, col=col[i], add=TRUE)
        ee <- rgl::ellipse3d(pp[[i]]$var,centre=pp[[i]]$mean)
        rgl::plot3d(ee, col=col[i], alpha=alpha, add = TRUE)
    }
  }
}

##' @export
sim.mvn.mixture <- function(x,n,...) {
    pars <- getMeanVar(x)
    K <- length(pars)
    p <- tail(coef(x),K-1); p <- c(p,1-sum(p))
    ng <- as.vector(rmultinom(1,n,p))
    res <- c()
    for (i in seq(K)) {
        res <- rbind(res,
                     mets::rmvn(ng[i],pars[[i]]$mean,pars[[i]]$var))
    }
    return(res)
}

Try the lava package in your browser

Any scripts or data that you put into this service are public.

lava documentation built on Nov. 5, 2023, 1:10 a.m.