Skip to content

API — OT Plan

distance( f, g, metric=None )

Computes the Wasserstein-2 squared distance between source f and target g.

Returns a scalar tensor in the active framework (JAX or PyTorch). Differentiable with respect to the positions and weights of f, and with respect to the parameters of g if supported.

Parameters

NameTypeDescription
fsource distributionSumOfDiracs, SumOfWeightedDiracs, or their batch/1d variants
gtarget distributionSplineGrid, PolynomialGrid, SumOfGaussians, …
metricmetric object or NoneGround metric (default: squared Euclidean Norm2)

Returns — scalar: W22(f, g)

python
from sdot import SumOfDiracs, PolynomialGrid, distance
import numpy as np

d = distance(
    SumOfDiracs( np.random.rand( 200, 2 ) ),
    PolynomialGrid( values = [[[1]]] ),
)

optimal_transport_plan( f, g, metric=None )

Computes the full OT plan. Returns an OtPlan object with all transport quantities.

Parameters — same as distance.

ReturnsOtPlan


OtPlan

The result of optimal_transport_plan.

AttributeShapeDescription
.distancescalarW22 transport cost
.barycenters(n, d)Centroid of each Laguerre cell
.cell_masses(n,)Integral of g over each cell
.brenier_potential(n,)Dual variable ψ (optimal weights)
.power_diagramPowerDiagramUnderlying power diagram object

All attributes are tensors in the active framework. .distance is differentiable.

SDOT — H. Leclerc, Q. Mérigot, T. Gallouët · LMO / INRIA PARMA