- Notifications
You must be signed in to change notification settings - Fork4
GalacticDynamics/unxt
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Unxt is unitful quantities and calculations inJAX, built onEquinox andQuax.
Unxt supports JAX's compelling features:
- JIT compilation (
jit
) - vectorization (
vmap
, etc.) - auto-differentiation (
grad
,jacobian
,hessian
) - GPU/TPU/multi-host acceleration
And best of all,unxt
doesn't force you to use special unit-compatiblere-exports of JAX libraries. You can useunxt
with existing JAX code, and withquax's simple decorator, JAX will work withunxt.Quantity
.
pip install unxt
usinguv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parentgit clone https://https://github.com/GalacticDynamics/unxt.gitcd unxtpip install -e.# editable mode
importunxtasuimportjax.numpyasjnpx=u.Quantity(jnp.arange(1,5,dtype=float),"km")print(x)# Quantity['length']([1., 2., 3., 4.], unit='km')
The constituent value and unit are accessible as attributes:
repr(x.value)# Array([1., 2., 3., 4.], dtype=float64)repr(x.unit)# Unit("m")
Quantity
objects obey the rules of unitful arithmetic.
# Addition / Subtractionprint(x+x)Quantity["length"]([2.0,4.0,6.0,8.0],unit="km")# Multiplication / Divisionprint(2*x)Quantity["length"]([2.0,4.0,6.0,8.0],unit="km")y=u.Quantity(jnp.arange(4,8,dtype=float),"yr")print(x/y)# Quantity['speed']([0.25, 0.4 , 0.5 , 0.57142857], unit='km / yr')# Exponentiationprint(x**2)# Quantity['area']([ 1., 4., 9., 16.], unit='km2')# Unit checking on operationstry:x+yexceptExceptionase:print(e)# 'yr' (time) and 'km' (length) are not convertible
Quantities can be converted to different units:
print(u.uconvert("m",x))# via function# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')print(x.uconvert("m"))# via method# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')
SinceQuantity
is parametric, it can do runtime dimension checking!
LengthQuantity=u.Quantity["length"]print(LengthQuantity(2,"km"))# Quantity['length'](2, unit='km')try:LengthQuantity(2,"s")exceptValueErrorase:print(e)# Physical type mismatch.
unxt
is built onquax
, which enables custom array-ish objects inJAX. For convenience we use thequaxed
library, which is just aquax.quaxify
wrapper aroundjax
to avoid boilerplate code.
Note
Usingquaxed
is optional. You can directly usequaxify
, and evenapply it to the top-level function instead of individual functions.
fromquaxedimportgrad,vmapimportquaxed.numpyasjnpprint(jnp.square(x))# Quantity['area']([ 1., 4., 9., 16.], unit='km2')print(jnp.power(x,3))# Quantity['volume']([ 1., 8., 27., 64.], unit='km3')print(vmap(grad(lambdax:x**3))(x))# Quantity['area']([ 3., 12., 27., 48.], unit='km2')
See thedocumentation for more examples and details of JIT and AD
If you found this library to be useful and want to support the development andmaintenance of lower-level code libraries for the scientific community, pleaseconsider citing this work.
We welcome contributions! Contributions are how open source projects improve andgrow.
To contribute tounxt
, pleasefork the repository, make adevelopment branch, develop on that branch, thenopen a pull request from thebranch in your fork to main.
To report bugs, request features, or suggest other ideas, pleaseopen an issue.
For more information, seeCONTRIBUTING.md.
About
Unitful Quantities in JAX
Topics
Resources
License
Code of conduct
Security policy
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Uh oh!
There was an error while loading.Please reload this page.
Contributors8
Uh oh!
There was an error while loading.Please reload this page.