jax.jit
Contents
jax.jit#
- jax.jit(fun:Callable,/,*,in_shardings:Any=UnspecifiedValue,out_shardings:Any=UnspecifiedValue,static_argnums:int|Sequence[int]|None=None,static_argnames:str|Iterable[str]|None=None,donate_argnums:int|Sequence[int]|None=None,donate_argnames:str|Iterable[str]|None=None,keep_unused:bool=False,device:xc.Device|None=None,backend:str|None=None,inline:bool=False,abstracted_axes:Any|None=None,compiler_options:dict[str,Any]|None=None)→pjit.JitWrapped[source]#
- jax.jit(*,in_shardings:Any=UnspecifiedValue,out_shardings:Any=UnspecifiedValue,static_argnums:int|Sequence[int]|None=None,static_argnames:str|Iterable[str]|None=None,donate_argnums:int|Sequence[int]|None=None,donate_argnames:str|Iterable[str]|None=None,keep_unused:bool=False,device:xc.Device|None=None,backend:str|None=None,inline:bool=False,abstracted_axes:Any|None=None,compiler_options:dict[str,Any]|None=None)→Callable[[Callable],pjit.JitWrapped]
Sets up
funfor just-in-time compilation with XLA.- Parameters:
fun – Function to be jitted.
funshould be a pure function.The arguments and return value offunshould be arrays, scalar, or(nested) standard Python containers (tuple/list/dict) thereof. Positionalarguments indicated bystatic_argnumscan be any hashable type. Staticarguments are included as part of a compilation cache key, which is whyhash and equality operators must be defined. JAX keeps a weak reference tofunfor use as a compilation cache key, so the objectfunmust beweakly-referenceable. Starting in JAX v0.8.1, whenfunis omitted,the return value will be a partially-evaluated function to allow thedecorator factory pattern (see Examples below).in_shardings – optional, a
Shardingor pytree withShardingleaves and structure that is a tree prefix of thepositional arguments tuple tofun. If provided, the positionalarguments passed tofunmust have shardings that are compatible within_shardingsor an error is raised, and the compiled computation hasinput shardings corresponding toin_shardings. If not provided, thecompiled computation’s input shardings are inferred from argumentshardings.out_shardings – optional, a
Shardingor pytree withShardingleaves and structure that is a tree prefix of theoutput offun. If provided, it has the same effect as applyingjax.lax.with_sharding_constraint()to the output offun.static_argnums –
optional, an int or collection of ints that specify whichpositional arguments to treat as static (trace- and compile-timeconstant).
Static arguments should be hashable, meaning both
__hash__and__eq__are implemented, and immutable. Otherwise, they can be arbitraryPython objects. Calling the jitted function with different values forthese constants will trigger recompilation. Arguments that are notarray-like or containers thereof must be marked as static.If neither
static_argnumsnorstatic_argnamesis provided, noarguments are treated as static. Ifstatic_argnumsis not provided butstatic_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments thatcorrespond tostatic_argnames(or vice versa). If bothstatic_argnumsandstatic_argnamesareprovided,inspect.signatureis not used, and only actualparameters listed in eitherstatic_argnumsorstatic_argnameswillbe treated as static.static_argnames – optional, a string or collection of strings specifyingwhich named arguments to treat as static (compile-time constant). See thecomment on
static_argnumsfor details. If notprovided butstatic_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.donate_argnums –
optional, collection of integers to specify which positionalargument buffers can be overwritten by the computation and marked deletedin the caller. It is safe to donate argument buffers if you no longer needthem once the computation has started. In some cases XLA can make use ofdonated buffers to reduce the amount of memory needed to perform acomputation, for example recycling one of your input buffers to store aresult. You should not reuse buffers that you donate to a computation; JAXwill raise an error if you try to. By default, no argument buffers aredonated.
If neither
donate_argnumsnordonate_argnamesis provided, noarguments are donated. Ifdonate_argnumsis not provided butdonate_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments thatcorrespond todonate_argnames(or vice versa). If bothdonate_argnumsanddonate_argnamesareprovided,inspect.signatureis not used, and only actualparameters listed in eitherdonate_argnumsordonate_argnameswillbe donated.For more details on buffer donation see theFAQ.
donate_argnames – optional, a string or collection of strings specifyingwhich named arguments are donated to the computation. See thecomment on
donate_argnumsfor details. If notprovided butdonate_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.keep_unused – optional boolean. IfFalse (the default), arguments that JAXdetermines to be unused byfunmay be dropped from resulting compiledXLA executables. Such arguments will not be transferred to the device norprovided to the underlying executable. IfTrue, unused arguments willnot be pruned.
device – This is an experimental feature and the API is likely to change.Optional, the Device the jitted function will run on. (Available devicescan be retrieved via
jax.devices().) The default is inheritedfrom XLA’s DeviceAssignment logic and is usually to usejax.devices()[0].backend – This is an experimental feature and the API is likely to change.Optional, a string representing the XLA backend:
'cpu','gpu', or'tpu'.inline – Optional boolean. Specify whether this function should be inlinedinto enclosing jaxprs. Default False.
- Returns:
A wrapped version of
fun, set up for just-in-time compilation.
Examples
In the following example,
selucan be compiled into a single fused kernelby XLA:>>>importjax>>>>>>@jax.jit...defselu(x,alpha=1.67,lmbda=1.05):...returnlmbda*jax.numpy.where(x>0,x,alpha*jax.numpy.exp(x)-alpha)>>>>>>key=jax.random.key(0)>>>x=jax.random.normal(key,(10,))>>>print(selu(x))[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748-0.85743 -0.78232 0.76827 0.59566 ]
Starting in JAX v0.8.1,
jit()supports the decorator factory patternfor specifying optional keywords:>>>@jax.jit(static_argnames=['n'])...defg(x,n):...foriinrange(n):...x=x**2...returnx>>>>>>g(jnp.arange(4),3)Array([ 0, 1, 256, 6561], dtype=int32)
For compatiblity with older JAX versions, a common pattern is to use
functools.partial():>>>fromfunctoolsimportpartial>>>>>>@partial(jax.jit,static_argnames=['n'])...defg(x,n):...foriinrange(n):...x=x**2...returnx>>>>>>g(jnp.arange(4),3)Array([ 0, 1, 256, 6561], dtype=int32)
