Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Unitful Quantities in JAX

License

NotificationsYou must be signed in to change notification settings

GalacticDynamics/unxt

Unitful Quantities in JAX

PyPI: unxtPyPI versions: unxtReadTheDocsunxt license

ruffruffpre-commitCodSpeed Badge

CI statusReadTheDocscodecov


Unxt is unitful quantities and calculations inJAX, built onEquinox andQuax.

Unxt supports JAX's compelling features:

  • JIT compilation (jit)
  • vectorization (vmap, etc.)
  • auto-differentiation (grad,jacobian,hessian)
  • GPU/TPU/multi-host acceleration

And best of all,unxt doesn't force you to use special unit-compatiblere-exports of JAX libraries. You can useunxt with existing JAX code, and withquax's simple decorator, JAX will work withunxt.Quantity.

Installation

PyPI versionPyPI platforms

pip install unxt
usinguv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parentgit clone https://https://github.com/GalacticDynamics/unxt.gitcd unxtpip install -e.# editable mode

Read The Docs

Quick example

importunxtasuimportjax.numpyasjnpx=u.Quantity(jnp.arange(1,5,dtype=float),"km")print(x)# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='km')

The constituent value and unit are accessible as attributes:

print(x.value)# Array([1., 2., 3., 4.], dtype=float64)print(x.unit)# Unit("m")

Quantity objects obey the rules of unitful arithmetic.

# Addition / Subtractionprint(x+x)# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')# Multiplication / Divisionprint(2*x)# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')y=u.Quantity(jnp.arange(4,8,dtype=float),"yr")print(x/y)# Quantity['speed'](Array([0.25      , 0.4       , 0.5       , 0.57142857], dtype=float64), unit='km / yr')# Exponentiationprint(x**2)# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='km2')# Unit checking on operationstry:x+yexceptExceptionase:print(e)# 'yr' (time) and 'km' (length) are not convertible

Quantities can be converted to different units:

print(u.uconvert("m",x))# via function# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')print(x.uconvert("m"))# via method# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')

SinceQuantity is parametric, it can do runtime dimension checking!

LengthQuantity=u.Quantity["length"]print(LengthQuantity(2,"km"))# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')try:LengthQuantity(2,"s")exceptValueErrorase:print(e)# Physical type mismatch.

unxt is built onquax, which enables custom array-ish objects inJAX. For convenience we use thequaxed library, which is just aquax.quaxify wrapper aroundjax to avoid boilerplate code.

Note

Usingquaxed is optional. You can directly usequaxify, and evenapply it to the top-level function instead of individual functions.

fromquaxedimportgrad,vmapimportquaxed.numpyasjnpprint(jnp.square(x))# Quantity['area'](Array([ 1.,  4.,  9., 16.], dtype=float64), unit='km2')print(jnp.power(x,3))# Quantity['volume'](Array([ 1.,  8., 27., 64.], dtype=float64), unit='km3')print(vmap(grad(lambdax:x**3))(x))# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='km2')

See thedocumentation for more examples and details of JIT and AD

Citation

DOI

If you found this library to be useful and want to support the development andmaintenance of lower-level code libraries for the scientific community, pleaseconsider citing this work.

Development

Actions StatusDocumentation StatuscodecovSPEC 0 — Minimum Supported Dependenciespre-commitruffCodSpeed Badge

We welcome contributions!


[8]ページ先頭

©2009-2025 Movatter.jp