Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.lax.stop_gradient

jax.lax.stop_gradient#

jax.lax.stop_gradient(x)[source]#

Stops gradient computation.

Operationallystop_gradient is the identity function, that is, it returnsargumentx unchanged. However,stop_gradient prevents the flow ofgradients during forward or reverse-mode automatic differentiation. If thereare multiple nested gradient computations,stop_gradient stops gradientsfor all of them. For some discussion of where this is useful, refer toStopping gradients.

Parameters:

x (T) – array or pytree of arrays

Returns:

input value is returned unchanged, but within autodiff will be treated asa constant.

Return type:

T

Examples

Consider a simple function that returns the square of the input value:

>>>deff1(x):...returnx**2>>>x=jnp.float32(3.0)>>>f1(x)Array(9.0, dtype=float32)>>>jax.grad(f1)(x)Array(6.0, dtype=float32)

The same function withstop_gradient aroundx will be equivalentunder normal evaluation, but return a zero gradient becausex iseffectively treated as a constant:

>>>deff2(x):...returnjax.lax.stop_gradient(x)**2>>>f2(x)Array(9.0, dtype=float32)>>>jax.grad(f2)(x)Array(0.0, dtype=float32)

This is used in a number of places within the JAX codebase; for examplejax.nn.softmax() internally normalizes the input by its maximumvalue, and this maximum value is wrapped instop_gradient forefficiency. Refer toStopping gradients for more discussion ofthe applicability ofstop_gradient.


[8]ページ先頭

©2009-2025 Movatter.jp