- Notifications
You must be signed in to change notification settings - Fork0
Neural Emulator Architectures in JAX.
License
Ceyron/pdequinox
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
PDE Emulator Architectures forEquinox.
Installation •Documentation •Quickstart •Background •Features •Boundary Conditions •Related •Citation
A collection of neural architectures for emulating Partial Differential Equations (PDEs) in JAX agnostic to the spatial dimension (1D, 2D, 3D) and boundary conditions (Dirichlet, Neumann, Periodic). This package is built on top ofEquinox.
pip install pdequinox
Requires Python 3.10+ and JAX 0.4.13+. 👉JAX install guide.
The documentation is available atfkoehler.site/pdequinox.
Train a UNet to become an emulator for the 1D Poisson equation.
importjaximportjax.numpyasjnpimportequinoxaseqximportoptax# `pip install optax`importpdequinoxaspdeqxfromtqdmimporttqdm# `pip install tqdm`force_fields,displacement_fields=pdeqx.sample_data.poisson_1d_dirichlet(key=jax.random.PRNGKey(0))force_fields_train=force_fields[:800]force_fields_test=force_fields[800:]displacement_fields_train=displacement_fields[:800]displacement_fields_test=displacement_fields[800:]unet=pdeqx.arch.ClassicUNet(1,1,1,key=jax.random.PRNGKey(1))defloss_fn(model,x,y):y_pref=jax.vmap(model)(x)returnjnp.mean((y_pref-y)**2)opt=optax.adam(3e-4)opt_state=opt.init(eqx.filter(unet,eqx.is_array))@eqx.filter_jitdefupdate_fn(model,state,x,y):loss,grad=eqx.filter_value_and_grad(loss_fn)(model,x,y)updates,new_state=opt.update(grad,state,model)new_model=eqx.apply_updates(model,updates)returnnew_model,new_state,lossloss_history= []shuffle_key=jax.random.PRNGKey(151)forepochintqdm(range(100)):shuffle_key,subkey=jax.random.split(shuffle_key)forbatchinpdeqx.dataloader( (force_fields_train,displacement_fields_train),batch_size=32,key=subkey ):unet,opt_state,loss=update_fn(unet,opt_state,*batch, )loss_history.append(loss)
Neural Emulators are networks learned to efficienty predict physical phenomena,often associated with PDEs. In the simplest case this can be a linear advectionequation, all the way to more complicated Navier-Stokes cases. If we work onUniform Cartesian grids* (which this package assumes), one can borrow plenty ofarchitectures from image-to-image tasks in computer vision (e.g., forsegmentation). This includes:
- Standard Feedforward ConvNets
- Convolutional ResNets (He et al.)
- U-Nets (Ronneberger et al.)
- Dilated ResNets (Yu et al.,Stachenfeld et al.)
- Fourier Neural Operators (Li et al.)
It is interesting to note that most of these architectures resemble classicalnumerical methods or at least share similarities with them. For example,ConvNets (or convolutions in general) are related to finite differences, whileU-Nets resemble multigrid methods. Fourier Neural Operators are related tospectral methods. The difference is that the emulators' free parameters arefound based on a (data-driven) numerical optimization not a symbolicmanipulation of the differential equations.
(*) This means that we essentially have a pixel or voxel grid on which space isdiscretized. Hence, the space can only be the scaled unit cube
- Based onJAX:
- One of the best Automatic Differentiation engines (forward & reverse)
- Automatic vectorization
- Backend-agnostic code (run on CPU, GPU, and TPU)
- Based onEquinox:
- Single-Batch by design
- Integration into the Equinox SciML ecosystem
- Agnostic to the spatial dimension (works for 1D, 2D, and 3D)
- Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodicBCs)
- Composability
- Tools to count parameters and assess receptive fields
This package assumes that the boundary condition is baked into the neuralemulator. Hence, most components allow settingboundary_mode
which can be"dirichlet"
,"neumann"
, or"periodic"
. This affects what is considered adegree of freedom in the grid.
Dirichlet boundaries fully eliminate degrees of freedom on the boundary.Periodic boundaries only keep one end of the domain as a degree of freedom (Thispackage follows the convention that the left boundary is the degree of freedom). Neumann boundaries keep both ends as degrees of freedom.
Similar packages that provide a collection of emulator architectures arePDEBench andPDEArena. With focus on Phyiscs-informedNeural Networks and Neural Operators, there are alsoDeepXDE andNVIDIAModulus.
This package was developed as part of theAPEBench paper(arxiv.org/abs/2411.00180) (accepted atNeurips 2024). If you find it useful for your research, please consider citingit:
@article{koehler2024apebench,title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey},journal={Advances in Neural Information Processing Systems (NeurIPS)},volume={38},year={2024}}
(Feel free to also give the project a star on GitHub if you like it.)
Here you can find the APEBench benchmarksuite.
The main author (Felix Koehler) is a PhD student in the group ofProf. Thuerey at TUM and his research is funded by theMunich Center for Machine Learning.
MIT, seehere
fkoehler.site · GitHub@ceyron · X@felix_m_koehler · LinkedInFelix Köhler
About
Neural Emulator Architectures in JAX.