This package implements the EM Algorithm for Gaussian Mixture Models. For the details of the algorithm, you can refer to the wiki page for EM algorithm and GMM model (https://en.wikipedia.org/wiki/EM_algorithm_and_GMM_model). With this package, users can fit to data a k-component Gaussian Mixture Model, where k is specified by the user. Initial points are also needed as a part of input. While any points can be chosen as the initial point, points generated by the k-means algorithm using the same set of data are highly recommended. The package mvtnorm is needed for the package, and ClusterR is needed for vignettes.
You can install this package using the following command.
install.packages('devtools')
devtools::install_github('graysonma/biostat625hw4.GaussianMixtureModel', build_vignettes = T)
Then, load the package with
library('biostat625hw4.GaussianMixtureModel')
The main function that users can access is 'GaussianMixtureModel'. You need a N by p data matrix, the number of components k, a k by p matrix specifying the initial points as input arguments. You may also determine the maximum number of EM iterations and the convergence criterion. For the details of the arguments and returns, you can refer to the help page after loading the package.
library('biostat625hw4.GaussianMixtureModel')
?GaussianMixtureModel
Here is an example of fitting a Gaussian Mixture Model on simulated data with the function.
# create samples from a two-component Gaussian Mixture
X = matrix(0, nrow = 500, ncol = 2)
z = sample(2, 500, replace = TRUE)
X[which(z == 1), ] = rnorm(sum(z == 1) * 2, mean = 0, sd = 1)
X[which(z == 2), ] = rnorm(sum(z == 2) * 2, mean = 5, sd = 1)
# fit a Gaussian Mixture Model to the data
gmm = GaussianMixtureModel(X, 2, initial_mu = matrix(c(0, 1, 0, 1), nrow = 2, ncol = 2))
# prediction
centroids = gmm$mu
r = gmm$r
cluster = apply(r, 1, which.max) # fitted cluster
acc = max(sum(cluster == z), sum(cluster != z)) / 500 # accuracy
# plot
plot(X[, 1], X[, 2], col = c('orange', 'red')[cluster], cex = 0.5)
points(x = centroids[, 1], y = centroids[, 2], pch = 10, col = 'blue', cex = 2)
legend(x = 'topleft',
legend = c('cluster 1', 'cluster 2', 'centers'),
col = c('orange', 'red', 'blue'),
pch = c(1, 1, 10))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.