View source: R/dist_wass_ipot.R
ipot | R Documentation |
Due to high computational cost for linear programming approaches to compute
Wasserstein distance, Cuturi (2013) proposed an entropic regularization
scheme as an efficient approximation to the original problem. This comes with
a regularization parameter \lambda > 0
in the term
\lambda h(\Gamma) = \lambda \sum_{m,n} \Gamma_{m,n} \log (\Gamma_{m,n}).
IPOT algorithm is known to be relatively robust to the choice of
regularization parameter \lambda
. Empirical observation says that
very small number of inner loop iteration like L=1
is sufficient.
ipot(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)
ipotD(D, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
wx |
a length- |
wy |
a length- |
lambda |
a regularization parameter (default: 0.1). |
... |
extra parameters including
|
D |
an |
a named list containing
\mathcal{W}_p
distance value
the number of iterations it took to converge.
an (M\times N)
nonnegative matrix for the optimal transport plan.
xie_fast_2020T4transport
#-------------------------------------------------------------------
# Wasserstein Distance between Samples from Two Bivariate Normal
#
# * class 1 : samples from Gaussian with mean=(-1, -1)
# * class 2 : samples from Gaussian with mean=(+1, +1)
#-------------------------------------------------------------------
## SMALL EXAMPLE
set.seed(100)
m = 20
n = 30
X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X
Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y
## COMPARE WITH WASSERSTEIN
outw = wasserstein(X, Y)
ipt1 = ipot(X, Y, lambda=1)
ipt2 = ipot(X, Y, lambda=10)
## VISUALIZE : SHOW THE PLAN AND DISTANCE
pmw = paste0("wasserstein plan ; dist=",round(outw$distance,2))
pm1 = paste0("ipot lbd=1 ; dist=",round(ipt1$distance,2))
pm2 = paste0("ipot lbd=10; dist=",round(ipt2$distance,2))
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,3))
image(outw$plan, axes=FALSE, main=pmw)
image(ipt1$plan, axes=FALSE, main=pm1)
image(ipt2$plan, axes=FALSE, main=pm2)
par(opar)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.