Getting Started
Installation using pip
pip install sdotSDOT requires at least one backend. Install the one you prefer (or both):
pip install jax jaxlib # JAX backend
pip install torch # PyTorch backendGPU support — follow the standard JAX/PyTorch GPU installation instructions for your platform (CUDA, Metal, …). SDOT will automatically use the device your tensors are on.
Internally, SDOT uses C++, Cuda and Metal, through dynamic libraries. The packages come bundled with the libraries for the most common case
For more more exotic cases, SDOT may have to generate and compile some code. In this case, one will need a compiler toolchain (SDOT uses xmake as build system).
Installation from the sources
Alternatively, to participate or get the last version, one can use the sources
git clone git@github.com:sdot-team/sdot.git
cd sdot
pip install -r requirements.txt
pip install -eHello, World
import numpy
from sdot
f = sdot.SumOfDiracs( np.random.rand( 1000, 2 ) ) # 1000 equal-weight diracs, 2D
g = sdot.SplineGrid( np.random.rand( 10, 10 ) ) # 10×10 spline density (auto-normalized)
print( sdot.distance( f, g ) )SDOT infers the backend, dimension, dtype (also with the int type), and device automatically.
The Driver
SDOT abstracts the underlying framework (JAX or PyTorch) behind a single driver object. You don't have to import JAX or PyTorch directly in your code — everything can go through sdot.driver.
Choice of the backend
SDOT looks for backend in this order:
- Modules already imported in the current session (
jax,torch) - The
SDOT_FRAMEWORKenvironment variable - First available library (
jax, thentorch)
Configuration
It is also possible to "force" the parameters, notably during runtime
import sdot
sdot.driver.framework = "jax" # or "torch"
sdot.driver.dtype = "FP64" # FP32, FP64
sdot.driver.device = "cuda:0" # or "cpu", "mps", …Or via environment variables:
export SDOT_DTYPE=FP32
export SDOT_DEVICE=cuda:0All arrays returned by SDOT live in the same framework, so they slot directly into your JAX/PyTorch computation graph.