Skip to contents

For a collection of empirical measures \(\lbrace \mu_k\rbrace_{k=1}^K\), this function computes the Procrustes-Wasserstein (PW) barycenter (Adamo et al. 2025) , which accounts for both measure transport and alignment through action of the orthogonal group.

Usage

pwbary(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)

Arguments

atoms

a length-\(K\) list where each element is an \((N_k \times P)\) matrix of atoms.

marginals

marginal distributions for empirical measures; if NULL (default), uniform weights are set for all measures. Otherwise, it should be a length-\(K\) list where each element is a length-\(N_i\) vector of nonnegative weights that sum to 1.

weights

weights for each individual measure; if NULL (default), each measure is considered equally. Otherwise, it should be a length-\(K\) vector.

num_support

the number of support points \(M\) for the PW barycenter (default: 100).

...

extra parameters including

abstol

stopping criterion for iterations (default: 1e-6).

maxiter

maximum number of iterations (default: 10).

Value

a list with three elements:

support

an \((M \times P)\) matrix of the PW barycenter's support points.

weight

a length-\(M\) vector of median's weights with all entries being \(1/M\).

References

Adamo D, Corneli M, Vuillien M, Vila E (2025). “An in Depth Look at the Procrustes-Wasserstein Distance: Properties and Barycenters.” In Forty-Second International Conference on Machine Learning.

Examples

if (FALSE) { # \dontrun{
#-------------------------------------------------------------------
#         Free-Support PW Barycenter of Multiple Gaussians
#
# * class 1 : samples from N((0,0),  diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), Id) randomly rotated
#
#  We draw 10 empirical measures from each and compare 
#  their barycenters under the regular and PW geometries.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)

#  prepare empty lists
input_1 = vector("list", length=10L)
input_2 = vector("list", length=10L)
input_3 = vector("list", length=10L)

#  generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
  input_1[[i]] = cbind(rnorm(50, sd=2), rnorm(50, sd=0.5))
}
for (j in 1:10){
  base_draw = cbind(rnorm(50, sd=0.5), rnorm(50, sd=2))
  base_draw[,1] = base_draw[,1] + 10
  
  input_2[[j]] = base_draw
  input_3[[j]] = base_draw%*%random_rot
}

## COMPUTE
#  regular Wasserstein barycenters
regular_1 = rbaryGD(input_1, num_support=50)
regular_2 = rbaryGD(input_2, num_support=50)
regular_3 = rbaryGD(input_3, num_support=50)

#  Procrustes-Wasserstein barycenters
pw_1 = pwbary(input_1, num_support=50)
pw_2 = pwbary(input_2, num_support=50)
pw_3 = pwbary(input_3, num_support=50)

## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(3,1))

#  set the x- and y-limits for display
lim_x = c(-12, 12)
lim_y = c(-10, 5)

#  plot prototypical measures per class
plot(input_1[[1]], pch=19, cex=0.5, col="gray80", 
     main="3 types of measures", xlab="", ylab="",
     xlim=lim_x, ylim=lim_y)
points(input_2[[1]], pch=19, cex=0.5, col="gray50")
points(input_3[[1]], pch=19, cex=0.5, col="gray10")

#  plot regular barycenters
plot(regular_1$support, pch=19, cex=0.5, col="blue", 
     main="Regular Wasserstein barycenters",
     xlab="", ylab="", xlim=lim_x, ylim=lim_y)
points(regular_2$support, pch=19, cex=0.5, col="cyan")
points(regular_3$support, pch=19, cex=0.5, col="red")

#  plot PW barycenters
plot(pw_1$support, pch=19, cex=0.5, col="blue", 
     main="Procrustes-Wasserstein barycenters",
     xlab="", ylab="", xlim=lim_x, ylim=lim_y)
points(pw_2$support, pch=19, cex=0.5, col="cyan")
points(pw_3$support, pch=19, cex=0.5, col="red")
par(opar)
} # }