Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Flax is a neural network library for JAX that is designed for flexibility.

License

NotificationsYou must be signed in to change notification settings

google/flax

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Buildcoverage

Overview|Quick install|What does Flax look like?|Documentation

Released in 2024, Flax NNX is a new simplified Flax API that is designed to makeit easier to create, inspect, debug, and analyze neural networks inJAX. It achieves this by adding first class supportfor Python reference semantics. This allows users to express their models usingregular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from theFlax Linen API, whichwas released in 2020 by engineers and researchers at Google Brain in close collaborationwith the JAX team.

You can learn more about Flax NNX on thededicated Flax documentation site. Make sure you check out:

Note: Flax Linen'sdocumentation has its own site.

The Flax team's mission is to serve the growing JAX neural networkresearch ecosystem - both within Alphabet and with the broader community,and to explore the use-cases where JAX shines. We use GitHub for almostall of our coordination and planning, as well as where we discussupcoming design changes. We welcome feedback on any of our discussion,issue and pull request threads.

You can make feature requests, let us know what you are working on,report issues, ask questions in ourFlax GitHub discussionforum.

We expect to improve Flax, but we don't anticipate significantbreaking changes to the core API. We useChangelogentries and deprecation warnings when possible.

In case you want to reach us directly, we're atflax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem forJAX that isdesigned for flexibility:Try new forms of training by forking an example and by modifying the trainingloop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team andcomes with everything you need to start your research, including:

Quick install

Flax uses JAX, so do check outJAX installation instructions on CPUs, GPUs and TPUs.

You will need Python 3.8 or later. Install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (likematplotlib) that are required but not includedby some dependencies, you can use:

pip install"flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about theModule abstraction, check out ourdocs, ourbroad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to ourguides anddeveloper notes.

Example of an MLP:

classMLP(nnx.Module):def__init__(self,din:int,dmid:int,dout:int,*,rngs:nnx.Rngs):self.linear1=Linear(din,dmid,rngs=rngs)self.dropout=nnx.Dropout(rate=0.1,rngs=rngs)self.bn=nnx.BatchNorm(dmid,rngs=rngs)self.linear2=Linear(dmid,dout,rngs=rngs)def__call__(self,x:jax.Array):x=nnx.gelu(self.dropout(self.bn(self.linear1(x))))returnself.linear2(x)

Example of a CNN:

classCNN(nnx.Module):def__init__(self,*,rngs:nnx.Rngs):self.conv1=nnx.Conv(1,32,kernel_size=(3,3),rngs=rngs)self.conv2=nnx.Conv(32,64,kernel_size=(3,3),rngs=rngs)self.avg_pool=partial(nnx.avg_pool,window_shape=(2,2),strides=(2,2))self.linear1=nnx.Linear(3136,256,rngs=rngs)self.linear2=nnx.Linear(256,10,rngs=rngs)def__call__(self,x):x=self.avg_pool(nnx.relu(self.conv1(x)))x=self.avg_pool(nnx.relu(self.conv2(x)))x=x.reshape(x.shape[0],-1)# flattenx=nnx.relu(self.linear1(x))x=self.linear2(x)returnx

Example of an autoencoder:

Encoder=lambdarngs:nnx.Linear(2,10,rngs=rngs)Decoder=lambdarngs:nnx.Linear(10,2,rngs=rngs)classAutoEncoder(nnx.Module):def__init__(self,rngs):self.encoder=Encoder(rngs)self.decoder=Decoder(rngs)def__call__(self,x)->jax.Array:returnself.decoder(self.encoder(x))defencode(self,x)->jax.Array:returnself.encoder(x)

Citing Flax

To cite this repository:

@software{flax2020github,  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},  title = {{F}lax: A neural network library and ecosystem for {JAX}},  url = {http://github.com/google/flax},  version = {0.10.4},  year = {2024},}

In the above bibtex entry, names are in alphabetical order, the version numberis intended to be that fromflax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.


[8]ページ先頭

©2009-2025 Movatter.jp