Movatterモバイル変換


[0]ホーム

URL:


Sorry, we no longer support your browser
Please upgrade toMicrosoft Edge,Google Chrome, orFirefox. Learn more about ourbrowser support.
Skip to main content
Stack Overflow
  1. About
  2. For Teams
Loading…
How are we doing? Please help us improve Stack Overflow.Take our short survey
Collectives™ on Stack Overflow

Find centralized, trusted content and collaborate around the technologies you use most.

Learn more about Collectives
67 questions
Filter by
Sorted by
Tagged with
0votes
0answers
101views

Is it expensive to keep recreating a Flax network, such asclass QNetwork(nn.Module): dim: int @nn.compact def __call__(self, x): x = nn.Dense(120)(x) x = nn.relu(x) ...
joel's user avatar
  • 8,122
1vote
1answer
99views

from jax import numpy as jnpfrom jax import randomfrom flax import nnximport optaxfrom matplotlib import pyplot as pltif __name__ == '__main__': shape = (2,55,1) epochs = 123 rngs = ...
user137146's user avatar
0votes
1answer
99views

Duplicating my question here: https://github.com/google/flax/discussions/4825I want to have a JAX or NNX jitted function that consumes and returns GPU-sharded tensors. However, inside the function, I ...
David Braun's user avatar
1vote
2answers
268views

In the following code, when I remove the vmap, I have the right randomized behavior. However, with vmap, I don't anymore. Isn't this supposed to be one of the features of nnx.vmap?import jaximport ...
Jackpap's user avatar
  • 8,086
0votes
1answer
54views

I want to train LLM on TPUv4-32 using JAX/Flax. The dataset is stored in a mounted google storage bucket. The dataset (Red-Pajama-v2) consists of 5000 shards, which are stored in .json.gz files: ~/...
0votes
1answer
89views

I'm doing some experiments with Flax NNX (not Linen!).What I'm trying to do is compute the weights of a network using another network:A hypernetwork receives some input parameters W and outputs a ...
1vote
0answers
75views

DescriptionI have a deterministic program that uses jax, and is heavy on linear algebra operations.I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), ...
2votes
1answer
108views

As of writing, this code does not pass the PyRight type checker:import jaximport jax.numpy as jnpimport jax.typing as jtimport flax.linen as nnclass MLP(nn.Module): @nn.compact def ...
1vote
1answer
290views

I am doing a project with RNNs using jax and flax and I have noticed some behavior that I do not really understand.My code is basically an optimization loop where the user provides the initial ...
yousef elbrolosy's user avatar
1vote
1answer
279views

I have a neural network (nnx.Module) written in Flax's NNX. I want to train this network efficiently using lax.scan instead of a for loop. However, as scan doesn't allow in place changes, how can I ...
elderly's user avatar
1vote
1answer
372views

I'm trying to work out how to do transfer learning with flax.nnx. Below is my attempt to freeze the kernel of my nnx.Linear instance and optimize the bias. I think maybe I'm not correctly setting up ...
jworrell's user avatar
2votes
1answer
203views

Why can't I use a vscode debugger to debug jax code, specifically pure functions. I understand that they provide their own framework for debugging but vscode debugger is quite comfortable. Is this ...
akshat's user avatar
0votes
1answer
286views

I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with ...
jworrell's user avatar
1vote
1answer
90views

I'm currently using Flax for neural network implementations. My model takes two inputs:x and θ. It first processes x through an LSTM, then concatenates the LSTM's output with θ — or more precisely, ...
Dan Leonte's user avatar
0votes
0answers
76views

I have setup a snippet on Colab herewithjax.__version__ # 0.4.33 9Feb2025 orbax.checkpoint.__version__ # 0.6.4 9Feb2025It quite difficult to follow the flax/orbax changes in the ...
Jean-Eric's user avatar

153050per page
1
2345

Hot Network Questions

more hot questions
Newest flax questions feed

[8]ページ先頭

©2009-2025 Movatter.jp