Skip to content

Optimal Transport Plans

This tutorial builds intuition for what SDOT computes, how the Newton solver works, and what quantities you can extract from an OT plan.


The Problem

Given:

  • f — a discrete measure: a finite collection of weighted Dirac masses at positions y1,,yn with masses m1,,mn (with imi=1)
  • g — a continuous density on some domain Ω

Find the transport map T:Ω{y1,,yn} that moves g onto f at minimum cost:

W22(f,g)=minTΩxT(x)2g(x)dxs.t.T1(yi)g(x)dx=mii

Power Diagrams (Laguerre Cells)

The optimal transport map in the semi-discrete setting is always a power diagram (also called a Laguerre tessellation): a partition of Ω into cells Ci(w), each associated to a Dirac yi, where membership is defined by a weighted distance:

xCi(w)xyi2wixyj2wjji

When all weights wi=0, this reduces to the standard Voronoi diagram. The weights "inflate" or "deflate" each cell to match the target mass.

The transport plan then simply sends every point x in cell Ci to position yi.


The Newton Solver

SDOT finds the optimal weights by solving:

Ci(w)g(x)dx=mii

This is a smooth, strictly concave system — Newton's method converges in a handful of iterations in practice (typically 5–20). The key cost is computing the power diagram and its integrals at each step, which runs in O(n log n) time.

Reference: Kitagawa, Mérigot, Thibert — Convergence of a Newton algorithm for semi-discrete optimal transport, JEMS 2019.


Computing a Plan

python
from sdot import SplineGrid, SumOfDiracs, optimal_transport_plan
import numpy as np

f = SumOfDiracs( np.random.rand( 300, 2 ) )
g = SplineGrid( np.random.rand( 10, 10 ) )

plan = optimal_transport_plan( f, g )

Extracting quantities

python
plan.distance          # W2^2 transport cost (scalar)
plan.barycenters       # centroid of each Laguerre cell — shape (n, d)
plan.cell_masses       # mass of each cell             — shape (n,)
plan.brenier_potential # dual potential psi            — shape (n,)
plan.power_diagram     # the underlying PowerDiagram object

The Brenier potential ψ is the dual variable: T(x)=x12ψ(x), and the transport cost is the Legendre transform of ψ integrated against g.


The Wasserstein Distance

The scalar plan.distance equals:

W22(f,g)=ΩxT(x)2g(x)dx=iCixyi2g(x)dx

It is also the negative of the dual objective at the optimal weights:

W22(f,g)=imiwiΩmaxi(yix12wi)g(x)dx

For a shortcut when you only need the distance:

python
from sdot import distance

d = distance( f, g )   # same as plan.distance, but doesn't store the full plan

Barycenters and the Lloyd Algorithm

The barycenters plan.barycenters[i] are the centroids of each transport cell:

bi=1miCi(w)xg(x)dx

Moving each Dirac to its barycenter and repeating gives the Lloyd algorithm, which converges to an optimal quantization of g:

python
positions = np.random.rand( 200, 2 )

for _ in range( 30 ):
    f    = SumOfDiracs( positions )
    plan = optimal_transport_plan( f, g )
    positions = plan.barycenters          # Lloyd step

Gradients

plan.distance (and distance(f, g)) are differentiable with respect to the Dirac positions and masses. The gradient with respect to positions yi is:

W22yi=mi(yibi)

where bi is the barycenter of cell i. This has a clean geometric interpretation: the gradient points from the barycenter toward the Dirac, and vanishes exactly when the Dirac is at its barycenter (i.e., at the Lloyd fixed point).

In SDOT, this gradient is computed automatically through the JAX/PyTorch autodiff:

python
import jax

grad_positions = jax.grad( lambda pos: distance( SumOfDiracs( pos ), g ) )( positions )
# grad_positions[i] ~ masses[i] * (positions[i] - barycenters[i])

What's Next

SDOT — H. Leclerc, Q. Mérigot, T. Gallouët · LMO / INRIA PARMA