- Notifications
You must be signed in to change notification settings - Fork4
GalacticDynamics/unxt
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Unxt is unitful quantities and calculations inJAX, built onEquinox andQuax.
Unxt supports JAX's compelling features:
- JIT compilation (
jit
) - vectorization (
vmap
, etc.) - auto-differentiation (
grad
,jacobian
,hessian
) - GPU/TPU/multi-host acceleration
And best of all,unxt
doesn't force you to use special unit-compatiblere-exports of JAX libraries. You can useunxt
with existing JAX code, and withquax's simple decorator, JAX will work withunxt.Quantity
.
pip install unxt
usinguv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parentgit clone https://https://github.com/GalacticDynamics/unxt.gitcd unxtpip install -e.# editable mode
importunxtasuimportjax.numpyasjnpx=u.Quantity(jnp.arange(1,5,dtype=float),"km")print(x)# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='km')
The constituent value and unit are accessible as attributes:
print(x.value)# Array([1., 2., 3., 4.], dtype=float64)print(x.unit)# Unit("m")
Quantity
objects obey the rules of unitful arithmetic.
# Addition / Subtractionprint(x+x)# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')# Multiplication / Divisionprint(2*x)# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')y=u.Quantity(jnp.arange(4,8,dtype=float),"yr")print(x/y)# Quantity['speed'](Array([0.25 , 0.4 , 0.5 , 0.57142857], dtype=float64), unit='km / yr')# Exponentiationprint(x**2)# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='km2')# Unit checking on operationstry:x+yexceptExceptionase:print(e)# 'yr' (time) and 'km' (length) are not convertible
Quantities can be converted to different units:
print(u.uconvert("m",x))# via function# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')print(x.uconvert("m"))# via method# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')
SinceQuantity
is parametric, it can do runtime dimension checking!
LengthQuantity=u.Quantity["length"]print(LengthQuantity(2,"km"))# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')try:LengthQuantity(2,"s")exceptValueErrorase:print(e)# Physical type mismatch.
unxt
is built onquax
, which enables custom array-ish objects inJAX. For convenience we use thequaxed
library, which is just aquax.quaxify
wrapper aroundjax
to avoid boilerplate code.
Note
Usingquaxed
is optional. You can directly usequaxify
, and evenapply it to the top-level function instead of individual functions.
fromquaxedimportgrad,vmapimportquaxed.numpyasjnpprint(jnp.square(x))# Quantity['area'](Array([ 1., 4., 9., 16.], dtype=float64), unit='km2')print(jnp.power(x,3))# Quantity['volume'](Array([ 1., 8., 27., 64.], dtype=float64), unit='km3')print(vmap(grad(lambdax:x**3))(x))# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='km2')
See thedocumentation for more examples and details of JIT and AD
If you found this library to be useful and want to support the development andmaintenance of lower-level code libraries for the scientific community, pleaseconsider citing this work.
We welcome contributions!
About
Unitful Quantities in JAX