Procrustes-Wasserstein Barycenter
pwbary.RdFor a collection of empirical measures \(\lbrace \mu_k\rbrace_{k=1}^K\), this function computes the Procrustes-Wasserstein (PW) barycenter (Adamo et al. 2025) , which accounts for both measure transport and alignment through action of the orthogonal group.
Arguments
- atoms
a length-\(K\) list where each element is an \((N_k \times P)\) matrix of atoms.
- marginals
marginal distributions for empirical measures; if
NULL(default), uniform weights are set for all measures. Otherwise, it should be a length-\(K\) list where each element is a length-\(N_i\) vector of nonnegative weights that sum to 1.- weights
weights for each individual measure; if
NULL(default), each measure is considered equally. Otherwise, it should be a length-\(K\) vector.- num_support
the number of support points \(M\) for the PW barycenter (default: 100).
- ...
extra parameters including
- abstol
stopping criterion for iterations (default: 1e-6).
- maxiter
maximum number of iterations (default: 10).
Value
a list with three elements:
- support
an \((M \times P)\) matrix of the PW barycenter's support points.
- weight
a length-\(M\) vector of median's weights with all entries being \(1/M\).
References
Adamo D, Corneli M, Vuillien M, Vila E (2025). “An in Depth Look at the Procrustes-Wasserstein Distance: Properties and Barycenters.” In Forty-Second International Conference on Machine Learning.
Examples
if (FALSE) { # \dontrun{
#-------------------------------------------------------------------
# Free-Support PW Barycenter of Multiple Gaussians
#
# * class 1 : samples from N((0,0), diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), Id) randomly rotated
#
# We draw 10 empirical measures from each and compare
# their barycenters under the regular and PW geometries.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)
# prepare empty lists
input_1 = vector("list", length=10L)
input_2 = vector("list", length=10L)
input_3 = vector("list", length=10L)
# generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
input_1[[i]] = cbind(rnorm(50, sd=2), rnorm(50, sd=0.5))
}
for (j in 1:10){
base_draw = cbind(rnorm(50, sd=0.5), rnorm(50, sd=2))
base_draw[,1] = base_draw[,1] + 10
input_2[[j]] = base_draw
input_3[[j]] = base_draw%*%random_rot
}
## COMPUTE
# regular Wasserstein barycenters
regular_1 = rbaryGD(input_1, num_support=50)
regular_2 = rbaryGD(input_2, num_support=50)
regular_3 = rbaryGD(input_3, num_support=50)
# Procrustes-Wasserstein barycenters
pw_1 = pwbary(input_1, num_support=50)
pw_2 = pwbary(input_2, num_support=50)
pw_3 = pwbary(input_3, num_support=50)
## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(3,1))
# set the x- and y-limits for display
lim_x = c(-12, 12)
lim_y = c(-10, 5)
# plot prototypical measures per class
plot(input_1[[1]], pch=19, cex=0.5, col="gray80",
main="3 types of measures", xlab="", ylab="",
xlim=lim_x, ylim=lim_y)
points(input_2[[1]], pch=19, cex=0.5, col="gray50")
points(input_3[[1]], pch=19, cex=0.5, col="gray10")
# plot regular barycenters
plot(regular_1$support, pch=19, cex=0.5, col="blue",
main="Regular Wasserstein barycenters",
xlab="", ylab="", xlim=lim_x, ylim=lim_y)
points(regular_2$support, pch=19, cex=0.5, col="cyan")
points(regular_3$support, pch=19, cex=0.5, col="red")
# plot PW barycenters
plot(pw_1$support, pch=19, cex=0.5, col="blue",
main="Procrustes-Wasserstein barycenters",
xlab="", ylab="", xlim=lim_x, ylim=lim_y)
points(pw_2$support, pch=19, cex=0.5, col="cyan")
points(pw_3$support, pch=19, cex=0.5, col="red")
par(opar)
} # }