em_solve: EM algorithm

View source: R/stats.R

em_solveR Documentation

EM algorithm

Description

Run the EM algorithm.

Usage

em_solve(
  data,
  params,
  em_Estep,
  em_Mstep,
  em_loglik,
  thresh.cvg = 10^(-3),
  nb.iters = 10^3,
  logliks = rep(NA, 1 + nb.iters),
  print_position = stdout()
)

Arguments

data

object containing the data

params

list with the initial values of the parameters

em_Estep

function implementing the E step taking data and params as inputs; can be parallelized

em_Mstep

function implementing the M step taking data, params and the output of em_Estep as inputs; can be parallelized

em_loglik

function taking data and params as inputs, and returning the value of the observed-data log-likelihood as a numeric; can be parallelized

thresh.cvg

threshold on the absolute difference between the observed-data log-likelihood of two successive iterations below which convergence is reached

nb.iters

number of iterations

logliks

vector in which the value of the observed-data log-likelihood will be recorded

print_position

if not NULL, should be a connection to which a debugging log will be printed

Value

list with MLEs of the parameters and values of the observed-data log-likelihood

Author(s)

Timothee Flutre

Examples

## Not run: ## I. example of the EM algorithm for the univariate Gaussian mixture

## I.1. simulate some data
simulDat <- function(K=2, N=100, gap=6){
  means <- seq(0, gap*(K-1), gap)
  stdevs <- runif(n=K, min=0.5, max=1.5)
  tmp <- floor(rnorm(n=K-1, mean=floor(N/K), sd=5))
  ns <- c(tmp, N - sum(tmp))
  memberships <- as.factor(matrix(unlist(lapply(1:K, function(k){rep(k, ns[k])})),
                           ncol=1))
  data <- matrix(unlist(lapply(1:K, function(k){
    rnorm(n=ns[k], mean=means[k], sd=stdevs[k])
  })))
  new.order <- sample(1:N, N)
  data <- data[new.order]
  rownames(data) <- NULL
  memberships <- memberships[new.order]
  return(list(data=data, memberships=memberships,
              means=means, stdevs=stdevs, weights=ns/N))
}
set.seed(1859)
K <- 3
N <- 300
simul <- simulDat(K, N)
simul$means
simul$stdevs
simul$weights

## I.2. visualize the data
hist(simul$data, breaks=30, freq=FALSE, col="grey", border="white",
     main="Simulated data from univariate Gaussian mixture",
     ylab="", xlab="data", las=1,
     xlim=c(-4,15), ylim=c(0,0.28))

## I.3. define functions required to run the EM algorithm

loglik <- function(data, params){
  sum(sapply(data, function(datum){
    log(sum(unlist(Map(function(mu, sigma, weight){
      weight * dnorm(x=datum, mean=mu, sd=sigma)
    }, params$means, params$stdevs, params$weights))))
  }))
}
loglik(simul$data, simul[-c(1,2)])

## function performing the E step
stepE <- function(data, params){
  N <- length(data)
  K <- length(params$means)
  tmp <- matrix(unlist(lapply(data, function(datum){
    norm.const <- sum(unlist(Map(function(mu, sigma, weight){
      weight * dnorm(x=datum, mean=mu, sd=sigma)
    }, params$means, params$stdevs, params$weights)))
    unlist(Map(function(mu, sigma, weight){
        weight * dnorm(x=datum, mean=mu, sd=sigma) / norm.const
      }, params$means[-K], params$stdevs[-K], params$weights[-K]))
  })), ncol=K-1, byrow=TRUE)
  membership.probas <- cbind(tmp, apply(tmp, 1, function(x){1 - sum(x)}))
  names(membership.probas) <- NULL
  return(membership.probas)
}
head(mb.pr <- stepE(simul$data, simul[-c(1,2)]))

stepM <- function(data, params, out.stepE){
  N <- length(data)
  K <- length(params$means)
  sum.membership.probas <- apply(out.stepE, 2, sum)
  ## MLEs of the means
  new.means <- sapply(1:K, function(k){
    sum(unlist(Map("*", out.stepE[,k], data))) /
    sum.membership.probas[k]
  })
  ## MLEs of the standard deviations
  new.stdevs <- sapply(1:K, function(k){
      sqrt(sum(unlist(Map(function(p.ki, x.i){
      p.ki * (x.i - new.means[k])^2
    }, out.stepE[,k], data))) /
    sum.membership.probas[k])
  })
  ## MLEs of the weights
  new.weights <- sapply(1:K, function(k){
    1/N * sum.membership.probas[k]
  })
  return(list(means=new.means, stdevs=new.stdevs, weights=new.weights))
}
stepM(simul$data, simul[-c(1,2)], mb.pr)

## I.4. run the EM algorithm
params0 <- list(means=runif(n=K, min=min(simul$data), max=max(simul$data)),
                stdevs=rep(1, K),
                weights=rep(1/K, K))
fit <- em_solve(data=simul$data, params=params0,
                em_Estep=stepE, em_Mstep=stepM, em_loglik=loglik)

## I.5. plot the log likelihood per iteration
plot(fit$logliks, xlab="iterations", ylab="observed log-likelihood",
     main="Convergence of the EM algorithm", type="b")

## I.6. plot the data along with the inferred density
hist(simul$data, breaks=30, freq=FALSE, col="grey", border="white",
     main="Simulated data from univariate Gaussian mixture",
     ylab="", xlab="data", las=1,
     xlim=c(-4,15), ylim=c(0,0.28))
rx <- seq(from=min(simul$data), to=max(simul$data), by=0.1)
ds <- lapply(1:K, function(k){dnorm(x=rx, mean=fit$params$means[k], sd=fit$params$stdevs[k])})
f <- sapply(1:length(rx), function(i){
  fit$params$weights[1] * ds[[1]][i] + fit$params$weights[2] * ds[[2]][i] + fit$params$weights[3] * ds[[3]][i]
})
lines(rx, f, col="red", lwd=2)

## I.7. look at the classification of the data
mb.pr <- stepE(simul$data, fit$params)
memberships <- apply(mb.pr, 1, function(x){which(x > 0.8)})
table(memberships)

## II. example of the EM algorithm for the univariate linear mixed model
## y = X beta + Z u + epsilon

## II.1. simulate some data
simulDat <- function(I=4, Q=100, mu=50, min.y=20, cv.u=0.15, h2=0.5){
  N <- I * Q
  var.y <- ((mu - min.y) / 3)^2
  var.u <- (cv.u * mu)
  var.epsilon <- ((1 - h2) / h2) * var.u
  data <- data.frame(lev.u=rep(paste0("u", 1:Q), I),
                     fact=rep(paste0("fact", 1:I), each=Q),
                     resp=NA)
  X <- model.matrix(~ 1 + fact, data=data)
  sd.beta <- NA
  beta <- c(mu, rnorm(n=I-1, mean=0, sd=sd.beta))
  Z <- model.matrix(~ -1 + lev.u, data=data)
  G <- var.u * diag(Q)
  u <- MASS::mvrnorm(n=1, mu=rep(0,Q), Sigma=G)
  R <- var.epsilon * diag(N)
  epsilon <- MASS::mvrnorm(n=1, mu=rep(0,N), Sigma=R)
  y <- X %*% beta + Z %*% u + epsilon
  data$resp <- y
  return(list(data=data, X=X, beta=beta, Z=Z, u=u,
              var.u=var.u, var.epsilon=var.epsilon))
}
set.seed(1859)
I <- 4
Q <- 100
simul <- simulDat(I, Q)
simul$var.u
simul$var.epsilon

## II.2. visualize the data
hist(simul$data$resp, breaks="FD", col="grey", border="white",
     main="Simulated data from univariate linear mixed model",
     ylab="", xlab="response", las=1)

## II.3. define functions required to run the EM algorithm
## TODO

loglik <- function(data, params){
}

stepE <- function(data, params){
}

stepM <- function(data, params, out.stepE){
}

## II.4. run the EM algorithm
params0 <- list()
fit <- em_solve(data=simul$data, params=params0,
                em_Estep=stepE, em_Mstep=stepM, em_loglik=loglik)

## II.5. plot the log likelihood per iteration
plot(fit$logliks, xlab="iterations", ylab="observed log-likelihood",
     main="Convergence of the EM algorithm", type="b")

## III. other possible examples of the EM algorithm
## * factor analysis
## * hidden Markov models
## * multivariate Student distribution
## * robust regression models
## * censored/truncated models
## * ...

## End(Not run)

timflutre/rutilstimflutre documentation built on Feb. 7, 2024, 8:17 a.m.