Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Neural Emulator Architectures in JAX.

License

NotificationsYou must be signed in to change notification settings

Ceyron/pdequinox

Repository files navigation

PyPITestsdocs-latestChangelogLicense

InstallationDocumentationQuickstartBackgroundFeaturesBoundary ConditionsRelatedCitation

A collection of neural architectures for emulating Partial Differential Equations (PDEs) in JAX agnostic to the spatial dimension (1D, 2D, 3D) and boundary conditions (Dirichlet, Neumann, Periodic). This package is built on top ofEquinox.

Installation

pip install pdequinox

Requires Python 3.10+ and JAX 0.4.13+. 👉JAX install guide.

Documentation

The documentation is available atfkoehler.site/pdequinox.

Quickstart

Train a UNet to become an emulator for the 1D Poisson equation.

importjaximportjax.numpyasjnpimportequinoxaseqximportoptax# `pip install optax`importpdequinoxaspdeqxfromtqdmimporttqdm# `pip install tqdm`force_fields,displacement_fields=pdeqx.sample_data.poisson_1d_dirichlet(key=jax.random.PRNGKey(0))force_fields_train=force_fields[:800]force_fields_test=force_fields[800:]displacement_fields_train=displacement_fields[:800]displacement_fields_test=displacement_fields[800:]unet=pdeqx.arch.ClassicUNet(1,1,1,key=jax.random.PRNGKey(1))defloss_fn(model,x,y):y_pref=jax.vmap(model)(x)returnjnp.mean((y_pref-y)**2)opt=optax.adam(3e-4)opt_state=opt.init(eqx.filter(unet,eqx.is_array))@eqx.filter_jitdefupdate_fn(model,state,x,y):loss,grad=eqx.filter_value_and_grad(loss_fn)(model,x,y)updates,new_state=opt.update(grad,state,model)new_model=eqx.apply_updates(model,updates)returnnew_model,new_state,lossloss_history= []shuffle_key=jax.random.PRNGKey(151)forepochintqdm(range(100)):shuffle_key,subkey=jax.random.split(shuffle_key)forbatchinpdeqx.dataloader(        (force_fields_train,displacement_fields_train),batch_size=32,key=subkey    ):unet,opt_state,loss=update_fn(unet,opt_state,*batch,        )loss_history.append(loss)

Background

Neural Emulators are networks learned to efficienty predict physical phenomena,often associated with PDEs. In the simplest case this can be a linear advectionequation, all the way to more complicated Navier-Stokes cases. If we work onUniform Cartesian grids* (which this package assumes), one can borrow plenty ofarchitectures from image-to-image tasks in computer vision (e.g., forsegmentation). This includes:

It is interesting to note that most of these architectures resemble classicalnumerical methods or at least share similarities with them. For example,ConvNets (or convolutions in general) are related to finite differences, whileU-Nets resemble multigrid methods. Fourier Neural Operators are related tospectral methods. The difference is that the emulators' free parameters arefound based on a (data-driven) numerical optimization not a symbolicmanipulation of the differential equations.

(*) This means that we essentially have a pixel or voxel grid on which space isdiscretized. Hence, the space can only be the scaled unit cube$\Omega = (0,L)^D$

Features

  • Based onJAX:
    • One of the best Automatic Differentiation engines (forward & reverse)
    • Automatic vectorization
    • Backend-agnostic code (run on CPU, GPU, and TPU)
  • Based onEquinox:
    • Single-Batch by design
    • Integration into the Equinox SciML ecosystem
  • Agnostic to the spatial dimension (works for 1D, 2D, and 3D)
  • Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodicBCs)
  • Composability
  • Tools to count parameters and assess receptive fields

Boundary Conditions

This package assumes that the boundary condition is baked into the neuralemulator. Hence, most components allow settingboundary_mode which can be"dirichlet","neumann", or"periodic". This affects what is considered adegree of freedom in the grid.

three_boundary_conditions

Dirichlet boundaries fully eliminate degrees of freedom on the boundary.Periodic boundaries only keep one end of the domain as a degree of freedom (Thispackage follows the convention that the left boundary is the degree of freedom). Neumann boundaries keep both ends as degrees of freedom.

Related Work

Similar packages that provide a collection of emulator architectures arePDEBench andPDEArena. With focus on Phyiscs-informedNeural Networks and Neural Operators, there are alsoDeepXDE andNVIDIAModulus.

Citation

This package was developed as part of theAPEBench paper(arxiv.org/abs/2411.00180) (accepted atNeurips 2024). If you find it useful for your research, please consider citingit:

@article{koehler2024apebench,title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey},journal={Advances in Neural Information Processing Systems (NeurIPS)},volume={38},year={2024}}

(Feel free to also give the project a star on GitHub if you like it.)

Here you can find the APEBench benchmarksuite.

Funding

The main author (Felix Koehler) is a PhD student in the group ofProf. Thuerey at TUM and his research is funded by theMunich Center for Machine Learning.

License

MIT, seehere


fkoehler.site  · GitHub@ceyron  · X@felix_m_koehler  · LinkedInFelix Köhler


[8]ページ先頭

©2009-2025 Movatter.jp