Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Stainless neural networks in JAX

License

NotificationsYou must be signed in to change notification settings

francois-rozet/inox

Repository files navigation

Inox's banner

Stainless neural networks in JAX

Inox is a minimalJAX library for neural networks with an intuitivePyTorch-like syntax. As withEquinox, modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.

Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects likeNNX andSerket to provide a versatile interface. Despite the differences, Inox remains compatible with the Equinox ecosystem, and its components (modules, transformations, ...) are for the most part interchangeable with those of Equinox.

Inox means "stainless steel" in French 🔪

Installation

Theinox package is available onPyPI, which means it is installable viapip.

pip install inox

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/francois-rozet/inox

Getting started

Modules are defined with an intuitive PyTorch-like syntax,

importjaximportinox.nnasnninit_key,data_key=jax.random.split(jax.random.key(0))classMLP(nn.Module):def__init__(self,key):keys=jax.random.split(key,3)self.l1=nn.Linear(3,64,key=keys[0])self.l2=nn.Linear(64,64,key=keys[1])self.l3=nn.Linear(64,3,key=keys[2])self.relu=nn.ReLU()def__call__(self,x):x=self.l1(x)x=self.l2(self.relu(x))x=self.l3(self.relu(x))returnxmodel=MLP(init_key)

and are compatible with JAX transformations.

X=jax.random.normal(data_key, (1024,3))Y=jax.numpy.sort(X,axis=-1)@jax.jitdefloss_fn(model,x,y):pred=jax.vmap(model)(x)returnjax.numpy.mean((y-pred)**2)grads=jax.grad(loss_fn)(model,X,Y)

However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.

model.name='stainless'# not an array@inox.jitdefloss_fn(model,x,y):pred=jax.vmap(model)(x)returnjax.numpy.mean((y-pred)**2)grads=inox.grad(loss_fn)(model,X,Y)

Inox also provides a partition mechanism to split the static definition of a module (structure, strings, flags, ...) from its dynamic content (parameters, indices, statistics, ...), which is convenient for updating parameters.

model.mask=jax.numpy.array([1,0,1])# not a parameterstatic,params,others=model.partition(nn.Parameter)@jax.jitdefloss_fn(params,others,x,y):model=static(arrays,others)pred=jax.vmap(model)(x)returnjax.numpy.mean((y-pred)**2)grads=jax.grad(loss_fn)(params,others,X,Y)params=jax.tree_util.tree_map(lambdap,g:p-0.01*g,params,grads)model=static(params,others)

For more information, check out the documentation and tutorials atinox.readthedocs.io.

Contributing

If you have a question, an issue or would like to contribute, please read ourcontributing guidelines.


[8]ページ先頭

©2009-2025 Movatter.jp