Description Usage Arguments Value References Examples
View source: R/location_scatter_bary.R
This function solves the 2-Wasserstein barycenter problem of N measures from a location-scatter family, where each data distribution is given by a mean vector and a covariance matrix. In particular, this can be used to compute the barycenter of Gaussian distributions.
1 2 3 4 5 6 7 | location_scatter_bary(
means,
cov,
thresh = 10^(-5),
maxiter = 100,
showIter = FALSE
)
|
means |
A list of mean vectors of the elements of the location-scatter family. |
cov |
A list of semipositive-definite covariance matrices of the elements of the location-scatter family. |
thresh |
A real number specifying the threshold for terminating the iterative algorithm. |
maxiter |
An integer specifying after how many iterations the algorithm should be terminated even if the specified threshold has not been reached. |
showIter |
A boolean specifying whether the number of performed iterations should be shown at the end. |
A list of two elements. The first element "mean" gives the mean vector of the barycenter measure. The second elemenent "cov" gives the covariance matrix of the barycenter measure.
PC Álvarez-Esteban, E del Barrio, JA Cuesta-Albertos, and C Matrán (2016).
A fixed-point approach to barycenters in Wasserstein space. J. Math. Anal.
Appl., 441(2):744–762.
Y Zemel and VM Panaretos (2019). Fréchet Means and Procrustes Analysis in Wasserstein Space. Bernoulli 25(2):932-976.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | #One dimensional example
mean.list<-list(5,15)
var.list<-list(1,4)
res<-location_scatter_bary(mean.list,var.list)
x<-seq(0,22,10^-4)
y1<-dnorm(x,mean=5,sd=1)
y2<-dnorm(x,mean=15,sd=2)
y3<-dnorm(x,mean=res$mean,sd=sqrt(res$cov))
plot(x,y1,type="l",main = "Barycenter of two 1-d Gaussian distributions",
ylab = "density",col="blue")
lines(x,y2,col="green")
lines(x,y3,col="red")
legend(15,0.4, legend=c("N(5,1)", "N(15,4)","Barycenter"),
col=c("blue","green","red"),lty=1,cex=0.9)
#two dimensional example
# packages graphics and mvtnorm are required to run this example
set.seed(2898581)
mean.list <- list(c(0,0), c(0,0), c(0,0))
COV <- 0.3 + rWishart(3, df = 2, Sigma = diag(2))
cov.list <- list(COV[,, 1], COV[,, 2], COV[,, 3])
res<-location_scatter_bary(mean.list, cov.list)
x <- y <- seq(-3, 3, .1)
z <- array(0.0, dim = c(length(x), length(y), 4))
for(i in seq_along(x))
for(j in seq_along(y))
{
for(n in 1:3)
z[i, j, n] <- mvtnorm::dmvnorm(c(x[i], y[j]), sigma = COV[, , n])
z[i, j, 4] <- mvtnorm::dmvnorm(c(x[i], y[j]), sigma = res$cov)
}
op <- par(mfrow = c(2, 2), mai = c(0, 0, 0, 0))
for(n in 1:3)
{
graphics::persp(x, y, z[, , n], theta = 30, phi = 30, expand = 0.5, col = "lightblue",
zlab = "", ticktype = "detailed", shade = .75, lphi = 45, ltheta = 135)
text(x = 0, y = 0.2, labels = paste("COV[,, ", n, "]", sep = ""))
}
graphics::persp(x, y, z[, , 4], theta = 30, phi = 30, expand = 0.5, col = "red", zlab = "",
ticktype = "detailed", shade = .75, lphi = 45, ltheta = 135)
text(x = 0, y = 0.2, labels = "barycenter")
par(op)
|
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.