jax.extend.linear_util module
jax.extend.linear_util module#
| |
| Represents a functionf to whichtransforms are to be applied. |
| Memoization decorator for functions taking a WrappedFun as first argument. |
| |
Adds one more transformation to a WrappedFun. | |
Adds one more transformation with auxiliary output to a WrappedFun. | |
|
