Optimal Transport Plans
This tutorial builds intuition for what SDOT computes, how the Newton solver works, and what quantities you can extract from an OT plan.
The Problem
Given:
- f — a discrete measure: a finite collection of weighted Dirac masses at positions
with masses (with ) - g — a continuous density on some domain
Find the transport map
Power Diagrams (Laguerre Cells)
The optimal transport map in the semi-discrete setting is always a power diagram (also called a Laguerre tessellation): a partition of
When all weights
The transport plan then simply sends every point
The Newton Solver
SDOT finds the optimal weights by solving:
This is a smooth, strictly concave system — Newton's method converges in a handful of iterations in practice (typically 5–20). The key cost is computing the power diagram and its integrals at each step, which runs in O(n log n) time.
Reference: Kitagawa, Mérigot, Thibert — Convergence of a Newton algorithm for semi-discrete optimal transport, JEMS 2019.
Computing a Plan
from sdot import SplineGrid, SumOfDiracs, optimal_transport_plan
import numpy as np
f = SumOfDiracs( np.random.rand( 300, 2 ) )
g = SplineGrid( np.random.rand( 10, 10 ) )
plan = optimal_transport_plan( f, g )Extracting quantities
plan.distance # W2^2 transport cost (scalar)
plan.barycenters # centroid of each Laguerre cell — shape (n, d)
plan.cell_masses # mass of each cell — shape (n,)
plan.brenier_potential # dual potential psi — shape (n,)
plan.power_diagram # the underlying PowerDiagram objectThe Brenier potential
The Wasserstein Distance
The scalar plan.distance equals:
It is also the negative of the dual objective at the optimal weights:
For a shortcut when you only need the distance:
from sdot import distance
d = distance( f, g ) # same as plan.distance, but doesn't store the full planBarycenters and the Lloyd Algorithm
The barycenters plan.barycenters[i] are the centroids of each transport cell:
Moving each Dirac to its barycenter and repeating gives the Lloyd algorithm, which converges to an optimal quantization of
positions = np.random.rand( 200, 2 )
for _ in range( 30 ):
f = SumOfDiracs( positions )
plan = optimal_transport_plan( f, g )
positions = plan.barycenters # Lloyd stepGradients
plan.distance (and distance(f, g)) are differentiable with respect to the Dirac positions and masses. The gradient with respect to positions
where
In SDOT, this gradient is computed automatically through the JAX/PyTorch autodiff:
import jax
grad_positions = jax.grad( lambda pos: distance( SumOfDiracs( pos ), g ) )( positions )
# grad_positions[i] ~ masses[i] * (positions[i] - barycenters[i])