Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Efficient Differentiable n-d PDE solvers in JAX.

License

NotificationsYou must be signed in to change notification settings

Ceyron/exponax

Repository files navigation

InstallationDocumentationQuickstartFeaturesBackgroundAcknowledgements

Exponax is a suite for building Fourier spectral ETDRK time-steppers forsemi-linear PDEs in 1d, 2d, and 3d. There are many pre-built dynamics and plentyof helpful utilities. It is extremely efficient, is differentiable (due to beingfully written in JAX), and embeds seamlessly into deep learning.

Installation

pip install exponax

Requires Python 3.10+ and JAX 0.4.13+. 👉JAX install guide.

Documentation

Documentation is available atfkoehler.site/exponax.

Quickstart

1d Kuramoto-Sivashinsky Equation.

importjaximportexponaxaseximportmatplotlib.pyplotaspltks_stepper=ex.stepper.KuramotoSivashinskyConservative(num_spatial_dims=1,domain_extent=100.0,num_points=200,dt=0.1,)u_0=ex.ic.RandomTruncatedFourierSeries(num_spatial_dims=1,cutoff=5)(num_points=200,key=jax.random.PRNGKey(0))trajectory=ex.rollout(ks_stepper,500,include_init=True)(u_0)plt.imshow(trajectory[:,0, :].T,aspect='auto',cmap='RdBu',vmin=-2,vmax=2,origin="lower")plt.xlabel("Time");plt.ylabel("Space");plt.show()

For a next step, check outthis tutorial on 1DAdvectionthat explains the basics ofExponax.

Features

  1. JAX as the computational backend:
    1. Backend agnotistic code - run on CPU, GPU, or TPU, in both single anddouble precision.
    2. Automatic differentiation over the timesteppers - compute gradientsof solutions with respect to initial conditions, parameters, etc.
    3. Also helpful fortight integration with Deep Learning since eachtimestepper is just anEquinox Module.
    4. Automatic Vectorization usingjax.vmap (orequinox.filter_vmap)allowing to advance multiple states in time or instantiate multiplesolvers at a time that operate efficiently in batch.
  2. Lightweight Design without custom types. There is nogrid orstateobject. Everything is based on JAX arrays. Timesteppers are callablePyTrees.
  3. More than 46 pre-built dynamics across 1d, 2d, and 3d:
    1. Linear PDEs (advection, diffusion, dispersion, etc.)
    2. Nonlinear PDEs (Burgers, Kuramoto-Sivashinsky,Korteweg-de Vries, Navier-Stokes, etc.)
    3. Reaction-Diffusion (Gray-Scott, Swift-Hohenberg, etc.)
  4. Collection ofinitial condition distributions (truncated Fourier series,Gaussian Random Fields, etc.)
  5. Utilities for spectral derivatives, grid creation, autogressive rollout,interpolation, etc.
  6. Easilyextendable to new PDEs by subclassing from theBaseStepper module.
  7. An alternative, reduced interface allowing to define PDE dynamics usingnormalized or difficulty-based idenfitiers.

Background

Exponax supports the efficient solution of (semi-linear) partial differentialequations on periodic domains in arbitrary dimensions. Those are PDEs of theform

$$ \partial u/ \partial t = Lu + N(u), $$

where$L$ is a linear differential operator and$N$ is a nonlinear differentialoperator. The linear part can be exactly solved using a (matrix) exponential,and the nonlinear part is approximated using Runge-Kutta methods of variousorders. These methods have been known in various disciplines in science for along time and have been unified for a first time byCox &Matthews [1]. In particular, thispackage uses the complex contour integral method ofKassam &Trefethen [2] for numericalstability. The package is restricted to the original first, second, third andfourth order method. A recent study byMontanelli &Bootland [3] showed that theoriginalETDRK4 method is still one of the most efficient methods for thesetypes of PDEs.

We focus on periodic domains on scaled hypercubes with a uniform Cartesiandiscretization. This allows using the Fast Fourier Transform resulting inblazing fast simulations. For example, a dataset of trajectories for the 2dKuramoto-Sivashinsky equation with 50 initial conditions over 200 time stepswith a 128x128 discretization is created in less than a second on a modern GPU.

[1] Cox, Steven M., and Paul C. Matthews. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455.

[2] Kassam, A.K. and Trefethen, L.N., 2005. Fourth-order time-stepping for stiff PDEs. SIAM Journal on Scientific Computing, 26(4), pp.1214-1233.

[3] Montanelli, Hadrien, and Niall Bootland. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327.

Acknowledgements

Related & Motivation

This package is greatly inspired by thechebfunlibrary inMATLAB, in particular thespinX (Stiff Pde INtegratorin X dimensions) module within it. TheseMATLAB utilties have been usedextensively as a data generator in early works for supervised physics-informedML, e.g., theDeepHiddenPhysicsandFourier NeuralOperators(the links show where in their public repos they use thespinX module). Theapproach of pre-sampling the solvers, writing out the trajectories, and thenusing them for supervised training worked for these problems, but of courselimits the scope to purely supervised problem. Modern research ideas likecorrecting coarse solvers (see for instance theSolver-in-the-Looppaper or theML-accelerated CFDpaper) require a coarse solvers to bedifferentiable. Some ideasof diverted chain training also requires the fine solver to be differentiable.Even for applications without differentiable solvers, we still have theinterface problem with legacy solvers (like theMATLAB ones). Hence, wecannot easily query them "on-the-fly" for sth like active learning tasks, nor dothey run efficiently on hardward accelerators (GPUs, TPUs, etc.). Additionally,they were not designed with batch execution (in the sense of vectorizedapplication) in mind which we get more or less for free byjax.vmap. With thereproducible randomness ofJAX we might not even have to ever write out adataset and can re-create it in seconds!

This package also took much inspiration from theFourierFlows.jl in theJulia ecosystem, especially for checking the implementation of the contourintegral method of [2] and how to handle (de)aliasing.

Citation

This package was developed as part of theAPEBench paper (accepted at Neurips 2024), we will soon add the citation here.

Funding

The main author (Felix Koehler) is a PhD student in the group ofProf. Thuerey at TUM and his research is funded by theMunich Center for Machine Learning.

License

MIT, seehere


fkoehler.site  · GitHub@ceyron  · X@felix_m_koehler  · LinkedInFelix Köhler

Packages

No packages published

Contributors3

  •  
  •  
  •  

[8]ページ先頭

©2009-2025 Movatter.jp