wasserstein_bary: Compute Wasserstein barycenters

Description Usage Arguments Details References Examples

View source: R/barycenterMain.R

Description

This function computes the Wasserstein barycenter of a list of suitable objects and returns the barycenter in a prespecified form.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
wasserstein_bary(
  data.list,
  frechet.weights = NULL,
  method = "alternating",
  return_type = "wpp",
  supp.size = NULL,
  output.supp = NULL,
  shared = FALSE,
  sample.size = NULL,
  maxIter = 10,
  weights_maxIter = 100,
  pos_maxIter = 100,
  stepsize = 0.1,
  thresh = 10^(-12),
  regular = 10^-3,
  warmstart = TRUE,
  warmstartlength = 2,
  showIter = FALSE,
  threads = 1
)

Arguments

data.list

A list of objects of which the barycenter should be computed. Each element should be one of the following: A matrix, representing an image; A path to a file containing an image; A wpp-object; A pp-object; A list containing an entry named 'positions' with the support of the measure and an entry named 'weights' containing the weights of the support points; A list containing en entry named 'positions“ specifying the support of a measure with uniform weights.

frechet.weights

A real vector summing to 1, specifying the weights in the Frechet functional. Should be of the same length as data.list.

method

A string specifiying the method to be used. This also determines which of the other parameters are active/used in this function call. See details for the specific methods currently available.

return_type

A string specifying the format of the output. The currently available options are "default" (which gives list with entries 'positions' and 'weights'); "wpp"- which gives a wpp-object; and "image_mat" for a matrix of the same dimensions as the input matrices (only for the regular method).

supp.size

A positive integer specifying the size of the support used to approximate the barycenter in the "alternating" method.

output.supp

An Mxd matrix specifying the support set on which the optimal weights of the barycenter should be approximated when method = "fixed_support". Each row of the matrix represents one support point in R^d.

shared

A boolean flag specifying whether all measures have the same support set and the weights of the barycenter should be optimised over this set as well.

sample.size

A positive integer specifying the number of samples drawn in the stochastic approximation of the barycenter for method "sampling".

maxIter

A positive integer specifyng the maximum number of "outer" iterations. The full number of iteration steps performed is maxIter * (weights_maxIter+pos_maxIter).

weights_maxIter

A positive integer specifying the maximum number of iterations on the weights of the barycenter.

pos_maxIter

A positive integer specifying the maximum number of iterations on the support of the barycenter.

stepsize

A positive number specifying the stepsize in the position iterations.

thresh

A positive number specifying the minimal amount of change between iterations, which does not cause the algorithm to terminate.

regular

A positive number specifying the regularisation parameter in the "regular" method.

warmstart

A boolean specifying whether the algorithm should use a warmstart based on a stochastic subgradient descent.

warmstartlength

A positive integer specifying the length of the warmstart. The number of steps in the SGD in the warmstart is 'length(data.list)*warmstartlength'.

showIter

A boolean specifying whether the number of "outer" iterations performed should be shown at the end.

threads

A positive integer specifying the number of threads used for parallel computing.

Details

This is the main function of this package. It computes/approximates 2-Wasserstein barycenters using the different methods outlined in the following.
"lp". Here the barycenter problem can be posed as a linear program. This method builds and solves this linear program. While this gives exact solutions to the problem, this method is highly run-time extensive and should only be used on small datasets. (See Anderes et al. (2016) and Borgwardt & Patterson (2020) for details).
"regular". This method solves the entropy-regularised fixed support barycenter problem. Here, a penalisation term is introduced to the problem, which yields a strictly convex problem that can be solved with Sinkhorn's algorithm (for details see Benamou et al. (2015)). Additionally, it is assumed that all the measures have the same support set, and instead of an exact (regularised) barycenter, the methods finds the best solution having the same support as the data. This is quite reasonable when the dataset consists of images and the barycenter should be an image as well. The choice of the regularisation parameter "regular" is a delicate issue. Large values reduce the runtime, but yield "blurry" barycenters. Small values yield sharper results, but have longer run-time and may cause numerical instabilities. The choice of this parameter depends on the dataset at hand, and will typically require tuning.
"fixed_support". This method computes the best approximation of the barycenter, which is supported on a pre-specified support set (as supplied by the parameter "output.support"). Contrary to the "regular" method, here this set does not need to coincide with any of the support sets of the data measures. See Cuturi & Doucet (2014) for details.
"alternating". This method computes the best approximation of the barycenter with a certain support size. It alternates between finding the best positions for given weights and then finding the best weights for these positions. See Cuturi and Doucet (2014) for details.
"sampling". This method uses the SUA method of Heinemann et al. (2020) to generate a stochastic approximation of the barycenter. It replaces the original measures by empirical measures obtained from samples of size 'sample.size' from each data measure.
The unregularised optimal transport problems, which need to be solved for the iterative methods, without regularisation, in each iteration step, are solved using a fast network simplex implementation (which is a modification of the LEMON Library by Nicolas Bonneel).

References

E Anderes, S Borgwardt, and J Miller (2016). Discrete Wasserstein barycenters: Optimal transport for discrete data. Mathematical Methods of Operations Research, 84(2):389-409.
S Borgwardt and S Patterson (2020). Improved linear programs for discrete barycenters. Informs Journal on Optimization 2(1):14-33.
J-D Benamou, G Carlier, M Cuturi, L Nenna, and G Peyré (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing 37(2):A1111-A1138.
M Cuturi and A Doucet (2014). Fast Computation of Wasserstein Barycenters. Proceedings of the 31st International Conference on Machine Learning, PMLR 32(2):685-693.
F Heinemann, A Munk, and Y Zemel (2020). Randomised Wasserstein barycenter computation: Resampling with statistical guarantees. arXiv preprint.
N. Bonneel (2018). Fast Network Simplex for Optimal Transport.
Github repository, nbonneel/network_simplex.
N. Bonneel, M. van de Panne, S. Paris and W. Heidrich (2011). Displacement interpolation using Lagrangian mass transport. ACM Transactions on Graphics (SIGGRAPH ASIA 2011) 30(6).

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
##Basic Examples
#build list
K<-1
N<-4*K
M<-9
d<-2
data.list<-vector("list",N)

###image_mat
for (i in 1:K){
  U<-runif(M)
  U<-U/sum(U)
  data.list[[i]]<-matrix(U,sqrt(M))
}

#wpp
for (i in (K+1):(2*K)){
  U<-runif(M)
  U<-U/sum(U)
  pos<-matrix(runif(d*M),M,d)
  data.list[[i]]<-transport::wpp(pos,U)
}

#point pattern
for (i in (2*K+1):(3*K)){
  pos<-matrix(runif(d*M),M,d)
  data.list[[i]]<-list(positions=pos)
}

#weighted point pattern

for (i in (3*K+1):(4*K)){
  U<-runif(M)
  U<-U/sum(U)
  pos<-matrix(runif(d*M),M,d)
  data.list[[i]]<-list(positions=pos,weights=U)
}

system.time(res1<-wasserstein_bary(data.list,return_type = "wpp",method="lp"))
frechet_func(res1,data.list)

system.time(res2<-wasserstein_bary(data.list,return_type = "wpp",method="alternating",
supp.size = M*N-N+1,warmstartlength = 3,pos_maxIter = 100,weights_maxIter = 100))
frechet_func(res2,data.list)

system.time(res3<-wasserstein_bary(data.list,return_type = "wpp",
method="fixed_support",warmstartlength = 3,weights_maxIter = 100,output.supp = res1$coordinates))
frechet_func(res3,data.list)
system.time(res4<-wasserstein_bary(data.list,return_type = "wpp",
method="sampling",sample.size=8,warmstartlength = 3,pos_maxIter = 100))
frechet_func(res4,data.list)

##Visual Examples
###ellipses
set.seed(420)
N<-20
supp.size<-10^2
L<-sqrt(supp.size)
d<-2
data.list<-vector("list",N)
image.list<-vector("list",N)
for (i in 1:N){
  t.vec<-seq(0,2*pi,length.out=supp.size)
  pos<-cbind(cos(t.vec)*runif(1,0.2,1),sin(t.vec)*runif(1,0.2,1))
  theta<-runif(1,0,2*pi)
  rotation<-matrix(c(cos(theta),sin(theta),-1*sin(theta),cos(theta)),2,2)
  pos<-pos%*%rotation
  pos<-pos+1
  pos<-pos/2
  W<-rep(1/supp.size,supp.size)
  data.list[[i]]<-transport::wpp(pos,W)
  I<-bin2d(data.list[[i]]$coordinates,data.list[[i]]$mass,c(L*2,L*2))
  I<-smear(I,1,1)
  I<-I/sum(I)
  image.list[[i]]<-I
}

system.time(res1<-wasserstein_bary(data.list,return_type = "wpp",method="alternating"
,supp.size = supp.size,warmstartlength = 0,pos_maxIter = 10,weights_maxIter = 10,maxIter = 10))
plot(res1)
system.time(res2<-wasserstein_bary(data.list,return_type = "wpp",
method="fixed_support",warmstartlength = 0,weights_maxIter = 50,maxIter=1,
output.supp = grid_positions(2*L,2*L)))
plot(res2)
system.time(res3<-wasserstein_bary(data.list,return_type = "wpp",method="sampling",
sample.size=400,warmstartlength = 0,pos_maxIter = 100,stepsize = 1,maxIter=1))
plot(res3)

system.time(res4<-wasserstein_bary(image.list,return_type = "wpp",
method="regular",stepsize = 1,weights_maxIter = 50))
plot(res4)
system.time(res5<-wasserstein_bary(image.list,return_type = "wpp",
method="fixed_support",shared=TRUE,warmstartlength = 0,weights_maxIter = 50,maxIter=1,
output.supp = grid_positions(2*L,2*L)))
plot(res5)

WSGeometry documentation built on Dec. 15, 2021, 1:08 a.m.