jax.extend.linear_util module
jax.extend.linear_util module#
| Represents a functionf to whichtransforms are to be applied. |
| Memoization decorator for functions taking a WrappedFun as first argument. |
| |
|
jax.extend.linear_util module#
| Represents a functionf to whichtransforms are to be applied. |
| Memoization decorator for functions taking a WrappedFun as first argument. |
| |
|