Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.make_jaxpr

Contents

jax.make_jaxpr#

jax.make_jaxpr(fun:Callable,static_argnums:int|Iterable[int]=(),axis_env:Sequence[tuple[AxisName,int]]|None=None,return_shape:Literal[False]=False,abstracted_axes:Any|None=None)Callable[...,core.ClosedJaxpr][source]#
jax.make_jaxpr(fun:Callable,static_argnums:int|Iterable[int]=(),axis_env:Sequence[tuple[AxisName,int]]|None=None,return_shape:Literal[True]=False,abstracted_axes:Any|None=None)Callable[...,tuple[core.ClosedJaxpr,Any]]

Create a function that returns the jaxpr offun given example args.

Parameters:
  • fun – The function whosejaxpr is to be computed. Its positionalarguments and return value should be arrays, scalars, or standard Pythoncontainers (tuple/list/dict) thereof.

  • static_argnums – See thejax.jit() docstring.

  • axis_env – Optional, a sequence of pairs where the first element is an axisname and the second element is a positive integer representing the size ofthe mapped axis with that name. This parameter is useful when loweringfunctions that involve parallel communication collectives, and itspecifies the axis name/size environment that would be set up byapplications ofjax.pmap().

  • return_shape – Optional boolean, defaults toFalse. IfTrue, thewrapped function returns a pair where the first element is theClosedJaxpr representation offun and the second element is apytree with the same structure as the output offun and where theleaves are objects withshape anddtype attributes representingthe corresponding types of the output leaves.

Returns:

A wrapped version offun that when applied to example arguments returnsaClosedJaxpr representation offun on those arguments. If theargumentreturn_shape isTrue, then the returned function insteadreturns a pair where the first element is theClosedJaxprrepresentation offun and the second element is a pytree representingthe structure, shape, dtypes, and named shapes of the output offun.

Ajaxpr is JAX’s intermediate representation for program traces. Thejaxpr language is based on the simply-typed first-order lambda calculuswith let-bindings.make_jaxpr() adapts a function to return itsjaxpr, which we can inspect to understand what JAX is doing internally.Thejaxpr returned is a trace offun abstracted toShapedArray level. Other levels of abstraction exist internally.

We do not describe the semantics of thejaxpr language in detail here, butinstead give a few examples.

>>>importjax>>>>>>deff(x):returnjax.numpy.sin(jax.numpy.cos(x))>>>print(f(3.0))-0.83602>>>jax.make_jaxpr(f)(3.0){ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }>>>jax.make_jaxpr(jax.grad(f))(3.0){ lambda ; a:f32[]. let    b:f32[] = cos a    c:f32[] = sin a    _:f32[] = sin b    d:f32[] = cos b    e:f32[] = mul 1.0:f32[] d    f:f32[] = neg e    g:f32[] = mul f c  in (g,) }
Contents

[8]ページ先頭

©2009-2025 Movatter.jp