- Notifications
You must be signed in to change notification settings - Fork1
Efficient Differentiable n-d PDE solvers in JAX.
License
Ceyron/exponax
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Installation •Documentation •Quickstart •Features •Background •Acknowledgements
Exponax
is a suite for building Fourier spectral ETDRK time-steppers forsemi-linear PDEs in 1d, 2d, and 3d. There are many pre-built dynamics and plentyof helpful utilities. It is extremely efficient, is differentiable (due to beingfully written in JAX), and embeds seamlessly into deep learning.
pip install exponax
Requires Python 3.10+ and JAX 0.4.13+. 👉JAX install guide.
Documentation is available atfkoehler.site/exponax.
1d Kuramoto-Sivashinsky Equation.
importjaximportexponaxaseximportmatplotlib.pyplotaspltks_stepper=ex.stepper.KuramotoSivashinskyConservative(num_spatial_dims=1,domain_extent=100.0,num_points=200,dt=0.1,)u_0=ex.ic.RandomTruncatedFourierSeries(num_spatial_dims=1,cutoff=5)(num_points=200,key=jax.random.PRNGKey(0))trajectory=ex.rollout(ks_stepper,500,include_init=True)(u_0)plt.imshow(trajectory[:,0, :].T,aspect='auto',cmap='RdBu',vmin=-2,vmax=2,origin="lower")plt.xlabel("Time");plt.ylabel("Space");plt.show()
For a next step, check outthis tutorial on 1DAdvectionthat explains the basics ofExponax
.
- JAX as the computational backend:
- Backend agnotistic code - run on CPU, GPU, or TPU, in both single anddouble precision.
- Automatic differentiation over the timesteppers - compute gradientsof solutions with respect to initial conditions, parameters, etc.
- Also helpful fortight integration with Deep Learning since eachtimestepper is just anEquinox Module.
- Automatic Vectorization using
jax.vmap
(orequinox.filter_vmap
)allowing to advance multiple states in time or instantiate multiplesolvers at a time that operate efficiently in batch.
- Lightweight Design without custom types. There is no
grid
orstate
object. Everything is based on JAX arrays. Timesteppers are callablePyTrees. - More than 46 pre-built dynamics across 1d, 2d, and 3d:
- Linear PDEs (advection, diffusion, dispersion, etc.)
- Nonlinear PDEs (Burgers, Kuramoto-Sivashinsky,Korteweg-de Vries, Navier-Stokes, etc.)
- Reaction-Diffusion (Gray-Scott, Swift-Hohenberg, etc.)
- Collection ofinitial condition distributions (truncated Fourier series,Gaussian Random Fields, etc.)
- Utilities for spectral derivatives, grid creation, autogressive rollout,interpolation, etc.
- Easilyextendable to new PDEs by subclassing from the
BaseStepper
module. - An alternative, reduced interface allowing to define PDE dynamics usingnormalized or difficulty-based idenfitiers.
Exponax supports the efficient solution of (semi-linear) partial differentialequations on periodic domains in arbitrary dimensions. Those are PDEs of theform
where
We focus on periodic domains on scaled hypercubes with a uniform Cartesiandiscretization. This allows using the Fast Fourier Transform resulting inblazing fast simulations. For example, a dataset of trajectories for the 2dKuramoto-Sivashinsky equation with 50 initial conditions over 200 time stepswith a 128x128 discretization is created in less than a second on a modern GPU.
[1] Cox, Steven M., and Paul C. Matthews. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455.
[2] Kassam, A.K. and Trefethen, L.N., 2005. Fourth-order time-stepping for stiff PDEs. SIAM Journal on Scientific Computing, 26(4), pp.1214-1233.
[3] Montanelli, Hadrien, and Niall Bootland. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327.
This package is greatly inspired by thechebfunlibrary inMATLAB, in particular thespinX
(Stiff Pde INtegratorin X dimensions) module within it. TheseMATLAB utilties have been usedextensively as a data generator in early works for supervised physics-informedML, e.g., theDeepHiddenPhysicsandFourier NeuralOperators(the links show where in their public repos they use thespinX
module). Theapproach of pre-sampling the solvers, writing out the trajectories, and thenusing them for supervised training worked for these problems, but of courselimits the scope to purely supervised problem. Modern research ideas likecorrecting coarse solvers (see for instance theSolver-in-the-Looppaper or theML-accelerated CFDpaper) require a coarse solvers to bedifferentiable. Some ideasof diverted chain training also requires the fine solver to be differentiable.Even for applications without differentiable solvers, we still have theinterface problem with legacy solvers (like theMATLAB ones). Hence, wecannot easily query them "on-the-fly" for sth like active learning tasks, nor dothey run efficiently on hardward accelerators (GPUs, TPUs, etc.). Additionally,they were not designed with batch execution (in the sense of vectorizedapplication) in mind which we get more or less for free byjax.vmap
. With thereproducible randomness ofJAX
we might not even have to ever write out adataset and can re-create it in seconds!
This package also took much inspiration from theFourierFlows.jl in theJulia ecosystem, especially for checking the implementation of the contourintegral method of [2] and how to handle (de)aliasing.
This package was developed as part of theAPEBench paper
(accepted at Neurips 2024), we will soon add the citation here.
The main author (Felix Koehler) is a PhD student in the group ofProf. Thuerey at TUM and his research is funded by theMunich Center for Machine Learning.
MIT, seehere
fkoehler.site · GitHub@ceyron · X@felix_m_koehler · LinkedInFelix Köhler
About
Efficient Differentiable n-d PDE solvers in JAX.
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors3
Uh oh!
There was an error while loading.Please reload this page.