JAX Internals: primitives
Contents
JAX Internals: primitives#
Introduction to JAX primitives#
A JAX primitive is the basic computational unit of a JAX program. This document explains the interface that a JAX primitive must support to allow JAX to perform all its transformations (this is not a how-to guide).
For example, the multiply-add operation can be implemented in terms of the low-leveljax.lax.* primitives (which are like XLA operator wrappers) orjax.extend.core.Primitive("multiply_add"), as demonstrated further below.
And JAX is able to take sequences of such primitive operations, and transform them via its composable transformations of Python functions, such asjax.jit(),jax.grad() andjax.vmap(). JAX implements these transforms in aJAX-traceable way. This means that when a Python function is executed, the only operations it applies to the data are either:
Inspections of data attributes: Data information, such as shape or type; or
JAX primitives: These are the JAX special operations covered in this tutorial.
JAX primitives know how to operate on both concrete data values and abstract JAX values.A JAX-traceable function can be invoked by JAX with abstract arguments. For example, a JAX abstract value —ShapedArray(float32[2,2]) — captures the type and the shape of values, but not the concrete data values.
The JAX-transformed functions must themselves be JAX-traceable functionsto make sure that these transformations are composable, for example likejax.jit(jax.jacfwd(jax.grad(f))).
JAX provides pre-defined primitives corresponding to most XLA operations, including add, matmul, sin, cos, and indexing.
In addition, JAX offers an implementation of NumPy functions in terms of JAX primitives. This means thatPython programs using JAX’s implementation of NumPy are JAX-traceable and, therefore, transformable. Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.
Furthermore, the set of JAX primitives is extensible, so instead of reimplementing a function in terms of pre-defined JAX primitives, you can define a new primitive that encapsulates the behavior of the function.
Consider the following example: you want to add to JAX support for a multiply-add function with three arguments, defined mathematically asmultiply_add(x,y,z)=x*y+z. This function operates on 3 identically-shaped tensors of floating point values and performs the operations pointwise. You can do this by:
Using existing JAX primitives#
The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other functions that are themselves written using JAX primitives, for example, those defined in thejax.lax() module:
fromjax._src.laximportlaxfromjax._srcimportapidefmultiply_add_lax(x,y,z):"""Implementation of multiply-add using the `jax.lax` primitives."""returnlax.add(lax.mul(x,y),z)defsquare_add_lax(a,b):"""A square-add function using the newly defined multiply-add."""returnmultiply_add_lax(a,a,b)print("square_add_lax = ",square_add_lax(2.,10.))# Differentiate w.r.t. the first argumentprint("grad(square_add_lax) = ",api.grad(square_add_lax,argnums=0)(2.0,10.))
square_add_lax = 14.0grad(square_add_lax) = 4.0
To understand how JAX is internally using the primitives, add some helpers for tracing function calls:
#@title Helper functions (execute this cell)importfunctoolsimporttraceback_indentation=0def_trace(msg=None):"""Print a message at current indentation."""ifmsgisnotNone:print(" "*_indentation+msg)def_trace_indent(msg=None):"""Print a message and then indent the rest."""global_indentation_trace(msg)_indentation=1+_indentationdef_trace_unindent(msg=None):"""Unindent then print a message."""global_indentation_indentation=_indentation-1_trace(msg)deftrace(name):"""A decorator for functions to trace arguments and results."""deftrace_func(func):# pylint: disable=missing-docstringdefpp(v):"""Print certain values more succinctly"""vtype=str(type(v))if"jax._src.xla_bridge._JaxComputationBuilder"invtype:return"<JaxComputationBuilder>"elif"jaxlib._jax_.XlaOp"invtype:return"<XlaOp at 0x{:x}>".format(id(v))elif("partial_eval.JaxprTracer"invtypeor"batching.BatchTracer"invtypeor"ad.JVPTracer"invtype):return"Traced<{}>".format(v.aval)elifisinstance(v,tuple):return"({})".format(pp_values(v))else:returnstr(v)defpp_values(args):return", ".join([pp(arg)forarginargs])@functools.wraps(func)deffunc_wrapper(*args):_trace_indent("call{}({})".format(name,pp_values(args)))res=func(*args)_trace_unindent("|<-{} ={}".format(name,pp(res)))returnresreturnfunc_wrapperreturntrace_funcclassexpectNotImplementedError(object):"""Context manager to check for NotImplementedError."""def__enter__(self):passdef__exit__(self,type,value,tb):global_indentation_indentation=0iftypeisNotImplementedError:print("\nFound expected exception:")traceback.print_exc(limit=3)returnTrueeliftypeisNone:# No exceptionassertFalse,"Expected NotImplementedError"else:returnFalse
Instead of usingjax.lax() primitives directly, you can use other functionsthat are already written in terms of those primitives, such as those injax.numpy:
importjax.numpyasjnpimportnumpyasnp@trace("multiply_add_numpy")defmultiply_add_numpy(x,y,z):returnjnp.add(jnp.multiply(x,y),z)@trace("square_add_numpy")defsquare_add_numpy(a,b):returnmultiply_add_numpy(a,a,b)print("\nNormal evaluation:")print("square_add_numpy = ",square_add_numpy(2.,10.))print("\nGradient evaluation:")print("grad(square_add_numpy) = ",api.grad(square_add_numpy)(2.0,10.))
Normal evaluation:call square_add_numpy(2.0, 10.0) call multiply_add_numpy(2.0, 2.0, 10.0) |<- multiply_add_numpy = 14.0|<- square_add_numpy = 14.0square_add_numpy = 14.0Gradient evaluation:call square_add_numpy(GradTracer(primal=2.0, typeof(tangent)=f32[]), 10.0) call multiply_add_numpy(GradTracer(primal=2.0, typeof(tangent)=f32[]), GradTracer(primal=2.0, typeof(tangent)=f32[]), 10.0) |<- multiply_add_numpy = GradTracer(primal=14.0, typeof(tangent)=f32[])|<- square_add_numpy = GradTracer(primal=14.0, typeof(tangent)=f32[])grad(square_add_numpy) = 4.0
Notice that in the process of computingjax.grad(), JAX invokessquare_add_numpy andmultiply_add_numpy with special argumentsConcreteArray(...) (described further below in this colab). It is important to remember that a JAX-traceable function must be able to operate not only on concrete arguments but also on special abstract arguments that JAX may use to abstract the function execution.
The JAX traceability property is satisfied as long as the function is written in terms of JAX primitives.
Defining new JAX primitives#
The right way to add support for multiply-add is in terms of existing JAX primitives, as shown above. However, to demonstrate how JAX primitives work, pretend that you want to add a new primitive to JAX for the multiply-add functionality.
fromjax.extendimportcoremultiply_add_p=core.Primitive("multiply_add")# Create the primitive@trace("multiply_add_prim")defmultiply_add_prim(x,y,z):"""The JAX-traceable way to use the JAX primitive. Note that the traced arguments must be passed as positional arguments to `bind`. """returnmultiply_add_p.bind(x,y,z)@trace("square_add_prim")defsquare_add_prim(a,b):"""A square-add function implemented using the new JAX-primitive."""returnmultiply_add_prim(a,a,b)
If you try to call the newly defined functions, you’ll get an error, because you haven’t yet told JAX anything about the semantics of the new primitive.
withexpectNotImplementedError():square_add_prim(2.,10.)
call square_add_prim(2.0, 10.0) call multiply_add_prim(2.0, 2.0, 10.0)Found expected exception:
Traceback (most recent call last): File "/tmp/ipykernel_1924/2844449444.py", line 2, in <module> square_add_prim(2., 10.) File "/tmp/ipykernel_1924/3854395562.py", line 48, in func_wrapper res = func(*args) ^^^^^^^^^^^ File "/tmp/ipykernel_1924/3275395289.py", line 17, in square_add_prim return multiply_add_prim(a, a, b) ^^^^^^^^^^^^^^^^^^^^^^^^^^NotImplementedError: Evaluation rule for 'multiply_add' not implemented
Primal evaluation rules#
@trace("multiply_add_impl")defmultiply_add_impl(x,y,z):"""Concrete implementation of the primitive. This function does not need to be JAX traceable. Args: x, y, z: The concrete arguments of the primitive. Will only be called with concrete values. Returns: the concrete result of the primitive. """# Note: you can use the ordinary (non-JAX) NumPy, which is not JAX-traceable.returnnp.add(np.multiply(x,y),z)# Now, register the primal implementation with JAX:multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assertsquare_add_prim(2.,10.)==14.
call square_add_prim(2.0, 10.0) call multiply_add_prim(2.0, 2.0, 10.0) call multiply_add_impl(2.0, 2.0, 10.0) |<- multiply_add_impl = 14.0 |<- multiply_add_prim = 14.0|<- square_add_prim = 14.0
What happens when you usejit#
Now, if you try to usejit, you’ll get aNotImplementedError:
withexpectNotImplementedError():api.jit(square_add_prim)(2.,10.)
call square_add_prim(JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[]))Found expected exception:
Traceback (most recent call last): File "/tmp/ipykernel_1924/1813425700.py", line 2, in <module> api.jit(square_add_prim)(2., 10.) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 197, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/pjit.py", line 253, in cache_miss p, args_flat = _infer_params(fun, jit_info, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
Abstract evaluation rules#
To JIT the function, and for other transformations as well, JAX first evaluates it abstractly using only the shape and type of the arguments. This abstract evaluation serves multiple purposes:
Gets the sequence of JAX primitives that are used in the computation. This sequence will be compiled.
Computes the shape and type of all vectors and operations used in the computation.
For example, the abstraction of a vector with 3 elements may beShapedArray(float32[3]), orConcreteArray([1.,2.,3.]). In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
fromjaximportcore@trace("multiply_add_abstract_eval")defmultiply_add_abstract_eval(xs,ys,zs):"""Abstract evaluation of the primitive. This function does not need to be JAX traceable. It will be invoked with abstractions of the actual arguments Args: xs, ys, zs: Abstractions of the arguments. Result: a ShapedArray for the result of the primitive. """assertxs.shape==ys.shapeassertxs.shape==zs.shapereturncore.ShapedArray(xs.shape,xs.dtype)# Now, register the abstract evaluation with JAX:multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
If you re-attempt to applyjit, you can inspect how the abstract evaluation proceeds, but you’ll get another error about missing the actual XLA compilation rule:
withexpectNotImplementedError():api.jit(square_add_prim)(2.,10.)
call square_add_prim(JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[])|<- square_add_prim = JitTracer(float32[])Found expected exception:
Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module> app.launch_new_instance()jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpuThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.--------------------The above exception was the direct cause of the following exception:Traceback (most recent call last): File "/tmp/ipykernel_1924/1813425700.py", line 2, in <module> api.jit(square_add_prim)(2., 10.) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 197, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/pjit.py", line 255, in cache_miss executable, pgle_profiler, const_args) = _run_python_pjit( ^^^^^^^^^^^^^^^^^NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA Compilation rules#
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is the biggest hurdle to adding new functionality to JAX, because the set of XLA operations is limited, and JAX already has pre-defined primitives for most of them. However, XLA includes aCustomCall operation that can be used to encapsulate arbitrary functionality defined using C++.
fromjax._src.lib.mlir.dialectsimporthlo@trace("multiply_add_lowering")defmultiply_add_lowering(ctx,xc,yc,zc):"""The compilation to XLA of the primitive. Given an mlir.ir.Value for each argument, return the mlir.ir.Values for the results of the function. Does not need to be a JAX-traceable function. """return[hlo.AddOp(hlo.MulOp(xc,yc),zc).result]# Now, register the lowering rule with JAX.# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.htmlfromjax.interpretersimportmlirmlir.register_lowering(multiply_add_p,multiply_add_lowering,platform='cpu')
You will now succeed to applyjax.jit. Notice below that JAX first evaluates the function abstractly, which triggers themultiply_add_abstract_eval function, and then compiles the set of primitives it has encountered, includingmultiply_add. At this point JAX invokesmultiply_add_lowering.
assertapi.jit(lambdax,y:square_add_prim(x,y))(2.,10.)==14.
call square_add_prim(JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[])|<- square_add_prim = JitTracer(float32[])call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd314818290>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd3148165c0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd314816560>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd327f7fbd0>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd314838350>, all_default_mem_kind=True, lowering_cache={}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd3149cef70>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)), avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd314838b90>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1), Value(<block argument> of type 'tensor<f32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd327fc5ef0>]Below is another use ofjit, where you compile only with respect to the first argument. Notice how the second argument tosquare_add_prim is concrete, which leads in the third argument tomultiply_add_abstract_eval beingConcreteArray. Notice thatmultiply_add_abstract_eval may be used with bothShapedArray andConcreteArray.
assertapi.jit(lambdax,y:square_add_prim(x,y),static_argnums=1)(2.,10.)==14.
call square_add_prim(JitTracer(~float32[]), 10.0) call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), 10.0) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[])|<- square_add_prim = JitTracer(float32[])call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd314818bf0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd314817060>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd314817000>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd3374c3b10>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd3148398e0>, all_default_mem_kind=True, lowering_cache={}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd3149cee50>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)), avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd31483a150>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1), Value(<block argument> of type 'tensor<f32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd314845470>]Forward differentiation#
JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it inForward- and reverse-mode autodiff in JAX).
If you attempt to compute thejvp function, you’ll get an error because you have not yet told JAX how to differentiate themultiply_add primitive.
# The second argument is set to `(2., 10.)` values where you# evaluate the Jacobian, and the third argument `(1., 1.)`# contains the values of the tangents for the arguments.withexpectNotImplementedError():api.jvp(square_add_prim,(2.,10.),(1.,1.))
call square_add_prim(Traced<~float32[]>, Traced<~float32[]>) call multiply_add_prim(Traced<~float32[]>, Traced<~float32[]>, Traced<~float32[]>)Found expected exception:
Traceback (most recent call last): File "/tmp/ipykernel_1924/459539105.py", line 5, in <module> api.jvp(square_add_prim, (2., 10.), (1., 1.)) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 197, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/api.py", line 1953, in jvp return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})), ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^NotImplementedError: Differentiation rule for 'multiply_add' not implementedfromjax.interpretersimportad@trace("multiply_add_value_and_jvp")defmultiply_add_value_and_jvp(arg_values,arg_tangents):"""Evaluates the primal output and the tangents (Jacobian-vector product). Given values of the arguments and perturbation of the arguments (tangents), compute the output of the primitive and the perturbation of the output. This method must be JAX-traceable. JAX may invoke it with abstract values for the arguments and tangents. Args: arg_values: A tuple of arguments arg_tangents: A tuple with the tangents of the arguments. The tuple has the same length as the arg_values. Some of the tangents may also be the special value `ad.Zero` to specify a zero tangent Returns: A pair of the primal output and the tangent. """x,y,z=arg_valuesxt,yt,zt=arg_tangents_trace("Primal evaluation:")# Now, you have a JAX-traceable computation of the output.# Normally, you can use the multiply add (`ma`) primitive itself to compute the primal output.primal_out=multiply_add_prim(x,y,z)_trace("Tangent evaluation:")# You must use a JAX-traceable way to compute the tangent. It turns out that# the output tangent can be computed as (xt * y + x * yt + zt),# which you can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.# You do need to deal specially with `Zero`. Here, you just turn it into a# proper tensor of 0s (of the same shape as 'x').# An alternative would be to check for `Zero` and perform algebraic# simplification of the output tangent computation.defmake_zero(tan):returnlax.full_like(x,0)iftype(tan)isad.Zeroelsetanoutput_tangent=multiply_add_prim(make_zero(xt),y,multiply_add_prim(x,make_zero(yt),make_zero(zt)))return(primal_out,output_tangent)# Register the forward differentiation rule with JAX:ad.primitive_jvps[multiply_add_p]=multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.assertapi.jvp(square_add_prim,(2.,10.),(1.,1.))==(14.,5.)
call square_add_prim(Traced<~float32[]>, Traced<~float32[]>) call multiply_add_prim(Traced<~float32[]>, Traced<~float32[]>, Traced<~float32[]>) call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0)) Primal evaluation: call multiply_add_prim(2.0, 2.0, 10.0) call multiply_add_impl(2.0, 2.0, 10.0) |<- multiply_add_impl = 14.0 |<- multiply_add_prim = 14.0 Tangent evaluation: call multiply_add_prim(2.0, 1.0, 1.0) call multiply_add_impl(2.0, 1.0, 1.0) |<- multiply_add_impl = 3.0 |<- multiply_add_prim = 3.0 call multiply_add_prim(1.0, 2.0, 3.0) call multiply_add_impl(1.0, 2.0, 3.0) |<- multiply_add_impl = 5.0 |<- multiply_add_prim = 5.0 |<- multiply_add_value_and_jvp = (14.0, 5.0) |<- multiply_add_prim = Traced<float32[]>|<- square_add_prim = Traced<float32[]>
JIT of forward differentiation#
You can applyjit to the forward differentiation function:
assertapi.jit(lambdaarg_values,arg_tangents:api.jvp(square_add_prim,arg_values,arg_tangents))((2.,10.),(1.,1.))==(14.,5.)
call square_add_prim(Traced<~float32[]>, Traced<~float32[]>) call multiply_add_prim(Traced<~float32[]>, Traced<~float32[]>, Traced<~float32[]>) call multiply_add_value_and_jvp((JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])), (JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[]))) Primal evaluation: call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[]) Tangent evaluation: call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[]) call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[]) |<- multiply_add_value_and_jvp = (JitTracer(float32[]), JitTracer(float32[])) |<- multiply_add_prim = Traced<float32[]>|<- square_add_prim = Traced<float32[]>call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd314873a10>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd31487e110>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd31487e0b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd31483b510>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd31483b350>, all_default_mem_kind=True, lowering_cache={}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd314888150>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)), avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd31483bd70>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1), Value(<block argument> of type 'tensor<f32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd31488c730>]call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd314873a10>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd31487e110>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd31487e0b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd31483b510>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd31483b350>, all_default_mem_kind=True, lowering_cache={LoweringCacheKey(primitive=multiply_add, eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), avals_in=(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)), effects=frozenset(), params=FrozenDict({}), platforms=('cpu',)): LoweringCacheValue(func=<jaxlib.mlir.dialects.func.FuncOp object at 0x7dd31487e210>, output_types=[RankedTensorType(tensor<f32>)], const_args=(), const_arg_avals=(), inline=True)}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd314888150>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])), avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd31483bef0>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1), Value(<block argument> of type 'tensor<f32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd31488cef0>]Notice that first, you evaluatemultiply_add_value_and_jvp abstractly, which in turn evaluates abstractly both the primal and the tangent evaluation (a total of 3 invocations of thema primitive). Then, you compile the 3 occurrences of the primitive.
Reverse differentiation#
If you attempt now to use reverse differentiation, you’ll notice that JAX starts by using themultiply_add_value_and_jvp to compute the forward differentiation for abstract values, but then runs into aNotImplementedError.
When computing the reverse differentiation, JAX first performs an abstract evaluation of the forward differentiation codemultiply_add_value_and_jvp to obtain a trace of primitives that compute the output tangent.
Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents.
Notice that JAX uses the special abstract tangent value
Zerofor the tangent corresponding to the third argument ofma. This reflects the fact that you do not differentiate w.r.t. the second argument tosquare_add_prim, which flows to the third argument tomultiply_add_prim.Notice also that during the abstract evaluation of the tangent you pass the value
0.0as the tangent for the third argument. This is because of the use of themake_zerofunction in the definition ofmultiply_add_value_and_jvp.
# This is reverse differentiation w.r.t. the first argument of `square_add_prim`withexpectNotImplementedError():api.grad(square_add_prim)(2.,10.)
call square_add_prim(GradTracer(primal=2.0, typeof(tangent)=f32[]), 10.0) call multiply_add_prim(GradTracer(primal=2.0, typeof(tangent)=f32[]), GradTracer(primal=2.0, typeof(tangent)=f32[]), 10.0) call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<~float32[]>, Traced<~float32[]>, Zero(~float32[]))) Primal evaluation: call multiply_add_prim(2.0, 2.0, 10.0) call multiply_add_impl(2.0, 2.0, 10.0) |<- multiply_add_impl = 14.0 |<- multiply_add_prim = 14.0 Tangent evaluation: call multiply_add_prim(2.0, Traced<~float32[]>, 0.0) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = Traced<float32[]> call multiply_add_prim(Traced<~float32[]>, 2.0, Traced<float32[]>) call multiply_add_abstract_eval(~float32[], ~float32[], float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = Traced<float32[]> |<- multiply_add_value_and_jvp = (14.0, Traced<float32[]>) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] call multiply_add_abstract_eval(~float32[], ~float32[], float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = GradTracer(primal=14.0, typeof(tangent)=f32[])|<- square_add_prim = GradTracer(primal=14.0, typeof(tangent)=f32[])Found expected exception:
Traceback (most recent call last): File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/interpreters/ad.py", line 322, in get_primitive_transpose return primitive_transposes[p] ~~~~~~~~~~~~~~~~~~~~^^^KeyError: multiply_addThe above exception was the direct cause of the following exception:Traceback (most recent call last): File "<frozen runpy>", line 198, in _run_module_as_main File "<frozen runpy>", line 88, in _run_code File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module> app.launch_new_instance()jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implementedThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.--------------------The above exception was the direct cause of the following exception:Traceback (most recent call last): File "/tmp/ipykernel_1924/2155094905.py", line 3, in <module> api.grad(square_add_prim)(2., 10.) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 197, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/api.py", line 469, in grad_f _, g = value_and_grad_f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The above error is because there is a missing piece for JAX to be able to use the forward differentiation code to compute reverse differentiation.
Transposition#
As previously explained, when computing reverse differentiation, JAX obtains a trace of primitives that compute the tangent using forward differentiation. Then,JAX interprets this trace abstractly backwards and for each primitive it applies atransposition rule.
To understand what is going on, consider a simpler example of the functionf(x,y)=x*y+y. Assume, you need to differentiate at the point(2.,4.). JAX will produce the following JVP tangent calculation offt from the tangents of the inputxt andyt:
a=xt*4.b=2.*ytc=a+bft=c+yt
By construction, the tangent calculation is always linear in the input tangents. The only non-linear operator that may arise in the tangent calculation is multiplication, but then one of the operands is constant.
JAX will produce the reverse differentiation computation by processing the JVP computation backwards. For each operation in the tangent computation, it accumulates the cotangents of the variables used by the operation, using the cotangent of the result of the operation:
# Initialize cotangents of inputs and intermediate variables:xct=yct=act=bct=cct=0.# Initialize cotangent of the output:fct=1.# Process `ft = c + yt`:cct+=fctyct+=fct# Process `c = a + b`:act+=cctbct+=cct# Process `b = 2. * yt`:yct+=2.*bct# Process `a = xt * 4.`:xct+=act*4.
One can verify that this computation producesxct=4. andyct=3., whichare the partial derivatives of the functionf.
JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitivep(x,y,z) is linear in the argumentsy andz for a constant value ofx, e.g.,p(x,y,z)=y*cy+z*cz, then the transposition of the primitive is:
p_transpose(out_ct,x,_,_)=(None,out_ct*cy,out_ct*cz)
Notice thatp_transpose takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined_ value, and for the other arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the valueNone returned for the constant arguments.
In particular:
add_transpose(out_ct,_,_)=(out_ct,out_ct)mult_transpose(out_ct,x,_)=(None,x*out_ct)mult_transpose(out_ct,_,y)=(out_ct*y,None)
@trace("multiply_add_transpose")defmultiply_add_transpose(ct,x,y,z):"""Evaluates the transpose of a linear primitive. This method is only used when computing the backward gradient following `value_and_jvp`, and is only needed for primitives that are used in the JVP calculation for some other primitive. You need a transposition for `multiply_add_prim`, because you have used `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. In this case, multiply_add is not a linear primitive. However, it is used linearly w.r.t. tangents in `multiply_add_value_and_jvp`: `output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))`. Always one of the first two multiplicative arguments is a constant. Args: ct: The cotangent of the output of the primitive. x, y, z: The values of the arguments. The arguments that are used linearly get an ad.UndefinedPrimal value. The other arguments get a constant value. Returns: A tuple with the cotangent of the inputs, with the value None corresponding to the constant arguments. """ifnotad.is_undefined_primal(x):# This use of multiply_add is with a constant "x".assertad.is_undefined_primal(y)ct_y=ad.Zero(y.aval)iftype(ct)isad.Zeroelsemultiply_add_prim(x,ct,lax.full_like(x,0))res=None,ct_y,ctelse:# This use of multiply_add is with a constant "y".assertad.is_undefined_primal(x)ct_x=ad.Zero(x.aval)iftype(ct)isad.Zeroelsemultiply_add_prim(ct,y,lax.full_like(y,0))res=ct_x,None,ctreturnresad.primitive_transposes[multiply_add_p]=multiply_add_transpose
Now you can complete the run of thegrad:
assertapi.grad(square_add_prim)(2.,10.)==4.
call square_add_prim(GradTracer(primal=2.0, typeof(tangent)=f32[]), 10.0) call multiply_add_prim(GradTracer(primal=2.0, typeof(tangent)=f32[]), GradTracer(primal=2.0, typeof(tangent)=f32[]), 10.0) call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<~float32[]>, Traced<~float32[]>, Zero(~float32[]))) Primal evaluation: call multiply_add_prim(2.0, 2.0, 10.0) call multiply_add_impl(2.0, 2.0, 10.0) |<- multiply_add_impl = 14.0 |<- multiply_add_prim = 14.0 Tangent evaluation: call multiply_add_prim(2.0, Traced<~float32[]>, 0.0) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = Traced<float32[]> call multiply_add_prim(Traced<~float32[]>, 2.0, Traced<float32[]>) call multiply_add_abstract_eval(~float32[], ~float32[], float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = Traced<float32[]> |<- multiply_add_value_and_jvp = (14.0, Traced<float32[]>) |<- multiply_add_prim = GradTracer(primal=14.0, typeof(tangent)=f32[])|<- square_add_prim = GradTracer(primal=14.0, typeof(tangent)=f32[])call multiply_add_transpose(1.0, UndefinedPrimal(~float32[]), 2.0, UndefinedPrimal(float32[])) call multiply_add_prim(1.0, 2.0, 0.0) call multiply_add_impl(1.0, 2.0, 0.0) |<- multiply_add_impl = 2.0 |<- multiply_add_prim = 2.0|<- multiply_add_transpose = (2.0, None, 1.0)call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(~float32[]), 0.0) call multiply_add_prim(2.0, 1.0, 0.0) call multiply_add_impl(2.0, 1.0, 0.0) |<- multiply_add_impl = 2.0 |<- multiply_add_prim = 2.0|<- multiply_add_transpose = (None, 2.0, 1.0)
Notice the two calls tomultiply_add_transpose. They correspond to the two uses ofmultiply_add_prim in the computation of theoutput_tangent inmultiply_add_value_and_jvp. The first call to transpose corresponds to the last use ofmultiply_add_prim:multiply_add_prim(xt,y,...) wherey is the constant2.0.
JIT of reverse differentiation#
Notice that the abstract evaluation of themultiply_add_value_and_jvp is using only abstract values. Meanwhile, in the absence of JIT, you usedConcreteArray.
assertapi.jit(api.grad(square_add_prim))(2.,10.)==4.
call square_add_prim(GradTracer(primal=JitTracer(~float32[]), typeof(tangent)=f32[]), JitTracer(~float32[])) call multiply_add_prim(GradTracer(primal=JitTracer(~float32[]), typeof(tangent)=f32[]), GradTracer(primal=JitTracer(~float32[]), typeof(tangent)=f32[]), JitTracer(~float32[])) call multiply_add_value_and_jvp((JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])), (Traced<~float32[]>, Traced<~float32[]>, Zero(~float32[]))) Primal evaluation: call multiply_add_prim(JitTracer(~float32[]), JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[]) Tangent evaluation: call multiply_add_prim(JitTracer(~float32[]), Traced<~float32[]>, JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = Traced<float32[]> call multiply_add_prim(Traced<~float32[]>, JitTracer(~float32[]), Traced<float32[]>) call multiply_add_abstract_eval(~float32[], ~float32[], float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = Traced<float32[]> |<- multiply_add_value_and_jvp = (JitTracer(float32[]), Traced<float32[]>) |<- multiply_add_prim = GradTracer(primal=JitTracer(float32[]), typeof(tangent)=f32[])|<- square_add_prim = GradTracer(primal=JitTracer(float32[]), typeof(tangent)=f32[])call multiply_add_transpose(JitTracer(float32[]), UndefinedPrimal(~float32[]), JitTracer(~float32[]), UndefinedPrimal(float32[])) call multiply_add_prim(JitTracer(float32[]), JitTracer(~float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(float32[], ~float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[])|<- multiply_add_transpose = (JitTracer(float32[]), None, JitTracer(float32[]))call multiply_add_transpose(JitTracer(float32[]), JitTracer(~float32[]), UndefinedPrimal(~float32[]), JitTracer(~float32[])) call multiply_add_prim(JitTracer(~float32[]), JitTracer(float32[]), JitTracer(~float32[])) call multiply_add_abstract_eval(~float32[], float32[], ~float32[]) |<- multiply_add_abstract_eval = float32[] |<- multiply_add_prim = JitTracer(float32[])|<- multiply_add_transpose = (None, JitTracer(float32[]), JitTracer(float32[]))call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd3148c0530>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd3148babb0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd3148ba880>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd314838750>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd314892c90>, all_default_mem_kind=True, lowering_cache={}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd314889350>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)), avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd3148932f0>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1), Value(<block argument> of type 'tensor<f32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd3148c5930>]call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd3148c0530>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd3148babb0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd3148ba880>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd314838750>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd314892c90>, all_default_mem_kind=True, lowering_cache={LoweringCacheKey(primitive=multiply_add, eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), avals_in=(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)), effects=frozenset(), params=FrozenDict({}), platforms=('cpu',)): LoweringCacheValue(func=<jaxlib.mlir.dialects.func.FuncOp object at 0x7dd3148bb110>, output_types=[RankedTensorType(tensor<f32>)], const_args=(), const_arg_avals=(), inline=True)}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd314889350>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)), avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd314893440>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1), Value(<block argument> of type 'tensor<f32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd3148c5bf0>]Batching#
The batching transformation takes a point-wise computation and turns it into a computation on vectors. If you try it right now, you will get aNotImplementedError:
# The arguments are two vectors instead of two scalars.withexpectNotImplementedError():api.vmap(square_add_prim,in_axes=0,out_axes=0)(np.array([2.,3.]),np.array([10.,20.]))
call square_add_prim(Traced<float32[]>, Traced<float32[]>) call multiply_add_prim(Traced<float32[]>, Traced<float32[]>, Traced<float32[]>)Found expected exception:
Traceback (most recent call last): File "/tmp/ipykernel_1924/1080163607.py", line 3, in <module> api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]), File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 197, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.12/site-packages/jax/_src/api.py", line 1217, in vmap_f out_flat, inferred_out_axes = batching.batch( ^^^^^^^^^^^^^^^NotImplementedError: Batching rule for 'multiply_add' not implemented
You need to instruct JAX how to evaluate the batched version of the primitive. In this particular case, themultiply_add_prim already operates pointwise for any dimension of input vectors, so the batched version can use the samemultiply_add_prim implementation.
fromjax.interpretersimportbatching@trace("multiply_add_batch")defmultiply_add_batch(vector_arg_values,batch_axes):"""Computes the batched version of the primitive. This must be a JAX-traceable function. Since the `multiply_add primitive` already operates point-wise on arbitrary dimension tensors, to batch it you can use the primitive itself. This works as long as both the inputs have the same dimensions and are batched along the same axes. The result is batched along the axis that the inputs are batched. Args: vector_arg_values: A tuple of two arguments, each being a tensor of matching shape. batch_axes: The axes that are being batched. See vmap documentation. Returns: A tuple of the result, and the result axis that was batched. """assertbatch_axes[0]==batch_axes[1]assertbatch_axes[0]==batch_axes[2]_trace("Using multiply_add to compute the batch:")res=multiply_add_prim(*vector_arg_values)returnres,batch_axes[0]batching.primitive_batchers[multiply_add_p]=multiply_add_batch
assertnp.allclose(api.vmap(square_add_prim,in_axes=0,out_axes=0)(np.array([2.,3.]),np.array([10.,20.])),[14.,29.])
call square_add_prim(Traced<float32[]>, Traced<float32[]>) call multiply_add_prim(Traced<float32[]>, Traced<float32[]>, Traced<float32[]>) call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0)) Using multiply_add to compute the batch: call multiply_add_prim([2. 3.], [2. 3.], [10. 20.]) call multiply_add_impl([2. 3.], [2. 3.], [10. 20.]) |<- multiply_add_impl = [14. 29.] |<- multiply_add_prim = [14. 29.] |<- multiply_add_batch = ([14. 29.], 0) |<- multiply_add_prim = Traced<float32[]>|<- square_add_prim = Traced<float32[]>
JIT of batching#
Below is an example of applying JIT to batching:
assertnp.allclose(api.jit(api.vmap(square_add_prim,in_axes=0,out_axes=0))(np.array([2.,3.]),np.array([10.,20.])),[14.,29.])
call square_add_prim(Traced<float32[]>, Traced<float32[]>) call multiply_add_prim(Traced<float32[]>, Traced<float32[]>, Traced<float32[]>) call multiply_add_batch((JitTracer(float32[2]), JitTracer(float32[2]), JitTracer(float32[2])), (0, 0, 0)) Using multiply_add to compute the batch: call multiply_add_prim(JitTracer(float32[2]), JitTracer(float32[2]), JitTracer(float32[2])) call multiply_add_abstract_eval(float32[2], float32[2], float32[2]) |<- multiply_add_abstract_eval = float32[2] |<- multiply_add_prim = JitTracer(float32[2]) |<- multiply_add_batch = (JitTracer(float32[2]), 0) |<- multiply_add_prim = Traced<float32[]>|<- square_add_prim = Traced<float32[]>call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jax._src.interpreters.mlir.JaxIrContext object at 0x7dd3148c04d0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7dd3148bb920>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7dd3148bb8c0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7dd314893d80>, platforms=('cpu',), backend=<jaxlib._jax.Client object at 0x7dd32c287920>, axis_context=ShardingContext(num_devices=1, device_assignment=None, abstract_mesh=None), keepalives=[], channel_iterator=count(2), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7dd314893b90>, all_default_mem_kind=True, lowering_cache={}, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_to_location_cache=<jaxlib.mlir._mlir_libs._jax_mlir_ext.TracebackToLocationCache object at 0x7dd3148892c0>, canonical_name_cache={}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False, hoist_constants_as_args=False)), name_stack=NameStack(stack=()), traceback=None, primitive=multiply_add, avals_in=(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])), avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7dd3148c88f0>, tokens_out=None, const_lowering={}, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types=()), remove_size_one_mesh_axis=False, xla_metadata=None), platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1), Value(<block argument> of type 'tensor<2xf32>' at index: 2))|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7dd3148c7cb0>]