Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.example_libraries.stax module#

Stax is a small but flexible neural net specification library from scratch.

You likely do not mean to import this module! Stax is intended as an examplelibrary only. There are a number of other much more fully-featured neuralnetwork libraries for JAX, includingFlax from Google, andHaiku fromDeepMind.

jax.example_libraries.stax.AvgPool(window_shape,strides=None,padding='VALID',spec=None)[source]#

Layer construction function for a pooling layer.

jax.example_libraries.stax.BatchNorm(axis=(0,1,2),epsilon=1e-05,center=True,scale=True,beta_init=<functionzeros>,gamma_init=<functionones>)[source]#

Layer construction function for a batch normalization layer.

jax.example_libraries.stax.Conv(out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=<functionnormal.<locals>.init>)#

Layer construction function for a general convolution layer.

jax.example_libraries.stax.Conv1DTranspose(out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=<functionnormal.<locals>.init>)#

Layer construction function for a general transposed-convolution layer.

jax.example_libraries.stax.ConvTranspose(out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=<functionnormal.<locals>.init>)#

Layer construction function for a general transposed-convolution layer.

jax.example_libraries.stax.Dense(out_dim,W_init=<functionvariance_scaling.<locals>.init>,b_init=<functionnormal.<locals>.init>)[source]#

Layer constructor function for a dense (fully-connected) layer.

jax.example_libraries.stax.Dropout(rate,mode='train')[source]#

Layer construction function for a dropout layer with given rate.

jax.example_libraries.stax.FanInConcat(axis=-1)[source]#

Layer construction function for a fan-in concatenation layer.

jax.example_libraries.stax.FanOut(num)[source]#

Layer construction function for a fan-out layer.

jax.example_libraries.stax.GeneralConv(dimension_numbers,out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=<functionnormal.<locals>.init>)[source]#

Layer construction function for a general convolution layer.

jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers,out_chan,filter_shape,strides=None,padding='VALID',W_init=None,b_init=<functionnormal.<locals>.init>)[source]#

Layer construction function for a general transposed-convolution layer.

jax.example_libraries.stax.MaxPool(window_shape,strides=None,padding='VALID',spec=None)[source]#

Layer construction function for a pooling layer.

jax.example_libraries.stax.SumPool(window_shape,strides=None,padding='VALID',spec=None)[source]#

Layer construction function for a pooling layer.

jax.example_libraries.stax.elementwise(fun,**fun_kwargs)[source]#

Layer that applies a scalar function elementwise on its inputs.

jax.example_libraries.stax.parallel(*layers)[source]#

Combinator for composing layers in parallel.

The layer resulting from this combinator is often used with the FanOut andFanInSum layers.

Parameters:

*layers – a sequence of layers, each an (init_fun, apply_fun) pair.

Returns:

A new layer, meaning an (init_fun, apply_fun) pair, representing theparallel composition of the given sequence of layers. In particular, thereturned layer takes a sequence of inputs and returns a sequence of outputswith the same length as the argumentlayers.

jax.example_libraries.stax.serial(*layers)[source]#

Combinator for composing layers in serial.

Parameters:

*layers – a sequence of layers, each an (init_fun, apply_fun) pair.

Returns:

A new layer, meaning an (init_fun, apply_fun) pair, representing the serialcomposition of the given sequence of layers.

jax.example_libraries.stax.shape_dependent(make_layer)[source]#

Combinator to delay layer constructor pair until input shapes are known.

Parameters:

make_layer – a one-argument function that takes an input shape as an argument(a tuple of positive integers) and returns an (init_fun, apply_fun) pair.

Returns:

A new layer, meaning an (init_fun, apply_fun) pair, representing the samelayer as returned bymake_layer but with its construction delayed untilinput shapes are known.


[8]ページ先頭

©2009-2025 Movatter.jp