Procrustes-Wasserstein Distance
pwdist.RdGiven two empirical measures $$\mu = \sum_{m=1}^M \mu_m \delta_{X_m}\quad\textrm{and}\quad \nu = \sum_{n=1}^N \nu_n \delta_{Y_n}$$ in \(\mathbb{R}^P\), the Procrustes-Wasserstein (PW) distance is defined as follows: $$ PW_2^2(\mu, \nu) = \min_{Q\in \mathcal{O}(P)} W_2^2(\mu, Q_\# \nu), $$ where \(\mathcal{O}(P)\) is the orthogonal group and \(Q_\#\nu\) is the pushforward via \(Q\).
Arguments
- X
an \((M\times P)\) matrix of row observations.
- Y
an \((N\times P)\) matrix of row observations.
- wx
a length-\(M\) marginal density that sums to \(1\). If
NULL(default), uniform weight is set.- wy
a length-\(N\) marginal density that sums to \(1\). If
NULL(default), uniform weight is set.- ...
extra parameters including
- maxiter
maximum number of iterations (default: 496).
- abstol
stopping criterion for iterations (default: 1e-10).
Value
a named list containing
- distance
the computed PW distance value.
- plan
an \((M\times N)\) nonnegative matrix for the optimal transport plan.
- alignment
an optimal alignment matrix of size \((P\times P)\) in \(\mathcal{O}(P)\).
References
Adamo D, Corneli M, Vuillien M, Vila E (2025). “An in Depth Look at the Procrustes-Wasserstein Distance: Properties and Barycenters.” In Forty-Second International Conference on Machine Learning.
Examples
if (FALSE) { # \dontrun{
#-------------------------------------------------------------------
# Description
#
# * class 1 : samples from N((0,0), diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), diag(c(1/4,4))) randomly rotated
#
# We draw 10 empirical measures from each and compare
# the regular Wasserstein and PW distance.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)
# prepare empty lists
inputs = vector("list", length=30)
# generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
inputs[[i]] = matrix(rnorm(50*2), ncol=2)
}
for (j in 11:20){
base_draw = matrix(rnorm(50*2), ncol=2)
base_draw[,1] = base_draw[,1] + 10
inputs[[j]] = base_draw
inputs[[j+10]] = base_draw%*%random_rot
}
## COMPUTE
# empty arrays
dist_RW = array(0, c(30, 30))
dist_PW = array(0, c(30, 30))
# compute pairwise distances
for (i in 1:29){
for (j in (i+1):30){
dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(inputs[[i]], inputs[[j]])$distance
dist_PW[i,j] <- dist_PW[j,i] <- pwdist(inputs[[i]], inputs[[j]])$distance
}
}
## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,2), pty="s")
image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance")
image(dist_PW, xaxt="n", yaxt="n", main="PW distance")
par(opar)
} # }