jax.lax.stop_gradient
Contents
jax.lax.stop_gradient#
- jax.lax.stop_gradient(x)[source]#
Stops gradient computation.
Operationally
stop_gradientis the identity function, that is, it returnsargumentx unchanged. However,stop_gradientprevents the flow ofgradients during forward or reverse-mode automatic differentiation. If thereare multiple nested gradient computations,stop_gradientstops gradientsfor all of them. For some discussion of where this is useful, refer toStopping gradients.- Parameters:
x (T) – array or pytree of arrays
- Returns:
input value is returned unchanged, but within autodiff will be treated asa constant.
- Return type:
T
Examples
Consider a simple function that returns the square of the input value:
>>>deff1(x):...returnx**2>>>x=jnp.float32(3.0)>>>f1(x)Array(9.0, dtype=float32)>>>jax.grad(f1)(x)Array(6.0, dtype=float32)
The same function with
stop_gradientaroundxwill be equivalentunder normal evaluation, but return a zero gradient becausexiseffectively treated as a constant:>>>deff2(x):...returnjax.lax.stop_gradient(x)**2>>>f2(x)Array(9.0, dtype=float32)>>>jax.grad(f2)(x)Array(0.0, dtype=float32)
This is used in a number of places within the JAX codebase; for example
jax.nn.softmax()internally normalizes the input by its maximumvalue, and this maximum value is wrapped instop_gradientforefficiency. Refer toStopping gradients for more discussion ofthe applicability ofstop_gradient.
