- Notifications
You must be signed in to change notification settings - Fork0
Stainless neural networks in JAX
License
francois-rozet/inox
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Inox is a minimalJAX library for neural networks with an intuitivePyTorch-like syntax. As withEquinox, modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.
Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects likeNNX andSerket to provide a versatile interface. Despite the differences, Inox remains compatible with the Equinox ecosystem, and its components (modules, transformations, ...) are for the most part interchangeable with those of Equinox.
Inox means "stainless steel" in French 🔪
Theinox
package is available onPyPI, which means it is installable viapip
.
pip install inox
Alternatively, if you need the latest features, you can install it from the repository.
pip install git+https://github.com/francois-rozet/inox
Modules are defined with an intuitive PyTorch-like syntax,
importjaximportinox.nnasnninit_key,data_key=jax.random.split(jax.random.key(0))classMLP(nn.Module):def__init__(self,key):keys=jax.random.split(key,3)self.l1=nn.Linear(3,64,key=keys[0])self.l2=nn.Linear(64,64,key=keys[1])self.l3=nn.Linear(64,3,key=keys[2])self.relu=nn.ReLU()def__call__(self,x):x=self.l1(x)x=self.l2(self.relu(x))x=self.l3(self.relu(x))returnxmodel=MLP(init_key)
and are compatible with JAX transformations.
X=jax.random.normal(data_key, (1024,3))Y=jax.numpy.sort(X,axis=-1)@jax.jitdefloss_fn(model,x,y):pred=jax.vmap(model)(x)returnjax.numpy.mean((y-pred)**2)grads=jax.grad(loss_fn)(model,X,Y)
However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.
model.name='stainless'# not an array@inox.jitdefloss_fn(model,x,y):pred=jax.vmap(model)(x)returnjax.numpy.mean((y-pred)**2)grads=inox.grad(loss_fn)(model,X,Y)
Inox also provides a partition mechanism to split the static definition of a module (structure, strings, flags, ...) from its dynamic content (parameters, indices, statistics, ...), which is convenient for updating parameters.
model.mask=jax.numpy.array([1,0,1])# not a parameterstatic,params,others=model.partition(nn.Parameter)@jax.jitdefloss_fn(params,others,x,y):model=static(arrays,others)pred=jax.vmap(model)(x)returnjax.numpy.mean((y-pred)**2)grads=jax.grad(loss_fn)(params,others,X,Y)params=jax.tree_util.tree_map(lambdap,g:p-0.01*g,params,grads)model=static(params,others)
For more information, check out the documentation and tutorials atinox.readthedocs.io.
If you have a question, an issue or would like to contribute, please read ourcontributing guidelines.
About
Stainless neural networks in JAX