ipot: Wasserstein Distance by Inexact Proximal Point Method

View source: R/dist_wass_ipot.R

ipotR Documentation

Wasserstein Distance by Inexact Proximal Point Method

Description

Due to high computational cost for linear programming approaches to compute Wasserstein distance, Cuturi (2013) proposed an entropic regularization scheme as an efficient approximation to the original problem. This comes with a regularization parameter \lambda > 0 in the term

\lambda h(\Gamma) = \lambda \sum_{m,n} \Gamma_{m,n} \log (\Gamma_{m,n}).

IPOT algorithm is known to be relatively robust to the choice of regularization parameter \lambda. Empirical observation says that very small number of inner loop iteration like L=1 is sufficient.

Usage

ipot(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)

ipotD(D, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)

Arguments

X

an (M\times P) matrix of row observations.

Y

an (N\times P) matrix of row observations.

p

an exponent for the order of the distance (default: 2).

wx

a length-M marginal density that sums to 1. If NULL (default), uniform weight is set.

wy

a length-N marginal density that sums to 1. If NULL (default), uniform weight is set.

lambda

a regularization parameter (default: 0.1).

...

extra parameters including

maxiter

maximum number of iterations (default: 496).

abstol

stopping criterion for iterations (default: 1e-10).

L

small number of inner loop iterations (default: 1).

D

an (M\times N) distance matrix d(x_m, y_n) between two sets of observations.

Value

a named list containing

distance

\mathcal{W}_p distance value

iteration

the number of iterations it took to converge.

plan

an (M\times N) nonnegative matrix for the optimal transport plan.

References

\insertRef

xie_fast_2020T4transport

Examples


#-------------------------------------------------------------------
#  Wasserstein Distance between Samples from Two Bivariate Normal
#
# * class 1 : samples from Gaussian with mean=(-1, -1)
# * class 2 : samples from Gaussian with mean=(+1, +1)
#-------------------------------------------------------------------
## SMALL EXAMPLE
set.seed(100)
m = 20
n = 30
X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X
Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y

## COMPARE WITH WASSERSTEIN 
outw = wasserstein(X, Y)
ipt1 = ipot(X, Y, lambda=1)
ipt2 = ipot(X, Y, lambda=10)

## VISUALIZE : SHOW THE PLAN AND DISTANCE
pmw = paste0("wasserstein plan ; dist=",round(outw$distance,2))
pm1 = paste0("ipot lbd=1 ; dist=",round(ipt1$distance,2))
pm2 = paste0("ipot lbd=10; dist=",round(ipt2$distance,2))

opar <- par(no.readonly=TRUE)
par(mfrow=c(1,3))
image(outw$plan, axes=FALSE, main=pmw)
image(ipt1$plan, axes=FALSE, main=pm1)
image(ipt2$plan, axes=FALSE, main=pm2)
par(opar)



T4transport documentation built on April 12, 2023, 12:37 p.m.