Is it expensive to keep recreating a Flax network, such asclass QNetwork(nn.Module): dim: int @nn.compact def __call__(self, x): x = nn.Dense(120)(x) x = nn.relu(x) ...
Duplicating my question here: https://github.com/google/flax/discussions/4825I want to have a JAX or NNX jitted function that consumes and returns GPU-sharded tensors. However, inside the function, I ...
In the following code, when I remove the vmap, I have the right randomized behavior. However, with vmap, I don't anymore. Isn't this supposed to be one of the features of nnx.vmap?import jaximport ...
I want to train LLM on TPUv4-32 using JAX/Flax. The dataset is stored in a mounted google storage bucket. The dataset (Red-Pajama-v2) consists of 5000 shards, which are stored in .json.gz files: ~/...
I'm doing some experiments with Flax NNX (not Linen!).What I'm trying to do is compute the weights of a network using another network:A hypernetwork receives some input parameters W and outputs a ...
DescriptionI have a deterministic program that uses jax, and is heavy on linear algebra operations.I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), ...
As of writing, this code does not pass the PyRight type checker:import jaximport jax.numpy as jnpimport jax.typing as jtimport flax.linen as nnclass MLP(nn.Module): @nn.compact def ...
I am doing a project with RNNs using jax and flax and I have noticed some behavior that I do not really understand.My code is basically an optimization loop where the user provides the initial ...
I have a neural network (nnx.Module) written in Flax's NNX. I want to train this network efficiently using lax.scan instead of a for loop. However, as scan doesn't allow in place changes, how can I ...
I'm trying to work out how to do transfer learning with flax.nnx. Below is my attempt to freeze the kernel of my nnx.Linear instance and optimize the bias. I think maybe I'm not correctly setting up ...
Why can't I use a vscode debugger to debug jax code, specifically pure functions. I understand that they provide their own framework for debugging but vscode debugger is quite comfortable. Is this ...
I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with ...
I'm currently using Flax for neural network implementations. My model takes two inputs:x and θ. It first processes x through an LSTM, then concatenates the LSTM's output with θ — or more precisely, ...
I have setup a snippet on Colab herewithjax.__version__ # 0.4.33 9Feb2025 orbax.checkpoint.__version__ # 0.6.4 9Feb2025It quite difficult to follow the flax/orbax changes in the ...