pwdist: Procrustes-Wasserstein Distance

View source: R/dist_pwdist.R

pwdistR Documentation

Procrustes-Wasserstein Distance

Description

Given two empirical measures

\mu = \sum_{m=1}^M \mu_m \delta_{X_m}\quad\textrm{and}\quad \nu = \sum_{n=1}^N \nu_n \delta_{Y_n}

in \mathbb{R}^P, the Procrustes-Wasserstein (PW) distance is defined as follows:

PW_2^2(\mu, \nu) = \min_{Q\in \mathcal{O}(P)} W_2^2(\mu, Q_\# \nu),

where \mathcal{O}(P) is the orthogonal group and Q_\#\nu is the pushforward via Q.

Usage

pwdist(X, Y, wx = NULL, wy = NULL, ...)

Arguments

X

an (M\times P) matrix of row observations.

Y

an (N\times P) matrix of row observations.

wx

a length-M marginal density that sums to 1. If NULL (default), uniform weight is set.

wy

a length-N marginal density that sums to 1. If NULL (default), uniform weight is set.

...

extra parameters including

maxiter

maximum number of iterations (default: 496).

abstol

stopping criterion for iterations (default: 1e-10).

Value

a named list containing

distance

the computed PW distance value.

plan

an (M\times N) nonnegative matrix for the optimal transport plan.

alignment

an optimal alignment matrix of size (P\times P) in \mathcal{O}(P).

References

\insertRef

adamo_2025_DepthLookProcrustesWassersteinT4transport

Examples

## Not run: 
#-------------------------------------------------------------------
#                           Description
#
# * class 1 : samples from N((0,0),  diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), diag(c(1/4,4))) randomly rotated
#
#  We draw 10 empirical measures from each and compare 
#  the regular Wasserstein and PW distance.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)

#  prepare empty lists
inputs = vector("list", length=30)

#  generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
  inputs[[i]] = matrix(rnorm(50*2), ncol=2)
}
for (j in 11:20){
  base_draw = matrix(rnorm(50*2), ncol=2)
  base_draw[,1] = base_draw[,1] + 10
  
  inputs[[j]] = base_draw
  inputs[[j+10]] = base_draw%*%random_rot
}

## COMPUTE
#  empty arrays
dist_RW = array(0, c(30, 30))
dist_PW = array(0, c(30, 30))

#  compute pairwise distances
for (i in 1:29){
  for (j in (i+1):30){
  dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(inputs[[i]], inputs[[j]])$distance
  dist_PW[i,j] <- dist_PW[j,i] <- pwdist(inputs[[i]], inputs[[j]])$distance
  }
}

## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,2), pty="s")
image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance")
image(dist_PW, xaxt="n", yaxt="n", main="PW distance")
par(opar)

## End(Not run)


T4transport documentation built on Jan. 11, 2026, 9:07 a.m.