location_scatter_bary: Computes the 2-Wasserstein barycenter of location-scatter...

Description Usage Arguments Value References Examples

View source: R/location_scatter_bary.R

Description

This function solves the 2-Wasserstein barycenter problem of N measures from a location-scatter family, where each data distribution is given by a mean vector and a covariance matrix. In particular, this can be used to compute the barycenter of Gaussian distributions.

Usage

1
2
3
4
5
6
7
location_scatter_bary(
  means,
  cov,
  thresh = 10^(-5),
  maxiter = 100,
  showIter = FALSE
)

Arguments

means

A list of mean vectors of the elements of the location-scatter family.

cov

A list of semipositive-definite covariance matrices of the elements of the location-scatter family.

thresh

A real number specifying the threshold for terminating the iterative algorithm.

maxiter

An integer specifying after how many iterations the algorithm should be terminated even if the specified threshold has not been reached.

showIter

A boolean specifying whether the number of performed iterations should be shown at the end.

Value

A list of two elements. The first element "mean" gives the mean vector of the barycenter measure. The second elemenent "cov" gives the covariance matrix of the barycenter measure.

References

PC Álvarez-Esteban, E del Barrio, JA Cuesta-Albertos, and C Matrán (2016). A fixed-point approach to barycenters in Wasserstein space. J. Math. Anal. Appl., 441(2):744–762.
Y Zemel and VM Panaretos (2019). Fréchet Means and Procrustes Analysis in Wasserstein Space. Bernoulli 25(2):932-976.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#One dimensional example
mean.list<-list(5,15)
var.list<-list(1,4)
res<-location_scatter_bary(mean.list,var.list)
x<-seq(0,22,10^-4)
y1<-dnorm(x,mean=5,sd=1)
y2<-dnorm(x,mean=15,sd=2)
y3<-dnorm(x,mean=res$mean,sd=sqrt(res$cov))
plot(x,y1,type="l",main = "Barycenter of two 1-d Gaussian distributions",
ylab = "density",col="blue")
lines(x,y2,col="green")
lines(x,y3,col="red")
legend(15,0.4, legend=c("N(5,1)", "N(15,4)","Barycenter"),
       col=c("blue","green","red"),lty=1,cex=0.9)
       
       
#two dimensional example
# packages graphics and mvtnorm are required to run this example
set.seed(2898581)
mean.list <- list(c(0,0), c(0,0), c(0,0))
COV <- 0.3 + rWishart(3, df = 2, Sigma = diag(2))
cov.list <- list(COV[,, 1], COV[,, 2], COV[,, 3])
res<-location_scatter_bary(mean.list, cov.list)

x <- y <- seq(-3, 3, .1)
z <- array(0.0, dim = c(length(x), length(y), 4))
for(i in seq_along(x))
 for(j in seq_along(y))
 {
   for(n in 1:3)
     z[i, j, n] <- mvtnorm::dmvnorm(c(x[i], y[j]), sigma = COV[, , n])
   
   z[i, j, 4] <- mvtnorm::dmvnorm(c(x[i], y[j]), sigma = res$cov)
 }

op <- par(mfrow = c(2, 2),  mai = c(0, 0, 0, 0))
for(n in 1:3)
{
 graphics::persp(x, y, z[, , n], theta = 30, phi = 30, expand = 0.5, col = "lightblue",
  zlab = "", ticktype = "detailed", shade = .75, lphi = 45, ltheta = 135)
 text(x = 0, y = 0.2, labels = paste("COV[,, ", n, "]", sep = ""))
}

graphics::persp(x, y, z[, , 4], theta = 30, phi = 30, expand = 0.5, col = "red", zlab = "",
ticktype = "detailed", shade = .75, lphi = 45, ltheta = 135)
text(x = 0, y = 0.2, labels = "barycenter")
par(op)

WSGeometry documentation built on Dec. 15, 2021, 1:08 a.m.