jax.make_jaxpr
Contents
jax.make_jaxpr#
- jax.make_jaxpr(fun:Callable,static_argnums:int|Iterable[int]=(),axis_env:Sequence[tuple[AxisName,int]]|None=None,return_shape:Literal[False]=False,abstracted_axes:Any|None=None)→Callable[...,core.ClosedJaxpr][source]#
- jax.make_jaxpr(fun:Callable,static_argnums:int|Iterable[int]=(),axis_env:Sequence[tuple[AxisName,int]]|None=None,return_shape:Literal[True]=False,abstracted_axes:Any|None=None)→Callable[...,tuple[core.ClosedJaxpr,Any]]
Create a function that returns the jaxpr of
fungiven example args.- Parameters:
fun – The function whose
jaxpris to be computed. Its positionalarguments and return value should be arrays, scalars, or standard Pythoncontainers (tuple/list/dict) thereof.static_argnums – See the
jax.jit()docstring.axis_env – Optional, a sequence of pairs where the first element is an axisname and the second element is a positive integer representing the size ofthe mapped axis with that name. This parameter is useful when loweringfunctions that involve parallel communication collectives, and itspecifies the axis name/size environment that would be set up byapplications of
jax.pmap().return_shape – Optional boolean, defaults to
False. IfTrue, thewrapped function returns a pair where the first element is theClosedJaxprrepresentation offunand the second element is apytree with the same structure as the output offunand where theleaves are objects withshapeanddtypeattributes representingthe corresponding types of the output leaves.
- Returns:
A wrapped version of
funthat when applied to example arguments returnsaClosedJaxprrepresentation offunon those arguments. If theargumentreturn_shapeisTrue, then the returned function insteadreturns a pair where the first element is theClosedJaxprrepresentation offunand the second element is a pytree representingthe structure, shape, dtypes, and named shapes of the output offun.
A
jaxpris JAX’s intermediate representation for program traces. Thejaxprlanguage is based on the simply-typed first-order lambda calculuswith let-bindings.make_jaxpr()adapts a function to return itsjaxpr, which we can inspect to understand what JAX is doing internally.Thejaxprreturned is a trace offunabstracted toShapedArraylevel. Other levels of abstraction exist internally.We do not describe the semantics of the
jaxprlanguage in detail here, butinstead give a few examples.>>>importjax>>>>>>deff(x):returnjax.numpy.sin(jax.numpy.cos(x))>>>print(f(3.0))-0.83602>>>jax.make_jaxpr(f)(3.0){ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }>>>jax.make_jaxpr(jax.grad(f))(3.0){ lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0:f32[] d f:f32[] = neg e g:f32[] = mul f c in (g,) }
