Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Open inColab

Autodidax: JAX core from scratch#

Ever want to learn how JAX works, but the implementation seemed impenetrable?Well, you’re in luck! By reading this tutorial, you’ll learn every big idea inJAX’s core system. You’ll even get clued into our weird jargon!

This is a work-in-progress draft. There are some important ingredientsmissing, still to come in parts 5 and 6 (and more?). There are also somesimplifications here that we haven’t yet applied to the main system, but wewill.

Part 1: Transformations as interpreters: standard evaluation,jvp, andvmap#

We want to transform functions that look like this:

deff(x):y=sin(x)*2.z=-y+xreturnz

Think of functions likesin and the arithmetic operations underlying theinfix operators (mul,add, andneg) as primitive operations, meaningatomic units of processing rather than compositions.

“Transform” means “interpret differently.” Instead of standard interpretationwhere we apply primitive operations to numerical inputs to produce numericaloutputs, we want to override primitive application and let different valuesflow through our program. For example, we might want to replace theapplication of every primitive with an application ofits JVPrule,and let primal-tangent pairs flow through our program. Moreover, we want to beable to compose multiple transformations, leading to stacks of interpreters.

JAX core machinery#

We can implement stacks of interpreters and even have them all discharge onthe fly as we execute the Python function to be transformed. To start, let’sdefine these primitives so that we can intercept their application:

fromtypingimportNamedTupleclassPrimitive(NamedTuple):name:stradd_p=Primitive('add')mul_p=Primitive('mul')neg_p=Primitive("neg")sin_p=Primitive("sin")cos_p=Primitive("cos")reduce_sum_p=Primitive("reduce_sum")greater_p=Primitive("greater")less_p=Primitive("less")transpose_p=Primitive("transpose")broadcast_p=Primitive("broadcast")defadd(x,y):returnbind1(add_p,x,y)defmul(x,y):returnbind1(mul_p,x,y)defneg(x):returnbind1(neg_p,x)defsin(x):returnbind1(sin_p,x)defcos(x):returnbind1(cos_p,x)defgreater(x,y):returnbind1(greater_p,x,y)defless(x,y):returnbind1(less_p,x,y)deftranspose(x,perm):returnbind1(transpose_p,x,perm=perm)defbroadcast(x,shape,axes):returnbind1(broadcast_p,x,shape=shape,axes=axes)defreduce_sum(x,axis=None):ifaxisisNone:axis=tuple(range(np.ndim(x)))iftype(axis)isint:axis=(axis,)returnbind1(reduce_sum_p,x,axis=axis)defbind1(prim,*args,**params):out,=bind(prim,*args,**params)returnout

We’ll set up array data types and infix operator methods in a moment.

APrimitive is just an object with a name, to which we attach ourinterpretation rules (one for each transformation). Thebind function is ourinterception point: it’ll figure out which transformation rule to apply, basedon how the arguments are boxed in tracers and what interpreters are active.

The functions that user code calls, likeadd andsin, are just wrappersaround calls tobind. These wrappers let us control how arguments are passedtobind, and in particular we follow a handy internal convention: when wecallbind, we pass values representing array data as positional arguments,and we pass metadata like theaxis argument toreduce_sum_p via keyword. Thiscalling convention simplifies some core logic (since e.g. instances of theTracer class to be defined below can only occur in positional arguments tobind). The wrappers can also provide docstrings!

We represent active interpreters as a stack. The stack is just a simplelist, and each element is a container with an integer level (correspondingto the element’s height in the stack), an interpreter type (which we’ll call atrace_type), and an optional field for any global data the interpreterneeds. We call each element aMainTrace, though maybe “Interpreter” would bemore descriptive.

fromcollections.abcimportSequencefromcontextlibimportcontextmanagerfromtypingimportAnyclassMainTrace(NamedTuple):level:inttrace_type:type['Trace']global_data:Any|Nonetrace_stack:list[MainTrace]=[]dynamic_trace:MainTrace|None=None# to be employed in Part 3@contextmanagerdefnew_main(trace_type:type['Trace'],global_data=None):level=len(trace_stack)main=MainTrace(level,trace_type,global_data)trace_stack.append(main)try:yieldmainfinally:trace_stack.pop()

When we’re about to apply a transformation, we’ll push another interpreteronto the stack usingnew_main. Then, as we apply primitives in the function,we can think of thebind first being interpreted by the trace at the top ofthe stack (i.e. with the highest level). If that first interpreter itselfbinds other primitives in its interpretation rule for the primitive, like howthe JVP rule ofsin_p might bindcos_p andmul_p, then thosebindcalls will be handled by the interpreter at the next level down.

What goes at the bottom of the interpreter stack? At the bottom, we know allthe transformation interpreters are finished, and we just want to do standardevaluation. So at the bottom we’ll put an evaluation interpreter.

Let’s sketch out the interface for interpreters, which is based on theTraceandTracer base classes. ATracer represents a boxed-up value, perhapscarrying some extra context data used by the interpreter. ATrace handlesboxing up values intoTracers and also handles primitive application.

classTrace:main:MainTracedef__init__(self,main:MainTrace)->None:self.main=maindefpure(self,val):assertFalse# must overridedeflift(self,val):assertFalse# must overridedefprocess_primitive(self,primitive,tracers,params):assertFalse# must override

The first two methods are about boxing up values inTracers, which are theobjects that flow through the Python programs we transform. The last method isthe callback we’ll use to interpret primitive application.

TheTrace itself doesn’t contain any data, other than a reference to itscorrespondingMainTrace instance. In fact, multiple instances of aTracemight be created and discarded during an application of a transformation,whereas only a singleMainTrace instance is created per application of atransformation.

As forTracers themselves, each one carries an abstract value (and forwardsinfix operators to it), and the rest is up to the transformation. (Therelationship betweenTracers andAbstractValues is that there’s oneTracer per transformation, and at least oneAbstractValue per base type,like arrays.)

importnumpyasnpclassTracer:_trace:Trace__array_priority__=1000@propertydefaval(self):assertFalse# must overridedeffull_lower(self):returnself# default implementationdef__neg__(self):returnself.aval._neg(self)def__add__(self,other):returnself.aval._add(self,other)def__radd__(self,other):returnself.aval._radd(self,other)def__mul__(self,other):returnself.aval._mul(self,other)def__rmul__(self,other):returnself.aval._rmul(self,other)def__gt__(self,other):returnself.aval._gt(self,other)def__lt__(self,other):returnself.aval._lt(self,other)def__bool__(self):returnself.aval._bool(self)def__nonzero__(self):returnself.aval._nonzero(self)def__getattr__(self,name):try:returngetattr(self.aval,name)exceptAttributeError:raiseAttributeError(f"{self.__class__.__name__} has no attribute{name}")defswap(f):returnlambdax,y:f(y,x)
classShapedArray:array_abstraction_level=1shape:tuple[int,...]dtype:np.dtypedef__init__(self,shape,dtype):self.shape=shapeself.dtype=dtype@propertydefndim(self):returnlen(self.shape)_neg=staticmethod(neg)_add=staticmethod(add)_radd=staticmethod(swap(add))_mul=staticmethod(mul)_rmul=staticmethod(swap(mul))_gt=staticmethod(greater)_lt=staticmethod(less)@staticmethoddef_bool(tracer):raiseException("ShapedArray can't be unambiguously converted to bool")@staticmethoddef_nonzero(tracer):raiseException("ShapedArray can't be unambiguously converted to bool")defstr_short(self):returnf'{self.dtype.name}[{",".join(str(d)fordinself.shape)}]'def__hash__(self):returnhash((self.shape,self.dtype))def__eq__(self,other):return(type(self)istype(other)andself.shape==other.shapeandself.dtype==other.dtype)def__repr__(self):returnf"ShapedArray(shape={self.shape}, dtype={self.dtype})"classConcreteArray(ShapedArray):array_abstraction_level=2val:np.ndarraydef__init__(self,val):self.val=valself.shape=val.shapeself.dtype=val.dtype@staticmethoddef_bool(tracer):returnbool(tracer.aval.val)@staticmethoddef_nonzero(tracer):returnbool(tracer.aval.val)defget_aval(x):ifisinstance(x,Tracer):returnx.avaleliftype(x)injax_types:returnConcreteArray(np.asarray(x))else:raiseTypeError(x)jax_types={bool,int,float,np.bool_,np.int32,np.int64,np.float32,np.float64,np.ndarray}

Notice that we actually have twoAbstractValues for arrays, representingdifferent levels of abstraction. AShapedArray represents the set of allpossible arrays with a given shape and dtype. AConcreteArray represents asingleton set consisting of a single array value.

Now that we’ve set up the interpreter stack, the Trace/Tracer API forinterpreters, and abstract values, we can come back to implementbind:

defbind(prim,*args,**params):top_trace=find_top_trace(args)tracers=[full_raise(top_trace,arg)forarginargs]outs=top_trace.process_primitive(prim,tracers,params)return[full_lower(out)foroutinouts]

The main action is that we callfind_top_trace to figure out whichinterpreter should handle this primitive application. We then call that toptrace’sprocess_primitive so that the trace can apply its interpretationrule. The calls tofull_raise just ensure that the inputs are boxed in thetop trace’sTracer instances, and the call tofull_lower is an optionaloptimization so that we unbox values out ofTracers as much as possible.

importoperatorasopdeffind_top_trace(xs)->Trace:top_main=max((x._trace.mainforxinxsifisinstance(x,Tracer)),default=trace_stack[0],key=op.attrgetter('level'))ifdynamic_traceanddynamic_trace.level>top_main.level:top_main=dynamic_tracereturntop_main.trace_type(top_main)

In words, ignoring thedynamic_trace step until Part 3,find_top_tracereturns the highest-level interpreter associated with theTracers on itsinputs, and otherwise returns the interpreter at the bottom of the stack(which is always an evaluation trace, at least for now). This is a deviationfrom the description above, where we always start by running the interpreterat the top of the stack and then work our way down, applying every interpreterin the stack. Instead, we’re only applying an interpreter when the inputarguments to a primitive bind are boxed in aTracer corresponding to thatinterpreter. This optimization lets us skip irrelevant transformations, butbakes in an assumption that transformations mostly follow data dependence(except for the special bottom-of-the-stack interpreter, which interpretseverything).

An alternative would be to have every interpreter in the stack interpret everyoperation. That’s worth exploring! JAX is designed around data dependence inlarge part because that’s so natural for automatic differentiation, and JAX’sroots are in autodiff. But it may be over-fit.

deffull_lower(val:Any):ifisinstance(val,Tracer):returnval.full_lower()else:returnvaldeffull_raise(trace:Trace,val:Any)->Tracer:ifnotisinstance(val,Tracer):asserttype(val)injax_typesreturntrace.pure(val)level=trace.main.levelifval._trace.mainistrace.main:returnvalelifval._trace.main.level<level:returntrace.lift(val)elifval._trace.main.level>level:raiseException(f"Can't lift level{val._trace.main.level} to{level}.")else:# val._trace.level == levelraiseException(f"Different traces at same level:{val._trace},{trace}.")

The logic infull_raise serves to box values intoTracers for a particularTrace, calling different methods on theTrace based on context:Trace.pure is called on non-Tracer constants, andTrace.lift is calledfor values that are alreadyTracers from a lower-level interpreter. Thesetwo methods could share the same implementation, but by distinguishing them inthe core logic we can provide more information to theTrace subclass.

That’s it for the JAX core! Now we can start adding interpreters.

Evaluation interpreter#

We’ll start with the simplest interpreter: the evaluation interpreter thatwill sit at the bottom of the interpreter stack.

classEvalTrace(Trace):pure=lift=lambdaself,x:x# no boxing in Tracers neededdefprocess_primitive(self,primitive,tracers,params):returnimpl_rules[primitive](*tracers,**params)trace_stack.append(MainTrace(0,EvalTrace,None))# special bottom of the stack# NB: in JAX, instead of a dict we attach impl rules to the Primitive instanceimpl_rules={}impl_rules[add_p]=lambdax,y:[np.add(x,y)]impl_rules[mul_p]=lambdax,y:[np.multiply(x,y)]impl_rules[neg_p]=lambdax:[np.negative(x)]impl_rules[sin_p]=lambdax:[np.sin(x)]impl_rules[cos_p]=lambdax:[np.cos(x)]impl_rules[reduce_sum_p]=lambdax,*,axis:[np.sum(x,axis)]impl_rules[greater_p]=lambdax,y:[np.greater(x,y)]impl_rules[less_p]=lambdax,y:[np.less(x,y)]impl_rules[transpose_p]=lambdax,*,perm:[np.transpose(x,perm)]defbroadcast_impl(x,*,shape,axes):foraxisinsorted(axes):x=np.expand_dims(x,axis)return[np.broadcast_to(x,shape)]impl_rules[broadcast_p]=broadcast_impl

With this interpreter, we can evaluate user functions:

deff(x):y=sin(x)*2.z=-y+xreturnzprint(f(3.0))
2.7177599838802657

Woo! Like going around in a big circle. But the point of this indirection isthat now we can add some real transformations.

Forward-mode autodiff withjvp#

First, a few helper functions:

importbuiltinsdefzeros_like(val):aval=get_aval(val)returnnp.zeros(aval.shape,aval.dtype)defunzip2(pairs):lst1,lst2=[],[]forx1,x2inpairs:lst1.append(x1)lst2.append(x2)returnlst1,lst2defmap(f,*xs):returnlist(builtins.map(f,*xs))defzip(*args):fst,*rest=args=map(list,args)n=len(fst)forarginrest:assertlen(arg)==nreturnlist(builtins.zip(*args))

TheTracer for forward-mode autodiff carries a primal-tangent pair. TheTrace applies JVP rules.

classJVPTracer(Tracer):def__init__(self,trace,primal,tangent):self._trace=traceself.primal=primalself.tangent=tangent@propertydefaval(self):returnget_aval(self.primal)classJVPTrace(Trace):pure=lift=lambdaself,val:JVPTracer(self,val,zeros_like(val))defprocess_primitive(self,primitive,tracers,params):primals_in,tangents_in=unzip2((t.primal,t.tangent)fortintracers)jvp_rule=jvp_rules[primitive]primal_outs,tangent_outs=jvp_rule(primals_in,tangents_in,**params)return[JVPTracer(self,x,t)forx,tinzip(primal_outs,tangent_outs)]jvp_rules={}

Notice bothpure andlift package a value into aJVPTracer with theminimal amount of context, which is a zero tangent value.

Let’s add some JVP rules for primitives:

defadd_jvp(primals,tangents):(x,y),(x_dot,y_dot)=primals,tangentsreturn[x+y],[x_dot+y_dot]jvp_rules[add_p]=add_jvpdefmul_jvp(primals,tangents):(x,y),(x_dot,y_dot)=primals,tangentsreturn[x*y],[x_dot*y+x*y_dot]jvp_rules[mul_p]=mul_jvpdefsin_jvp(primals,tangents):(x,),(x_dot,)=primals,tangentsreturn[sin(x)],[cos(x)*x_dot]jvp_rules[sin_p]=sin_jvpdefcos_jvp(primals,tangents):(x,),(x_dot,)=primals,tangentsreturn[cos(x)],[-sin(x)*x_dot]jvp_rules[cos_p]=cos_jvpdefneg_jvp(primals,tangents):(x,),(x_dot,)=primals,tangentsreturn[neg(x)],[neg(x_dot)]jvp_rules[neg_p]=neg_jvpdefreduce_sum_jvp(primals,tangents,*,axis):(x,),(x_dot,)=primals,tangentsreturn[reduce_sum(x,axis)],[reduce_sum(x_dot,axis)]jvp_rules[reduce_sum_p]=reduce_sum_jvpdefgreater_jvp(primals,tangents):(x,y),_=primals,tangentsout_primal=greater(x,y)return[out_primal],[zeros_like(out_primal)]jvp_rules[greater_p]=greater_jvpdefless_jvp(primals,tangents):(x,y),_=primals,tangentsout_primal=less(x,y)return[out_primal],[zeros_like(out_primal)]jvp_rules[less_p]=less_jvp

Finally, we add a transformation API to kick off the trace:

defjvp_v1(f,primals,tangents):withnew_main(JVPTrace)asmain:trace=JVPTrace(main)tracers_in=[JVPTracer(trace,x,t)forx,tinzip(primals,tangents)]out=f(*tracers_in)tracer_out=full_raise(trace,out)primal_out,tangent_out=tracer_out.primal,tracer_out.tangentreturnprimal_out,tangent_out

And with that, we can differentiate!

x=3.0y,sin_deriv_at_3=jvp_v1(sin,(x,),(1.0,))print(sin_deriv_at_3)print(cos(3.0))
-0.9899924966004454-0.9899924966004454
deff(x):y=sin(x)*2.z=-y+xreturnzx,xdot=3.,1.y,ydot=jvp_v1(f,(x,),(xdot,))print(y)print(ydot)
2.71775998388026572.979984993200891
defderiv(f):returnlambdax:jvp_v1(f,(x,),(1.,))[1]print(deriv(sin)(3.))print(deriv(deriv(sin))(3.))print(deriv(deriv(deriv(sin)))(3.))print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454-0.14112000805986720.98999249660044540.1411200080598672
deff(x):ifx>0.:# Python control flowreturn2.*xelse:returnxprint(deriv(f)(3.))print(deriv(f)(-3.))
2.01.0

Pytrees and flattening user functions’ inputs and outputs#

A limitation withjvp_v1 is that it assumes the user function accepts arraysas positional arguments and produces a single array as output. What if itproduced a list as output? Or accepted nested containers as inputs? It wouldbe a pain to deal with all the possible containers in inputs and outputs atevery layer of the stack. Instead, we can wrap the user function so that thewrapped version accepts arrays as inputs and returns a flat list of arrays asoutput. The wrapper just needs to unflatten its input, call the user function,and flatten the output.

Here’s how we’d like to writejvp, assuming the user always gives usfunctions that take arrays as inputs and produces a flat list of arrays asoutputs:

defjvp_flat(f,primals,tangents):withnew_main(JVPTrace)asmain:trace=JVPTrace(main)tracers_in=[JVPTracer(trace,x,t)forx,tinzip(primals,tangents)]outs=f(*tracers_in)tracers_out=[full_raise(trace,out)foroutinouts]primals_out,tangents_out=unzip2((t.primal,t.tangent)fortintracers_out)returnprimals_out,tangents_out

To support user functions that have arbitrary containers in the inputs andoutputs, here’s how we’d write the user-facingjvp wrapper:

defjvp(f,primals,tangents):primals_flat,in_tree=tree_flatten(primals)tangents_flat,in_tree2=tree_flatten(tangents)ifin_tree!=in_tree2:raiseTypeErrorf,out_tree=flatten_fun(f,in_tree)primals_out_flat,tangents_out_flat=jvp_flat(f,primals_flat,tangents_flat)primals_out=tree_unflatten(out_tree(),primals_out_flat)tangents_out=tree_unflatten(out_tree(),tangents_out_flat)returnprimals_out,tangents_out

Notice that we had to plumb the tree structure of the user function outputback to the caller offlatten_fun. That information isn’t available until weactually run the user function, soflatten_fun just returns a reference to amutable cell, represented as a thunk. These side-effects are safe because wealways run the user function exactly once. (This safe regime is the reason forthe “linear” name inlinear_util.py, in the sense oflineartypes.)

All that remains is to writetree_flatten,tree_unflatten, andflatten_fun.

Show code cell source

Hide code cell source

defflatten_fun(f,in_tree):store=Store()defflat_fun(*args_flat):pytree_args=tree_unflatten(in_tree,args_flat)out=f(*pytree_args)out_flat,out_tree=tree_flatten(out)store.set_value(out_tree)returnout_flatreturnflat_fun,storeclassEmpty:passempty=Empty()classStore:val=emptydefset_value(self,val):assertself.valisemptyself.val=valdef__call__(self):returnself.val

Show code cell source

Hide code cell source

fromcollections.abcimportHashable,Iterable,Iteratorimportitertoolsasitfromcollections.abcimportCallableclassNodeType(NamedTuple):name:strto_iterable:Callablefrom_iterable:Callabledefregister_pytree_node(ty:type,to_iter:Callable,from_iter:Callable)->None:node_types[ty]=NodeType(str(ty),to_iter,from_iter)node_types:dict[type,NodeType]={}register_pytree_node(tuple,lambdat:(None,t),lambda_,xs:tuple(xs))register_pytree_node(list,lambdal:(None,l),lambda_,xs:list(xs))register_pytree_node(dict,lambdad:map(tuple,unzip2(sorted(d.items()))),lambdakeys,vals:dict(zip(keys,vals)))classPyTreeDef(NamedTuple):node_type:NodeTypenode_metadata:Hashablechild_treedefs:tuple['PyTreeDef',...]classLeaf:passleaf=Leaf()deftree_flatten(x:Any)->tuple[list[Any],PyTreeDef]:children_iter,treedef=_tree_flatten(x)returnlist(children_iter),treedefdef_tree_flatten(x:Any)->tuple[Iterable,PyTreeDef]:node_type=node_types.get(type(x))ifnode_type:node_metadata,children=node_type.to_iterable(x)children_flat,child_trees=unzip2(map(_tree_flatten,children))flattened=it.chain.from_iterable(children_flat)returnflattened,PyTreeDef(node_type,node_metadata,tuple(child_trees))else:return[x],leafdeftree_unflatten(treedef:PyTreeDef,xs:list[Any])->Any:return_tree_unflatten(treedef,iter(xs))def_tree_unflatten(treedef:PyTreeDef,xs:Iterator)->Any:iftreedefisleaf:returnnext(xs)else:children=(_tree_unflatten(t,xs)fortintreedef.child_treedefs)returntreedef.node_type.from_iterable(treedef.node_metadata,children)

With this pytree-handlingjvp implementation, we can now handle arbitraryinput and output containers. That’ll come in handy with future transformationstoo!

deff(x):y=sin(x)*2.z=-y+xreturn{'hi':z,'there':[x,y]}x,xdot=3.,1.y,ydot=jvp(f,(x,),(xdot,))print(y)print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}

Vectorized batching withvmap#

First, a couple helper functions, one for producing mapped abstract valuesfrom unmapped ones (by removing an axis), and one for moving batch dimensionsaround:

defmapped_aval(batch_dim,aval):shape=list(aval.shape)delshape[batch_dim]returnShapedArray(tuple(shape),aval.dtype)defmove_batch_axis(axis_size,src,dst,x):ifsrcisnot_mapped:target_shape=list(np.shape(x))target_shape.insert(dst,axis_size)returnbroadcast(x,target_shape,[dst])elifsrc==dst:returnxelse:returnmoveaxis(x,src,dst)defmoveaxis(x,src:int,dst:int):perm=[iforiinrange(np.ndim(x))ifi!=src]perm.insert(dst,src)returntranspose(x,perm)

TheTracer for vectorized batching carries a batched value and an optionalinteger indicating which axis (if any) is the batch axis.

fromtypingimportUnionclassNotMapped:passnot_mapped=NotMapped()BatchAxis=Union[NotMapped,int]classBatchTracer(Tracer):def__init__(self,trace,val,batch_dim:BatchAxis):self._trace=traceself.val=valself.batch_dim=batch_dim@propertydefaval(self):ifself.batch_dimisnot_mapped:returnget_aval(self.val)else:returnmapped_aval(self.batch_dim,get_aval(self.val))deffull_lower(self):ifself.batch_dimisnot_mapped:returnfull_lower(self.val)else:returnselfclassBatchTrace(Trace):pure=lift=lambdaself,val:BatchTracer(self,val,not_mapped)defprocess_primitive(self,primitive,tracers,params):vals_in,bdims_in=unzip2((t.val,t.batch_dim)fortintracers)vmap_rule=vmap_rules[primitive]val_outs,bdim_outs=vmap_rule(self.axis_size,vals_in,bdims_in,**params)return[BatchTracer(self,x,bd)forx,bdinzip(val_outs,bdim_outs)]@propertydefaxis_size(self):returnself.main.global_datavmap_rules={}

Here we’ve implemented the optionalTracer.full_lower method, which lets uspeel off a batching tracer if it’s not needed because it doesn’t represent abatched value.

ForBatchTrace, analogous toJVPTrace, the methodspure andlift justbox a value in aBatchTracer with the minimal amount of context, which inthis case is abatch_dim taking the sentinel valuenot_mapped. Notice weuse theMainTrace’s interpreter-global data field to store the batch axissize.

Next we can define batching interpreter rules for each primitive:

fromfunctoolsimportpartialdefbinop_batching_rule(op,axis_size,vals_in,dims_in):(x,y),(x_bdim,y_bdim)=vals_in,dims_inifx_bdim!=y_bdim:ifx_bdimisnot_mapped:x=move_batch_axis(axis_size,x_bdim,y_bdim,x)x_bdim=y_bdimelse:y=move_batch_axis(axis_size,y_bdim,x_bdim,y)return[op(x,y)],[x_bdim]vmap_rules[add_p]=partial(binop_batching_rule,add)vmap_rules[mul_p]=partial(binop_batching_rule,mul)defvectorized_unop_batching_rule(op,axis_size,vals_in,dims_in):(x,),(x_bdim,)=vals_in,dims_inreturn[op(x)],[x_bdim]vmap_rules[sin_p]=partial(vectorized_unop_batching_rule,sin)vmap_rules[cos_p]=partial(vectorized_unop_batching_rule,cos)vmap_rules[neg_p]=partial(vectorized_unop_batching_rule,neg)defreduce_sum_batching_rule(axis_size,vals_in,dims_in,*,axis):(x,),(x_bdim,)=vals_in,dims_innew_axis=tuple(ax+(x_bdim<=ax)foraxinaxis)out_bdim=x_bdim-sum(ax<x_bdimforaxinaxis)return[reduce_sum(x,new_axis)],[out_bdim]vmap_rules[reduce_sum_p]=reduce_sum_batching_rule

Finally, we add a transformation API to kick off the trace:

defvmap_flat(f,in_axes,*args):axis_size,={x.shape[ax]forx,axinzip(args,in_axes)ifaxisnotnot_mapped}withnew_main(BatchTrace,axis_size)asmain:trace=BatchTrace(main)tracers_in=[BatchTracer(trace,x,ax)ifaxisnotNoneelsexforx,axinzip(args,in_axes)]outs=f(*tracers_in)tracers_out=[full_raise(trace,out)foroutinouts]vals_out,bdims_out=unzip2((t.val,t.batch_dim)fortintracers_out)outs_transposed=[move_batch_axis(axis_size,bdim,0,val_out)forval_out,bdiminzip(vals_out,bdims_out)]returnouts_transposeddefvmap(f,in_axes):defbatched_f(*args):args_flat,in_tree=tree_flatten(args)in_axes_flat,in_tree2=tree_flatten(in_axes)ifin_tree!=in_tree2:raiseTypeErrorf_flat,out_tree=flatten_fun(f,in_tree)outs_flat=vmap_flat(f_flat,in_axes_flat,*args_flat)returntree_unflatten(out_tree(),outs_flat)returnbatched_f
defadd_one_to_a_scalar(scalar):assertnp.ndim(scalar)==0return1+scalarvector_in=np.arange(3.)vector_out=vmap(add_one_to_a_scalar,(0,))(vector_in)print(vector_in)print(vector_out)
[0. 1. 2.][1. 2. 3.]
defjacfwd(f,x):pushfwd=lambdav:jvp(f,(x,),(v,))[1]vecs_in=np.eye(np.size(x)).reshape(np.shape(x)*2)returnvmap(pushfwd,(0,))(vecs_in)deff(x):returnsin(x)jacfwd(f,np.arange(3.))
array([[ 1.        ,  0.        , -0.        ],       [ 0.        ,  0.54030231, -0.        ],       [ 0.        ,  0.        , -0.41614684]])

That’s it forjvp andvmap!

Part 2: Jaxprs#

The next transformations on the horizon arejit for just-in-timecompilation andvjp for reverse-mode autodiff. (grad is just a smallwrapper aroundvjp.) Whereasjvp andvmap only needed eachTracer tocarry a little bit of extra context, for bothjit andvjp we need muchricher context: we need to representprograms. That is, we need jaxprs!

Jaxprs are JAX’s internal intermediate representation of programs. They areexplicitly typed, functional, first-order, and in ANF form. We need aprogram representation forjit because the purpose ofjit is to stagecomputation out of Python. For any computation we want to stage out, we needto be able to represent it as data, and build it up as we trace a Pythonfunction. Similarly,vjp needs a way to represent the computation for thebackward pass of reverse-mode autodiff. We use the same jaxpr programrepresentation for both needs.

(Building a program representation is the mostfree kind oftrace-transformation, and so except for issues around handling native Pythoncontrol flow, any transformation could be implemented by first tracing to ajaxpr and then interpreting the jaxpr.)

Jaxpr data structures#

The jaxpr term syntax is roughly:

jaxpr::={lambda<binder>,....let<eqn>...in(<atom>,...)}binder::=<var>:<array_type>var::=a|b|c|...atom::=<var>|<literal>literal::=<int32>|<int64>|<float32>|<float64>eqn::=<binder>,...=<primitive>[<params>]<atom>,...

The syntax of types is:

jaxpr_type::=[<array_type>,...]->[<array_type>,...]array_type::=<dtype>[<shape>]dtype::=f32|f64|i32|i64shape::=<int>,...

How do we represent these as Python data structures? We reuse ShapedArrays torepresent types, and we can represent the term syntax with a few Pythonstructs:

classVar:aval:ShapedArraydef__init__(self,aval):self.aval=avalclassLit:val:Anyaval:ShapedArraydef__init__(self,val):self.aval=aval=raise_to_shaped(get_aval(val))self.val=np.array(val,aval.dtype)Atom=Union[Var,Lit]classJaxprEqn(NamedTuple):primitive:Primitiveinputs:list[Atom]params:dict[str,Any]out_binders:list[Var]classJaxpr(NamedTuple):in_binders:list[Var]eqns:list[JaxprEqn]outs:list[Atom]def__hash__(self):returnid(self)__eq__=op.is_defraise_to_shaped(aval):returnShapedArray(aval.shape,aval.dtype)

Type-checking a jaxpr involves checking that there are no unbound variables,that variables are only bound once, and that for each equation the type ofthe primitive application matches the type of the output binders.

classJaxprType(NamedTuple):in_types:list[ShapedArray]out_types:list[ShapedArray]def__repr__(self):in_types=', '.join(aval.str_short()foravalinself.in_types)out_types=', '.join(aval.str_short()foravalinself.out_types)returnf'({in_types}) -> ({out_types})'deftypecheck_jaxpr(jaxpr:Jaxpr)->JaxprType:env:set[Var]=set()forvinjaxpr.in_binders:ifvinenv:raiseTypeErrorenv.add(v)foreqninjaxpr.eqns:in_types=[typecheck_atom(env,x)forxineqn.inputs]out_types=abstract_eval_rules[eqn.primitive](*in_types,**eqn.params)forout_binder,out_typeinzip(eqn.out_binders,out_types):ifnotout_type==out_binder.aval:raiseTypeErrorforout_binderineqn.out_binders:ifout_binderinenv:raiseTypeErrorenv.add(out_binder)in_types=[v.avalforvinjaxpr.in_binders]out_types=[typecheck_atom(env,x)forxinjaxpr.outs]returnJaxprType(in_types,out_types)deftypecheck_atom(env:set[Var],x:Atom)->ShapedArray:ifisinstance(x,Var):ifxnotinenv:raiseTypeError("unbound variable")returnx.avalelifisinstance(x,Lit):returnraise_to_shaped(get_aval(x.val))else:assertFalse

We can apply the function represented by a jaxpr to arguments with a simpleinterpreter.

defeval_jaxpr(jaxpr:Jaxpr,args:list[Any])->list[Any]:env:dict[Var,Any]={}defread(x:Atom)->Any:returnenv[x]iftype(x)isVarelsex.valdefwrite(v:Var,val:Any)->None:assertvnotinenv# single-assignmentenv[v]=valmap(write,jaxpr.in_binders,args)foreqninjaxpr.eqns:in_vals=map(read,eqn.inputs)outs=bind(eqn.primitive,*in_vals,**eqn.params)map(write,eqn.out_binders,outs)returnmap(read,jaxpr.outs)defjaxpr_as_fun(jaxpr:Jaxpr):returnlambda*args:eval_jaxpr(jaxpr,args)

By usingbind in the interpreter, this interpreter itself is traceable.

Building jaxprs with tracing#

Now that we have jaxprs as a data structure, we need ways to produce thesefrom tracing Python code. In general there are two variants of how we trace toa jaxpr;jit uses one andvjp uses the other. We’ll start with the oneused byjit, which is also used by control flow primitives likelax.cond,lax.while_loop, andlax.scan.

defsplit_list(lst:list[Any],n:int)->tuple[list[Any],list[Any]]:assert0<=n<=len(lst)returnlst[:n],lst[n:]defpartition_list(bs:list[bool],l:list[Any])->tuple[list[Any],list[Any]]:assertlen(bs)==len(l)lists=lst1,lst2=[],[]forb,xinzip(bs,l):lists[b].append(x)returnlst1,lst2
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'classJaxprTracer(Tracer):__slots__=['aval']aval:ShapedArraydef__init__(self,trace,aval):self._trace=traceself.aval=aval# NB: the analogous class in JAX is called 'DynamicJaxprTrace'classJaxprTrace(Trace):defnew_arg(self,aval:ShapedArray)->JaxprTracer:aval=raise_to_shaped(aval)tracer=self.builder.new_tracer(self,aval)self.builder.tracer_to_var[id(tracer)]=Var(aval)returntracerdefget_or_make_const_tracer(self,val:Any)->JaxprTracer:tracer=self.builder.const_tracers.get(id(val))iftracerisNone:tracer=self.builder.new_tracer(self,raise_to_shaped(get_aval(val)))self.builder.add_const(tracer,val)returntracerpure=lift=get_or_make_const_tracerdefprocess_primitive(self,primitive,tracers,params):avals_in=[t.avalfortintracers]avals_out=abstract_eval_rules[primitive](*avals_in,**params)out_tracers=[self.builder.new_tracer(self,a)forainavals_out]inputs=[self.builder.getvar(t)fortintracers]outvars=[self.builder.add_var(t)fortinout_tracers]self.builder.add_eqn(JaxprEqn(primitive,inputs,params,outvars))returnout_tracers@propertydefbuilder(self):returnself.main.global_data# NB: in JAX, we instead attach abstract eval rules to Primitive instancesabstract_eval_rules={}

Notice that we keep as interpreter-global data a builder object, which keepstrack of variables, constants, and eqns as we build up the jaxpr.

classJaxprBuilder:eqns:list[JaxprEqn]tracer_to_var:dict[int,Var]const_tracers:dict[int,JaxprTracer]constvals:dict[Var,Any]tracers:list[JaxprTracer]def__init__(self):self.eqns=[]self.tracer_to_var={}self.const_tracers={}self.constvals={}self.tracers=[]defnew_tracer(self,trace:JaxprTrace,aval:ShapedArray)->JaxprTracer:tracer=JaxprTracer(trace,aval)self.tracers.append(tracer)returntracerdefadd_eqn(self,eqn:JaxprEqn)->None:self.eqns.append(eqn)defadd_var(self,tracer:JaxprTracer)->Var:assertid(tracer)notinself.tracer_to_varvar=self.tracer_to_var[id(tracer)]=Var(tracer.aval)returnvardefgetvar(self,tracer:JaxprTracer)->Var:var=self.tracer_to_var.get(id(tracer))assertvarisnotNonereturnvardefadd_const(self,tracer:JaxprTracer,val:Any)->Var:var=self.add_var(tracer)self.const_tracers[id(val)]=tracerself.constvals[var]=valreturnvardefbuild(self,in_tracers:list[JaxprTracer],out_tracers:list[JaxprTracer])->tuple[Jaxpr,list[Any]]:constvars,constvals=unzip2(self.constvals.items())t2v=lambdat:self.tracer_to_var[id(t)]in_binders=constvars+[t2v(t)fortinin_tracers]out_vars=[t2v(t)fortinout_tracers]jaxpr=Jaxpr(in_binders,self.eqns,out_vars)typecheck_jaxpr(jaxpr)jaxpr,constvals=_inline_literals(jaxpr,constvals)returnjaxpr,constvals
def_inline_literals(jaxpr:Jaxpr,consts:list[Any])->tuple[Jaxpr,list[Any]]:const_binders,other_binders=split_list(jaxpr.in_binders,len(consts))scalars=[type(x)injax_typesandnotget_aval(x).shapeforxinconsts]new_const_binders,lit_binders=partition_list(scalars,const_binders)new_consts,lit_vals=partition_list(scalars,consts)literals=dict(zip(lit_binders,map(Lit,lit_vals)))new_eqns=[JaxprEqn(eqn.primitive,[literals.get(x,x)forxineqn.inputs],eqn.params,eqn.out_binders)foreqninjaxpr.eqns]new_outs=[literals.get(x,x)forxinjaxpr.outs]new_jaxpr=Jaxpr(new_const_binders+other_binders,new_eqns,new_outs)typecheck_jaxpr(new_jaxpr)returnnew_jaxpr,new_consts

The rules we need forJaxprTrace.process_primitive are essentially typingrules for primitive applications: given the primitive, its parameters, andtypes for the inputs, the rule must produce a type for the output, which isthen packaged with the outputJaxprTracer. We can use abstract evaluationrules for this same purpose, even though they can be more general (sinceabstract evaluation rules must accept ConcreteArray inputs, and since theyneed only return an upper bound on the set of possible outputs, they canproduce ConcreteArray outputs as well). We’ll reuse these abstract evaluationrules for the other jaxpr-producing trace machinery, where the potential extragenerality is useful.

defbinop_abstract_eval(x:ShapedArray,y:ShapedArray)->list[ShapedArray]:ifnotisinstance(x,ShapedArray)ornotisinstance(y,ShapedArray):raiseTypeErrorifraise_to_shaped(x)!=raise_to_shaped(y):raiseTypeErrorreturn[ShapedArray(x.shape,x.dtype)]abstract_eval_rules[add_p]=binop_abstract_evalabstract_eval_rules[mul_p]=binop_abstract_evaldefcompare_abstract_eval(x:ShapedArray,y:ShapedArray)->list[ShapedArray]:ifnotisinstance(x,ShapedArray)ornotisinstance(y,ShapedArray):raiseTypeErrorifx.shape!=y.shape:raiseTypeErrorreturn[ShapedArray(x.shape,np.dtype('bool'))]abstract_eval_rules[greater_p]=compare_abstract_evalabstract_eval_rules[less_p]=compare_abstract_evaldefvectorized_unop_abstract_eval(x:ShapedArray)->list[ShapedArray]:return[ShapedArray(x.shape,x.dtype)]abstract_eval_rules[sin_p]=vectorized_unop_abstract_evalabstract_eval_rules[cos_p]=vectorized_unop_abstract_evalabstract_eval_rules[neg_p]=vectorized_unop_abstract_evaldefreduce_sum_abstract_eval(x:ShapedArray,*,axis:tuple[int,...])->list[ShapedArray]:axis_=set(axis)new_shape=[dfori,dinenumerate(x.shape)ifinotinaxis_]return[ShapedArray(tuple(new_shape),x.dtype)]abstract_eval_rules[reduce_sum_p]=reduce_sum_abstract_evaldefbroadcast_abstract_eval(x:ShapedArray,*,shape:Sequence[int],axes:Sequence[int])->list[ShapedArray]:return[ShapedArray(tuple(shape),x.dtype)]abstract_eval_rules[broadcast_p]=broadcast_abstract_eval

To check our implementation of jaxprs, we can add amake_jaxprtransformation and a pretty-printer:

fromfunctoolsimportlru_cache@lru_cache# ShapedArrays are hashabledefmake_jaxpr_v1(f,*avals_in):avals_in,in_tree=tree_flatten(avals_in)f,out_tree=flatten_fun(f,in_tree)builder=JaxprBuilder()withnew_main(JaxprTrace,builder)asmain:trace=JaxprTrace(main)tracers_in=[trace.new_arg(aval)foravalinavals_in]outs=f(*tracers_in)tracers_out=[full_raise(trace,out)foroutinouts]jaxpr,consts=builder.build(tracers_in,tracers_out)returnjaxpr,consts,out_tree()

Show code cell source

Hide code cell source

fromcollectionsimportdefaultdictimportstringclassPPrint:lines:list[tuple[int,str]]def__init__(self,lines):self.lines=linesdefindent(self,indent:int)->'PPrint':returnPPrint([(indent+orig_indent,s)fororig_indent,sinself.lines])def__add__(self,rhs:'PPrint')->'PPrint':returnPPrint(self.lines+rhs.lines)def__rshift__(self,rhs:'PPrint')->'PPrint':ifnotrhs.lines:returnselfifnotself.lines:returnrhsindent,s=self.lines[-1]indented_block=rhs.indent(indent+len(s))common_line=s+' '*rhs.lines[0][0]+rhs.lines[0][1]returnPPrint(self.lines[:-1]+[(indent,common_line)]+indented_block.lines[1:])def__str__(self)->str:return'\n'.join(' '*indent+sforindent,sinself.lines)defpp(s:Any)->PPrint:returnPPrint([(0,line)forlineinstr(s).splitlines()])defvcat(ps:list[PPrint])->PPrint:returnsum(ps,pp(''))defpp_jaxpr(jaxpr:Jaxpr)->PPrint:namegen=(''.join(s)forrinit.count(1)forsinit.permutations(string.ascii_lowercase,r))names=defaultdict(lambda:next(namegen))in_binders=', '.join(var_str(names,x)forxinjaxpr.in_binders)eqns=vcat([pp_eqn(names,e)foreinjaxpr.eqns])outs=', '.join(names[v]ifisinstance(v,Var)elsestr(v.val)forvinjaxpr.outs)return(pp(f'{{ lambda{in_binders} .')+((pp('let ')>>eqns)+pp(f'in ({outs} )}}')).indent(2))defvar_str(names:defaultdict[Var,str],v:Var)->str:returnf'{names[v]}:{v.aval.str_short()}'defpp_eqn(names:defaultdict[Var,str],eqn:JaxprEqn)->PPrint:rule=pp_rules.get(eqn.primitive)ifrule:returnrule(names,eqn)else:lhs=pp(' '.join(var_str(names,v)forvineqn.out_binders))rhs=(pp(eqn.primitive.name)>>pp_params(eqn.params)>>pp(' '.join(names[x]ifisinstance(x,Var)elsestr(x.val)forxineqn.inputs)))returnlhs>>pp(' = ')>>rhsdefpp_params(params:dict[str,Any])->PPrint:items=sorted(params.items())ifitems:returnpp(' [ ')>>vcat([pp(f'{k}={v}')fork,vinitems])>>pp(' ] ')else:returnpp(' ')Jaxpr.__repr__=lambdaself:str(pp_jaxpr(self))pp_rules:dict[Primitive,Callable[...,PPrint]]={}
jaxpr,consts,_=make_jaxpr_v1(lambdax:2.*x,raise_to_shaped(get_aval(3.)))print(jaxpr)print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .  let b:float64[] = mul 2.0 a  in ( b ) }(float64[]) -> (float64[])

But there’s a limitation here: because of howfind_top_trace operates bydata dependence,make_jaxpr_v1 can’t stage out all the primitive operationsperformed by the Python callable it’s given. For example:

jaxpr,consts,_=make_jaxpr_v1(lambda:mul(2.,2.))print(jaxpr)
{ lambda  .  let   in ( 4.0 ) }

This is precisely the issue thatomnistaging fixed.We want to ensure that theJaxprTrace started bymake_jaxpr is alwaysapplied, regardless of whether any inputs tobind are boxed in correspondingJaxprTracer instances. We can achieve this by employing thedynamic_traceglobal defined in Part 1:

@contextmanagerdefnew_dynamic(main:MainTrace):globaldynamic_traceprev_dynamic_trace,dynamic_trace=dynamic_trace,maintry:yieldfinally:dynamic_trace=prev_dynamic_trace@lru_cachedefmake_jaxpr(f:Callable,*avals_in:ShapedArray,)->tuple[Jaxpr,list[Any],PyTreeDef]:avals_in,in_tree=tree_flatten(avals_in)f,out_tree=flatten_fun(f,in_tree)builder=JaxprBuilder()withnew_main(JaxprTrace,builder)asmain:withnew_dynamic(main):trace=JaxprTrace(main)tracers_in=[trace.new_arg(aval)foravalinavals_in]outs=f(*tracers_in)tracers_out=[full_raise(trace,out)foroutinouts]jaxpr,consts=builder.build(tracers_in,tracers_out)returnjaxpr,consts,out_tree()jaxpr,consts,_=make_jaxpr(lambda:mul(2.,2.))print(jaxpr)
{ lambda  .  let a:float64[] = mul 2.0 2.0  in ( a ) }

Usingdynamic_trace this way is conceptually the same as stashing thecurrent interpreter stack and starting a new one with theJaxprTrace at thebottom. That is, no interpreters lower in the stack than thedynamic_traceare applied (sinceJaxprTrace.process_primitive doesn’t callbind), thoughif the Python callable being traced to a jaxpr itself uses transformationsthen those can be pushed onto the interpreter stack above theJaxprTrace.But temporarily stashing the interpreter stack would break up the systemstate. Thedynamic_trace tag achieves the same goals while keeping thesystem state simpler.

That’s it for jaxprs! With jaxprs in hand, we can implement the remainingmajor JAX features.

Part 3:jit, simplified#

Whilejit has a transformation-like API in that it accepts a Python callableas an argument, under the hood it’s really a higher-order primitive ratherthan a transformation. A primitive ishigher-order when it’s parameterizedby a function.

On-the-fly (“final style”) and staged (“initial style”) processing#

There are two options for how to handle higher-order primitives. Each requiresa different approach to tracing and engenders different tradeoffs:

  1. On-the-fly processing, wherebind takes a Python callable as anargument. We defer forming a jaxpr until as late as possible, namelyuntil we’re running the final interpreter at the bottom of the interpreterstack. That way we can swap aJaxprTrace in at the bottom of theinterpreter stack and thus stage out rather than execute all primitiveoperations. With this approach, transformations in the stack get applied aswe execute the Python callable as usual. This approach can be very trickyto implement, but it’s as general as possible because it allowshigher-order primitives not to raise the abstraction level of theirarguments and thus allows data-dependent Python control flow. We refer tothis approach as using a “final-style higher-order primitive” employing thedischarge-at-tracing-time “final-style transformations” we’ve used so far.

  2. Staged processing, wherebind takes a jaxpr as an argument. Before wecallbind, in the primitive wrapper we can just usemake_jaxpr to forma jaxpr up-front and be done with the Python callable entirely. In thiscase,make_jaxpr puts itsJaxprTrace at the top of the interpreterstack, and no transformations lower in the stack, which might enter viaclosed-over Tracers, are applied to the Python callable as we trace it.(Transformations applied within the Python callable are applied as usual,being added to the stack above the JaxprTrace.) Instead, thetransformations lower in the stack are later applied to the call primitive,and the call primitive’s rules must then transform the jaxpr itself.Because we trace to a jaxpr up-front, this approach can’t supportdata-dependent Python control flow, but it is more straightforward toimplement. We refer to this kind of higher-order primitive as an“initial-style higher-order primitive”, and say that its jaxpr-processingtransformation rules are “initial-style transformation rules.”

The latter approach fits forjit because we don’t need to supportdata-dependent Python control flow in the user-provided Python callable, asthe whole purpose ofjit is to stage computation out of Python to beexecuted by XLA. (In contrast,custom_jvp is a higher-order primitive inwhich we want to support data-dependent Python control flow.)

Historically, we started using the “initial-style” and “final-style”terminology after reading thetyped tagless finalinterpreters paper, andjokingly referring to JAX as an implementation of “untyped tagful finalinterpreters.” We don’t claim to carry over (or understand) any deep meaningbehind these terms; we loosely use “initial style” to mean “build an AST andthen transform it”, and we use “final style” to mean “transform as we trace.”But it’s just imprecise yet sticky jargon.

With the initial-style approach, here’s the user-facingjit wrapper:

defjit(f):deff_jitted(*args):avals_in=[raise_to_shaped(get_aval(x))forxinargs]jaxpr,consts,out_tree=make_jaxpr(f,*avals_in)outs=bind(xla_call_p,*consts,*args,jaxpr=jaxpr,num_consts=len(consts))returntree_unflatten(out_tree,outs)returnf_jittedxla_call_p=Primitive('xla_call')

With any new primitive, we need to give it transformation rules, starting withits evaluation rule. When we evaluate an application of thexla_callprimitive, we want to stage out the computation to XLA. That involvestranslating the jaxpr to an XLA HLO program, transferring the argument valuesto the XLA device, executing the XLA program, and transferring back theresults. We’ll cache the XLA HLO compilation so that for eachjittedfunction it only needs to be performed once per argument shape and dtypesignature.

First, some utilities.

classIDHashable:val:Anydef__init__(self,val):self.val=valdef__hash__(self)->int:returnid(self.val)def__eq__(self,other):returntype(other)isIDHashableandid(self.val)==id(other.val)

Next, we’ll define the evaluation rule forxla_call:

importiofromjax.extend.mlirimportirfromjax.extend.mlir.dialectsimportfuncfromjax.extend.mlir.dialectsimportstablehloashlofromjax._srcimportxla_bridgeasxbclassMlirContext(NamedTuple):module:ir.Modulesymbol_table:ir.SymbolTabledefxla_call_impl(*args,jaxpr:Jaxpr,num_consts:int):consts,args=args[:num_consts],args[num_consts:]hashable_consts=tuple(map(IDHashable,consts))execute=xla_callable(IDHashable(jaxpr),hashable_consts)returnexecute(*args)impl_rules[xla_call_p]=xla_call_impl@lru_cachedefxla_callable(hashable_jaxpr:IDHashable,hashable_consts:tuple[IDHashable,...]):jaxpr:Jaxpr=hashable_jaxpr.valtypecheck_jaxpr(jaxpr)consts=[x.valforxinhashable_consts]in_avals=[v.avalforvinjaxpr.in_binders[len(consts):]]withir.Context()asctx,ir.Location.unknown(ctx):hlo.register_dialect(ctx)m=ir.Module.create()c=MlirContext(m,ir.SymbolTable(m.operation))withir.InsertionPoint(c.module.body):@func.func(*(aval_to_ir_type(aval)foravalinin_avals))defmain(*params):returnjaxpr_subcomp(c,jaxpr,_hlo_consts(consts)+params)output=io.StringIO()c.module.operation.print(file=output)backend=xb.get_backend(None)compiled=backend.compile_and_load(output.getvalue(),backend.devices()[:1])returnpartial(execute_compiled,compiled,[v.avalforvinjaxpr.outs])def_mlir_dtype(dtype:np.dtype)->ir.Type:ifnp.issubdtype(dtype,np.signedinteger):returnir.IntegerType.get_signless(np.iinfo(dtype).bits)elifdtype==np.float32:returnir.F32Type.get()elifdtype==np.float64:returnir.F64Type.get()else:raiseNotImplementedError("MLIR conversion not implemented for ",dtype)defaval_to_ir_type(aval:ShapedArray)->ir.Type:returnir.RankedTensorType.get(aval.shape,_mlir_dtype(aval.dtype))def_hlo_const(x:Any)->ir.Value:a=np.asarray(x)ifa.dtype==np.bool_:returnhlo.constant(ir.DenseElementsAttr.get(np.array(a,np.bool_),type=ir.IntegerType.get_signless(1),shape=a.shape))else:returnhlo.constant(ir.DenseElementsAttr.get(a))def_hlo_consts(consts:list[Any])->list[ir.Value]:unique_consts={id(cnst):cnstforcnstinconsts}ir_consts={id_:_hlo_const(cnst)forid_,cnstinunique_consts.items()}returntuple(ir_consts[id(cnst)]forcnstinconsts)

The main action is inxla_callable, which compiles a jaxpr into an XLA HLOprogram usingjaxpr_subcomp, then returns a callable which executes thecompiled program:

defjaxpr_subcomp(c:MlirContext,jaxpr:Jaxpr,args:list[ir.Value])->list[ir.Value]:env:dict[Var,ir.Value]={}defread(x:Atom)->ir.Value:returnenv[x]iftype(x)isVarelse_hlo_const(np.asarray(x.val))defwrite(v:Var,val:ir.Value)->None:env[v]=valmap(write,jaxpr.in_binders,args)foreqninjaxpr.eqns:in_avals=[x.avalforxineqn.inputs]in_vals=map(read,eqn.inputs)out_avals=[x.avalforxineqn.out_binders]rule=hlo_translations[eqn.primitive]assertall(isinstance(v,ir.Value)forvinin_vals),in_valsout_vals=rule(c,in_avals,out_avals,in_vals,**eqn.params)assertall(isinstance(v,ir.Value)forvinout_vals),out_valsmap(write,eqn.out_binders,out_vals),out_valsreturnmap(read,jaxpr.outs)defexecute_compiled(compiled,out_avals,*args):input_bufs=[input_handlers[type(x)](x)forxinargs]out_bufs=compiled.execute(input_bufs)return[handle_result(aval,buf)foraval,bufinzip(out_avals,out_bufs)]default_input_handler=xb.get_backend(None).buffer_from_pyvalinput_handlers={ty:default_input_handlerfortyin[bool,int,float,np.ndarray,np.float64,np.float32]}defhandle_result(aval:ShapedArray,buf):delaval# Unused for nowreturnnp.asarray(buf)hlo_translations={}

Notice thatjaxpr_subcomp has the structure of a simple interpreter. That’sa common pattern: the way we process jaxprs is usually with an interpreter.And as with any interpreter, we need an interpretation rule for eachprimitive:

defdirect_translation(op,c,in_avals,out_avals,in_vals):delc,in_avals,out_avalsreturn[op(*in_vals)]hlo_translations[add_p]=partial(direct_translation,hlo.add)hlo_translations[mul_p]=partial(direct_translation,hlo.multiply)hlo_translations[neg_p]=partial(direct_translation,hlo.negate)hlo_translations[sin_p]=partial(direct_translation,hlo.sine)hlo_translations[cos_p]=partial(direct_translation,hlo.cosine)defcompare_translation(op,c,in_avals,out_avals,in_vals):delc,out_avalsreturn[hlo.compare(*in_vals,hlo.ComparisonDirectionAttr.get(op))]hlo_translations[greater_p]=partial(compare_translation,"GT")hlo_translations[less_p]=partial(compare_translation,"LT")defreduce_sum_translation(c,in_avals,out_avals,in_vals,*,axis):delc(x_aval,),(out_aval,),(x,)=in_avals,out_avals,in_valsop=hlo.ReduceOp([aval_to_ir_type(out_aval)],[x],[_hlo_const(np.array(0,x_aval.dtype))],axis)scalar_type=aval_to_ir_type(ShapedArray((),x_aval.dtype))reducer_region=op.body.blocks.append(scalar_type,scalar_type)withir.InsertionPoint(reducer_region):hlo.return_([hlo.add(*reducer_region.arguments)])returnop.resultshlo_translations[reduce_sum_p]=reduce_sum_translationdefbroadcast_translation(c,in_avals,out_avals,in_vals,*,shape,axes):delc(x,),(out_aval,)=in_vals,out_avalsdims_complement=[iforiinrange(len(shape))ifinotinaxes]return[hlo.broadcast_in_dim(aval_to_ir_type(out_aval),x,dims_complement)]hlo_translations[broadcast_p]=broadcast_translation

With that, we can now usejit to stage out, compile, and execute programswith XLA!

@jitdeff(x,y):print('tracing!')returnsin(x)*cos(y)
z=f(3.,4.)# 'tracing!' prints the first timeprint(z)
tracing!-0.09224219304455371
z=f(4.,5.)# 'tracing!' doesn't print, compilation cache hit!print(z)
-0.21467624978306993
@jitdeff(x):returnreduce_sum(x,axis=0)print(f(np.array([1.,2.,3.])))
6.0
deff(x):y=sin(x)*2.z=-y+xreturnzdefderiv(f):returnlambdax:jvp(f,(x,),(1.,))[1]print(deriv(deriv(f))(3.))print(jit(deriv(deriv(f)))(3.))
0.28224001611973440.2822400161197344

Instead of implementingjit to first trace to a jaxpr and then to lower thejaxpr to XLA HLO, it might appear that we could have skipped the jaxpr stepand just lowered to HLO while tracing. That is, perhaps we could have insteadimplementedjit with aTrace andTracer that appended to the XLA HLOgraph incrementally on each primitive bind. That’s correct for now, but won’tbe possible when we introduce compiled SPMD computations because there we mustknow the number of replicas needed before compiling the program.

We haven’t yet defined any transformation rules forxla_call_p other thanits evaluation rule. That is, we can’t yet dovmap-of-jit orjvp-of-jit or evenjit-of-jit. Insteadjit has to be at the “toplevel.” Let’s fix that!

defxla_call_jvp_rule(primals,tangents,*,jaxpr,num_consts):delnum_consts# Unusednew_jaxpr,new_consts=jvp_jaxpr(jaxpr)outs=bind(xla_call_p,*new_consts,*primals,*tangents,jaxpr=new_jaxpr,num_consts=len(new_consts))n=len(outs)//2primals_out,tangents_out=outs[:n],outs[n:]returnprimals_out,tangents_outjvp_rules[xla_call_p]=xla_call_jvp_rule@lru_cachedefjvp_jaxpr(jaxpr:Jaxpr)->tuple[Jaxpr,list[Any]]:defjvp_traceable(*primals_and_tangents):n=len(primals_and_tangents)//2primals,tangents=primals_and_tangents[:n],primals_and_tangents[n:]returnjvp(jaxpr_as_fun(jaxpr),primals,tangents)in_avals=[v.avalforvinjaxpr.in_binders]new_jaxpr,new_consts,_=make_jaxpr(jvp_traceable,*in_avals,*in_avals)returnnew_jaxpr,new_consts
defxla_call_vmap_rule(axis_size,vals_in,dims_in,*,jaxpr,num_consts):delnum_consts# Unusednew_jaxpr,new_consts=vmap_jaxpr(jaxpr,axis_size,tuple(dims_in))outs=bind(xla_call_p,*new_consts,*vals_in,jaxpr=new_jaxpr,num_consts=len(new_consts))returnouts,[0]*len(outs)vmap_rules[xla_call_p]=xla_call_vmap_rule@lru_cachedefvmap_jaxpr(jaxpr:Jaxpr,axis_size:int,bdims_in:tuple[BatchAxis,...])->tuple[Jaxpr,list[Any]]:vmap_traceable=vmap(jaxpr_as_fun(jaxpr),tuple(bdims_in))in_avals=[unmapped_aval(axis_size,d,v.aval)forv,dinzip(jaxpr.in_binders,bdims_in)]new_jaxpr,new_consts,_=make_jaxpr(vmap_traceable,*in_avals)returnnew_jaxpr,new_constsdefunmapped_aval(axis_size:int,batch_dim:BatchAxis,aval:ShapedArray)->ShapedArray:ifbatch_dimisnot_mapped:returnavalelse:shape=list(aval.shape)shape.insert(batch_dim,axis_size)returnShapedArray(tuple(shape),aval.dtype)
defxla_call_abstract_eval_rule(*in_types,jaxpr,num_consts):delnum_consts# Unusedjaxpr_type=typecheck_jaxpr(jaxpr)ifnotall(t1==t2fort1,t2inzip(jaxpr_type.in_types,in_types)):raiseTypeErrorreturnjaxpr_type.out_typesabstract_eval_rules[xla_call_p]=xla_call_abstract_eval_ruledefxla_call_translation(c,in_avals,out_avals,in_vals,*,jaxpr,num_consts):delnum_consts,out_avals# Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.withir.InsertionPoint(c.module.body):@func.func(*(aval_to_ir_type(aval)foravalinin_avals))definner_xla_call(*params):returnjaxpr_subcomp(c,jaxpr,params)c.symbol_table.insert(inner_xla_call.func_op)returnfunc.CallOp(inner_xla_call.func_op,in_vals).resultshlo_translations[xla_call_p]=xla_call_translation
@jitdeff(x):print('tracing!')y=sin(x)*2.z=-y+xreturnzx,xdot=3.,1.y,ydot=jvp(f,(x,),(xdot,))print(y)print(ydot)
tracing!2.71775998388026572.979984993200891
y,ydot=jvp(f,(x,),(xdot,))# 'tracing!' not printed
ys=vmap(f,(0,))(np.arange(3.))print(ys)
[ 0.         -0.68294197  0.18140515]

One piece missing is device memory persistence for arrays. That is, we’vedefinedhandle_result to transfer results back to CPU memory as NumPyarrays, but it’s often preferable to avoid transferring results just totransfer them back for the next operation. We can do that by introducing anArray class, which can wrap XLA buffers and otherwise duck-typenumpy.ndarrays:

defhandle_result(aval:ShapedArray,buf):# noqa: F811returnArray(aval,buf)classArray:buf:Anyaval:ShapedArraydef__init__(self,aval,buf):self.aval=avalself.buf=bufdtype=property(lambdaself:self.aval.dtype)shape=property(lambdaself:self.aval.shape)ndim=property(lambdaself:self.aval.ndim)def__array__(self):returnnp.asarray(self.buf)def__repr__(self):returnrepr(np.asarray(self.buf))def__str__(self):returnstr(np.asarray(self.buf))_neg=staticmethod(neg)_add=staticmethod(add)_radd=staticmethod(add)_mul=staticmethod(mul)_rmul=staticmethod(mul)_gt=staticmethod(greater)_lt=staticmethod(less)input_handlers[Array]=lambdax:x.bufjax_types.add(Array)
@jitdeff(x):y=sin(x)*2.z=-y+xreturnzx,xdot=3.,1.y,ydot=jvp(f,(x,),(xdot,))print(y)print(ydot)
2.71775998388026572.979984993200891

Show code cell source

Hide code cell source

defpprint_xla_call(names:defaultdict[Var,str],eqn:JaxprEqn)->PPrint:lhs=pp(' '.join(var_str(names,v)forvineqn.out_binders))params_without_jaxpr={k:vfork,vineqn.params.items()ifk!='jaxpr'}rhs=(pp(eqn.primitive.name)>>pp_params(params_without_jaxpr)>>pp(' '.join(names[x]ifisinstance(x,Var)elsestr(x.val)forxineqn.inputs)))returnvcat([lhs>>pp(' = ')>>rhs,pp_jaxpr(eqn.params['jaxpr']).indent(2)])pp_rules[xla_call_p]=pprint_xla_call

Part 4:linearize andvjp (andgrad!)#

Thelinearize andvjp autodiff functions are built onjvp, but involvejaxprs as well. That’s because both involve staging out, or delaying,computation.

linearize#

In the case oflinearize, we want to stage out the linear part of ajvpcomputation. That is, in terms ofHaskell-like type signatures,if we havejvp:(a->b)->(a,Ta)->(b,Tb),then we writelinearize:(a->b)->a->(b,Ta-oTb), usingTa tomean “the tangent type ofa” and using the “lollipop”-o rather than thearrow-> to indicate alinear function. We define the semantics oflinearize in terms ofjvp too:

y,f_lin=linearize(f,x)y_dot=f_lin(x_dot)

gives the same result for(y,y_dot) as

y,y_dot=jvp(f,(x,),(x_dot,))

where the application off_lin does not redo any of the linearization work.We’ll represent the delayed linear partf_lin:Ta-oTb as a jaxpr.

Tangentially, now that we have linear arrows-o, we can provide a slightlymore informative type forjvp:

jvp:(a->b)->(UnrestrictedUsea,Ta)-o(UnrestrictedUseb,Tb)

Here we’re writingUnrestrictedUse just to indicate that we have a specialpair where the first element can be used in an unrestricted (nonlinear) way.In conjunction with the linear arrow, this notation is just meant to expressthat the functionjvpf uses its first input in a nonlinear way but itssecond input in a linear way, producing a corresponding nonlinear output(which can be used in a nonlinear way) paired with a linear output. This morerefined type signature encodes the data dependencies injvpf, which areuseful for partial evaluation.

To build thef_lin jaxpr from a JVP, we need to perform partial evaluation:we evaluate all the primal values as we trace, but stage the tangentcomputations into a jaxpr. This is our second way to build jaxprs. But wheremake_jaxpr and its underlyingJaxprTrace/JaxprTracer interpreters aimto stage out every primitive bind, this second approach stages out only thoseprimitive binds with a data dependence on tangent inputs.

First, some utilities:

defsplit_half(lst:list[Any])->tuple[list[Any],list[Any]]:assertnotlen(lst)%2returnsplit_list(lst,len(lst)//2)defmerge_lists(which:list[bool],l1:list[Any],l2:list[Any])->list[Any]:l1,l2=iter(l1),iter(l2)out=[next(l2)ifbelsenext(l1)forbinwhich]assertnext(l1,None)isnext(l2,None)isNonereturnout

Next, we’ll writelinearize by combiningjvp together with a generalpartial evaluation transformation, to be added next:

deflinearize_flat(f,*primals_in):pvals_in=([PartialVal.known(x)forxinprimals_in]+[PartialVal.unknown(vspace(get_aval(x)))forxinprimals_in])deff_jvp(*primals_tangents_in):primals_out,tangents_out=jvp(f,*split_half(primals_tangents_in))return[*primals_out,*tangents_out]jaxpr,pvals_out,consts=partial_eval_flat(f_jvp,pvals_in)primal_pvals,_=split_half(pvals_out)assertall(pval.is_knownforpvalinprimal_pvals)primals_out=[pval.constforpvalinprimal_pvals]f_lin=lambda*tangents:eval_jaxpr(jaxpr,[*consts,*tangents])returnprimals_out,f_lindeflinearize(f,*primals_in):primals_in_flat,in_tree=tree_flatten(primals_in)f,out_tree=flatten_fun(f,in_tree)primals_out_flat,f_lin_flat=linearize_flat(f,*primals_in_flat)primals_out=tree_unflatten(out_tree(),primals_out_flat)deff_lin(*tangents_in):tangents_in_flat,in_tree2=tree_flatten(tangents_in)ifin_tree!=in_tree2:raiseTypeErrortangents_out_flat=f_lin_flat(*tangents_in_flat)returntree_unflatten(out_tree(),tangents_out_flat)returnprimals_out,f_lindefvspace(aval:ShapedArray)->ShapedArray:returnraise_to_shaped(aval)# TODO handle integers?

Now we turn to the general partial evaluation transformation. The goal is toaccept a Python callable and a list of inputs, some known and some unknown,and to produce (1) all the outputs which can be computed from the knowninputs, together with (2) a jaxpr representing the part of the Pythoncallable’s computation which can only be performed after the remaining inputsare known.

This transformation is tricky to summarize in a type signature. If weassume the input function’s type signature is(a1,a2)->(b1,b2), wherea1 anda2 represent the known and unknown inputs, respectively, and whereb1 only has a data dependency ona1 whileb2 has some data dependency ona2, then we might write

partial_eval:((a1,a2)->(b1,b2))->a1->existsr.(b1,r,(r,a2)->b2)

In words, given values for the inputs of typea1,partial_eval producesthe outputs of typeb1 along with “residual” values ofexistentially-quantified typer representing the intermediates required tocomplete the computation in the second stage. It also produces a function oftype(r,a2)->b2 which accepts the residual values as well as theremaining inputs and produces the remaining outputs.

We like to think of partial evaluation as “unzipping” one computation intotwo. For example, consider this jaxpr:

{lambdaa:float64[].letb:float64[]=sinac:float64[]=negbin(c)}

A jaxpr for the JVP would look like:

{lambdaa:float64[]b:float64[].letc:float64[]=sinad:float64[]=cosae:float64[]=muldbf:float64[]=negcg:float64[]=negein(f,g)}

If we imagine applying partial evaluation to this jaxpr with the first inputknown and the second unknown, we end up ‘unzipping’ the JVP jaxpr into primaland tangent jaxprs:

{lambdaa:float64[].letc:float64[]=sinad:float64[]=cosaf:float64[]=negcin(f,d)}
{lambdad:float64[]b:float64[].lete:float64[]=muldbg:float64[]=negein(g)}

This second jaxpr represents the linear computation that we want fromlinearize.

However, unlike in this jaxpr example, we want the computation on known valuesto occur while evaluating the input Python callable. That is, rather thanforming a jaxpr for the entire function(a1,a2)->(b1,b2), staging alloperations out of Python first before sorting out what can be evaluated nowand what must be delayed, we want only to form a jaxpr for those operationsthatmust be delayed due to a dependence on unknown inputs. In the contextof automatic differentiation, this is the feature that ultimately enables usto handle functions likegrad(lambdax:x**2ifx>0else0.). Pythoncontrol flow works because partial evaluation keeps the primal computation inPython. As a consequence, ourTrace andTracer subclasses must on the flysort out what can be evaluated and what must be staged out into a jaxpr.

First, we start with aPartialVal class, which represents a value that canbe either known or unknown:

classPartialVal(NamedTuple):aval:ShapedArrayconst:Any|None@classmethoddefknown(cls,val:Any):returnPartialVal(get_aval(val),val)@classmethoddefunknown(cls,aval:ShapedArray):returnPartialVal(aval,None)is_known=property(lambdaself:self.constisnotNone)is_unknown=property(lambdaself:self.constisNone)

Partial evaluation will take a list ofPartialVals representing inputs, andreturn a list ofPartialVal outputs along with a jaxpr representing thedelayed computation:

defpartial_eval_flat(f:Callable,pvals_in:list[PartialVal])->tuple[Jaxpr,list[PartialVal],list[Any]]:withnew_main(PartialEvalTrace)asmain:trace=PartialEvalTrace(main)tracers_in=[trace.new_arg(pval)forpvalinpvals_in]outs=f(*tracers_in)tracers_out=[full_raise(trace,out)foroutinouts]pvals_out=[t.pvalfortintracers_out]unk_tracers_in=[tfortintracers_inift.pval.is_unknown]unk_tracers_out=[tfortintracers_outift.pval.is_unknown]jaxpr,consts=tracers_to_jaxpr(unk_tracers_in,unk_tracers_out)returnjaxpr,pvals_out,consts

Next we need to implementPartialEvalTrace and itsPartialEvalTracer. Thisinterpreter will build a jaxpr on the fly while tracking data dependencies. Todo so, it builds a bipartite directed acyclic graph (DAG) betweenPartialEvalTracer nodes, representing staged-out values, andJaxprRecipenodes, representing formulas for how to compute some values from others. Onekind of recipe is aJaxprEqnRecipe, corresponding to aJaxprEqn’sprimitive application, but we also have recipe types for constants and lambdabinders:

fromweakrefimportref,ReferenceTypeclassLambdaBindingRecipe(NamedTuple):passclassConstRecipe(NamedTuple):val:AnyclassJaxprEqnRecipe(NamedTuple):prim:Primitivetracers_in:list['PartialEvalTracer']params:dict[str,Any]avals_out:list[ShapedArray]tracer_refs_out:list['ReferenceType[PartialEvalTracer]']JaxprRecipe=Union[LambdaBindingRecipe,ConstRecipe,JaxprEqnRecipe]
classPartialEvalTracer(Tracer):pval:PartialValrecipe:JaxprRecipe|Nonedef__init__(self,trace,pval,recipe):self._trace=traceself.pval=pvalself.recipe=recipeaval=property(lambdaself:self.pval.aval)deffull_lower(self):ifself.pval.is_known:returnfull_lower(self.pval.const)returnself

ThePartialEvalTrace contains the logic for constructing the graph ofJaxprRecipes andPartialEvalTracers. Each argument corresponds to aLambdaBindingRecipe leaf node, and each constant is aConstRecipe leafnode holding a reference to the constant. All other tracers and recipes comefromprocess_primitive, which forms tracers withJaxprEqnRecipes.

For most primitives, theprocess_primitive logic is straightforward: if allinputs are known then we can bind the primitive on the known values(evaluating it in Python) and avoid forming tracers corresponding to theoutput. If instead any input is unknown then we instead stage out into aJaxprEqnRecipe representing the primitive application. To build the tracersrepresenting unknown outputs, we need avals, which we get from the abstracteval rules. (Notice that tracers referenceJaxprEqnRecipes, andJaxprEqnRecipes reference tracers; we avoid circular garbage by usingweakrefs.)

Thatprocess_primitive logic applies to most primitives, butxla_call_prequires recursive treatment. So we special-case its rule in apartial_eval_rules dict.

classPartialEvalTrace(Trace):defnew_arg(self,pval:PartialVal)->Any:returnPartialEvalTracer(self,pval,LambdaBindingRecipe())deflift(self,val:Any)->PartialEvalTracer:returnPartialEvalTracer(self,PartialVal.known(val),None)pure=liftdefinstantiate_const(self,tracer:PartialEvalTracer)->PartialEvalTracer:iftracer.pval.is_unknown:returntracerelse:pval=PartialVal.unknown(raise_to_shaped(tracer.aval))returnPartialEvalTracer(self,pval,ConstRecipe(tracer.pval.const))defprocess_primitive(self,primitive,tracers,params):ifall(t.pval.is_knownfortintracers):returnbind(primitive,*map(full_lower,tracers),**params)rule=partial_eval_rules.get(primitive)ifrule:returnrule(self,tracers,**params)tracers_in=[self.instantiate_const(t)fortintracers]avals_in=[t.avalfortintracers_in]avals_out=abstract_eval_rules[primitive](*avals_in,**params)tracers_out=[PartialEvalTracer(self,PartialVal.unknown(aval),None)foravalinavals_out]eqn=JaxprEqnRecipe(primitive,tracers_in,params,avals_out,map(ref,tracers_out))fortintracers_out:t.recipe=eqnreturntracers_outpartial_eval_rules={}

Now that we can build graph representations of jaxprs withPartialEvalTrace,we need a mechanism to convert the graph representation to a standard jaxpr.The jaxpr corresponds to a topological sort of the graph.

deftracers_to_jaxpr(tracers_in:list[PartialEvalTracer],tracers_out:list[PartialEvalTracer]):tracer_to_var:dict[int,Var]={id(t):Var(raise_to_shaped(t.aval))fortintracers_in}constvar_to_val:dict[int,Any]={}constid_to_var:dict[int,Var]={}processed_eqns:set[int]=set()eqns:list[JaxprEqn]=[]fortintoposort(tracers_out,tracer_parents):ifisinstance(t.recipe,LambdaBindingRecipe):assertid(t)inset(map(id,tracers_in))elifisinstance(t.recipe,ConstRecipe):val=t.recipe.valvar=constid_to_var.get(id(val))ifvarisNone:aval=raise_to_shaped(get_aval(val))var=constid_to_var[id(val)]=Var(aval)constvar_to_val[var]=valtracer_to_var[id(t)]=varelifisinstance(t.recipe,JaxprEqnRecipe):ifid(t.recipe)notinprocessed_eqns:eqns.append(recipe_to_eqn(tracer_to_var,t.recipe))processed_eqns.add(id(t.recipe))else:raiseTypeError(t.recipe)constvars,constvals=unzip2(constvar_to_val.items())in_binders=constvars+[tracer_to_var[id(t)]fortintracers_in]out_vars=[tracer_to_var[id(t)]fortintracers_out]jaxpr=Jaxpr(in_binders,eqns,out_vars)typecheck_jaxpr(jaxpr)returnjaxpr,constvalsdefrecipe_to_eqn(tracer_to_var:dict[int,Var],recipe:JaxprEqnRecipe)->JaxprEqn:inputs=[tracer_to_var[id(t)]fortinrecipe.tracers_in]out_binders=[Var(aval)foravalinrecipe.avals_out]fort_ref,varinzip(recipe.tracer_refs_out,out_binders):ift_ref()isnotNone:tracer_to_var[id(t_ref())]=varreturnJaxprEqn(recipe.prim,inputs,recipe.params,out_binders)deftracer_parents(t:PartialEvalTracer)->list[PartialEvalTracer]:returnt.recipe.tracers_inifisinstance(t.recipe,JaxprEqnRecipe)else[]

Show code cell source

Hide code cell source

deftoposort(out_nodes:list[Any],parents:Callable[[Any],list[Any]]):ifnotout_nodes:return[]out_nodes=remove_duplicates(out_nodes)child_counts={}stack=list(out_nodes)whilestack:node=stack.pop()ifid(node)inchild_counts:child_counts[id(node)]+=1else:child_counts[id(node)]=1stack.extend(parents(node))fornodeinout_nodes:child_counts[id(node)]-=1sorted_nodes=[]childless_nodes=[nodefornodeinout_nodesifnotchild_counts[id(node)]]whilechildless_nodes:node=childless_nodes.pop()sorted_nodes.append(node)forparentinparents(node):ifchild_counts[id(parent)]==1:childless_nodes.append(parent)else:child_counts[id(parent)]-=1sorted_nodes=sorted_nodes[::-1]check_toposort(sorted_nodes,parents)returnsorted_nodesdefremove_duplicates(lst):seen=set()return[xforxinlstifid(x)notinseenandnotseen.add(id(x))]defcheck_toposort(nodes:list[Any],parents:Callable[[Any],list[Any]]):seen=set()fornodeinnodes:assertall(id(parent)inseenforparentinparents(node))seen.add(id(node))

Now we can linearize!

y,sin_lin=linearize(sin,3.)print(y,sin(3.))print(sin_lin(1.),cos(3.))
0.1411200080598672 0.1411200080598672-0.9899924966004454 -0.9899924966004454

To handlelinearize-of-jit, we still need to write a partial evaluationrule forxla_call_p. Other than tracer bookkeeping, the main task is toperform partial evaluation of a jaxpr, ‘unzipping’ it into two jaxprs.

There are actually two rules to write: one for trace-time partial evaluation,which we’ll callxla_call_partial_eval, and one for partial evaluation ofjaxprs, which we’ll callxla_call_peval_eqn.

defxla_call_partial_eval(trace,tracers,*,jaxpr,num_consts):delnum_consts# Unusedin_unknowns=[nott.pval.is_knownfortintracers]jaxpr1,jaxpr2,out_unknowns,num_res=partial_eval_jaxpr(jaxpr,in_unknowns)known_tracers,unknown_tracers=partition_list(in_unknowns,tracers)known_vals=[t.pval.constfortinknown_tracers]outs1_res=bind(xla_call_p,*known_vals,jaxpr=jaxpr1,num_consts=0)outs1,res=split_list(outs1_res,len(jaxpr1.outs)-num_res)res_tracers=[trace.instantiate_const(full_raise(trace,x))forxinres]outs2=[PartialEvalTracer(trace,PartialVal.unknown(v.aval),None)forvinjaxpr2.outs]eqn=JaxprEqnRecipe(xla_call_p,res_tracers+unknown_tracers,dict(jaxpr=jaxpr2,num_consts=0),[v.avalforvinjaxpr2.outs],map(ref,outs2))fortinouts2:t.recipe=eqnreturnmerge_lists(out_unknowns,outs1,outs2)partial_eval_rules[xla_call_p]=xla_call_partial_evaldefpartial_eval_jaxpr(jaxpr:Jaxpr,in_unknowns:list[bool],instantiate:list[bool]|None=None,)->tuple[Jaxpr,Jaxpr,list[bool],int]:env:dict[Var,bool]={}residuals:set[Var]=set()defread(x:Atom)->bool:returntype(x)isVarandenv[x]defwrite(unk:bool,v:Var)->None:env[v]=unkdefnew_res(x:Atom)->Atom:iftype(x)isVar:residuals.add(x)returnxeqns1,eqns2=[],[]map(write,in_unknowns,jaxpr.in_binders)foreqninjaxpr.eqns:unks_in=map(read,eqn.inputs)rule=partial_eval_jaxpr_rules.get(eqn.primitive)ifrule:eqn1,eqn2,unks_out,res=rule(unks_in,eqn)eqns1.append(eqn1);eqns2.append(eqn2);residuals.update(res)map(write,unks_out,eqn.out_binders)elifany(unks_in):inputs=[vifunkelsenew_res(v)forunk,vinzip(unks_in,eqn.inputs)]eqns2.append(JaxprEqn(eqn.primitive,inputs,eqn.params,eqn.out_binders))map(partial(write,True),eqn.out_binders)else:eqns1.append(eqn)map(partial(write,False),eqn.out_binders)out_unknowns=map(read,jaxpr.outs)ifinstantiateisnotNone:forv,uk,instinzip(jaxpr.outs,out_unknowns,instantiate):ifinstandnotuk:new_res(v)out_unknowns=map(op.or_,out_unknowns,instantiate)residuals,num_res=list(residuals),len(residuals)assertall(type(v)isVarforvinresiduals),residualsins1,ins2=partition_list(in_unknowns,jaxpr.in_binders)outs1,outs2=partition_list(out_unknowns,jaxpr.outs)jaxpr1=Jaxpr(ins1,eqns1,outs1+residuals)jaxpr2=Jaxpr(residuals+ins2,eqns2,outs2)typecheck_partial_eval_jaxpr(jaxpr,in_unknowns,out_unknowns,jaxpr1,jaxpr2)returnjaxpr1,jaxpr2,out_unknowns,num_resdeftypecheck_partial_eval_jaxpr(jaxpr,unks_in,unks_out,jaxpr1,jaxpr2):jaxprty=typecheck_jaxpr(jaxpr)# (a1,  a2) -> (b1, b2 )jaxpr1ty=typecheck_jaxpr(jaxpr1)#  a1       -> (b1, res)jaxpr2ty=typecheck_jaxpr(jaxpr2)# (res, a2) -> b2a1,a2=partition_list(unks_in,jaxprty.in_types)b1,b2=partition_list(unks_out,jaxprty.out_types)b1_,res=split_list(jaxpr1ty.out_types,len(b1))res_,a2_=split_list(jaxpr2ty.in_types,len(res))b2_=jaxpr2ty.out_typesifjaxpr1ty.in_types!=a1:raiseTypeErrorifjaxpr2ty.out_types!=b2:raiseTypeErrorifb1!=b1_:raiseTypeErrorifres!=res_:raiseTypeErrorifa2!=a2_:raiseTypeErrorifb2!=b2_:raiseTypeErrorpartial_eval_jaxpr_rules={}defxla_call_peval_eqn(unks_in:list[bool],eqn:JaxprEqn,)->tuple[JaxprEqn,JaxprEqn,list[bool],list[Var]]:jaxpr=eqn.params['jaxpr']jaxpr1,jaxpr2,unks_out,num_res=partial_eval_jaxpr(jaxpr,unks_in)ins1,ins2=partition_list(unks_in,eqn.inputs)out_binders1,out_binders2=partition_list(unks_out,eqn.out_binders)residuals=[Var(v.aval)forvinjaxpr2.in_binders[:num_res]]eqn1=JaxprEqn(xla_call_p,ins1,dict(jaxpr=jaxpr1,num_consts=0),out_binders1+residuals)eqn2=JaxprEqn(xla_call_p,residuals+ins2,dict(jaxpr=jaxpr2,num_consts=0),out_binders2)returneqn1,eqn2,unks_out,residualspartial_eval_jaxpr_rules[xla_call_p]=xla_call_peval_eqn

With that, we can composelinearize andjit however we like:

@jitdeff(x):y=sin(x)*2.z=-y+xreturnzy,f_lin=linearize(f,3.)y_dot=f_lin(1.)print(y,y_dot)
2.7177599838802657 2.979984993200891
@jitdeff(x):y=sin(x)*2.z=g(x,y)returnz@jitdefg(x,y):returncos(x)+yy,f_lin=linearize(f,3.)y_dot=f_lin(1.)print(y,y_dot)
-0.7077524804807109 -2.121105001260758

vjp andgrad#

Thevjp transformation works a lot like linearize. Its type signature isanalogous:

linearize:(a->b)->a->(b,Ta-oTb)vjp:(a->b)->a->(b,Tb-oTa)

The only difference is that we transpose the linear part of the computationbefore returning it, so that it goes from typeTa-oTb to typeTb-oTa. That is, we’ll implementvjp as, essentially,

defvjp(f,x):y,f_lin=linearize(f,x)f_vjp=lambday_bar:transpose(f_lin)(y_bar)returny,f_vjp

Since we have the linear computation as a jaxpr, not just a Python callable,we can implement the transpose transformation as a jaxpr interpreter.

defvjp_flat(f,*primals_in):pvals_in=([PartialVal.known(x)forxinprimals_in]+[PartialVal.unknown(vspace(get_aval(x)))forxinprimals_in])primal_pvals_in,tangent_pvals_in=split_half(pvals_in)deff_jvp(*primals_tangents_in):primals_out,tangents_out=jvp(f,*split_half(primals_tangents_in))return[*primals_out,*tangents_out]jaxpr,pvals_out,consts=partial_eval_flat(f_jvp,pvals_in)# linearizeprimal_pvals,_=split_half(pvals_out)assertall(pval.is_knownforpvalinprimal_pvals)primals_out=[pval.constforpvalinprimal_pvals]transpose_inputs=consts+[UndefPrimal(p.aval)forpintangent_pvals_in]f_vjp=lambda*cts:eval_jaxpr_transposed(jaxpr,transpose_inputs,cts)returnprimals_out,f_vjpdefvjp(f,*primals_in):primals_in_flat,in_tree=tree_flatten(primals_in)f,out_tree=flatten_fun(f,in_tree)primals_out_flat,f_vjp_flat=vjp_flat(f,*primals_in_flat)primals_out=tree_unflatten(out_tree(),primals_out_flat)deff_vjp(*cotangents_out):cotangents_out_flat,_=tree_flatten(cotangents_out)cotangents_in_flat=f_vjp_flat(*cotangents_out_flat)returntree_unflatten(in_tree,cotangents_in_flat)returnprimals_out,f_vjpclassUndefPrimal(NamedTuple):aval:ShapedArrayregister_pytree_node(UndefPrimal,lambdau:(u.aval,()),lambdaaval,_:UndefPrimal(aval))

We useUndefPrimal instances to indicate which arguments with respect towhich we want to transpose. These arise because in general, being explicitabout closed-over values, we want to transpose functions of typea->b-oc to functions of typea->c-ob. Even more generally, theinputs with respect to which the function is linear could be scattered throughthe argument list. So we indicate the linear positions usingUndefPrimal.We registerUndefPrimal as a pytree node because the pytree mechanism givesa handy way to prune these placeholders out of argument lists.

Next, we can writeeval_jaxpr_transposed, along with transpose rules forall primitives which can be linear in at least one argument:

# NB: the analogous function in JAX is called 'backward_pass'defeval_jaxpr_transposed(jaxpr:Jaxpr,args:list[Any],cotangents:list[Any])->list[Any]:primal_env:dict[Var,Any]={}ct_env:dict[Var,Any]={}defread_primal(x:Atom)->Any:returnprimal_env.get(x,UndefPrimal(x.aval))iftype(x)isVarelsex.valdefwrite_primal(v:Var,val:Any)->None:iftype(val)isnotUndefPrimal:primal_env[v]=valdefread_cotangent(v:Var)->Any:returnct_env.pop(v,np.zeros(v.aval.shape,v.aval.dtype))defwrite_cotangent(x:Atom,val:Any):iftype(x)isVarandvalisnotNone:ct_env[x]=add(ct_env[x],val)ifxinct_envelsevalmap(write_primal,jaxpr.in_binders,args)map(write_cotangent,jaxpr.outs,cotangents)foreqninjaxpr.eqns[::-1]:primals_in=map(read_primal,eqn.inputs)cts_in=map(read_cotangent,eqn.out_binders)rule=transpose_rules[eqn.primitive]cts_out=rule(cts_in,*primals_in,**eqn.params)map(write_cotangent,eqn.inputs,cts_out)return[read_cotangent(v)forv,xinzip(jaxpr.in_binders,args)iftype(x)isUndefPrimal]transpose_rules={}
defmul_transpose_rule(cts,x,y):z_bar,=ctsassert(type(x)isUndefPrimal)^(type(y)isUndefPrimal)return[mul(z_bar,y),None]iftype(x)isUndefPrimalelse[None,mul(x,z_bar)]transpose_rules[mul_p]=mul_transpose_ruledefneg_transpose_rule(cts,x):ybar,=ctsasserttype(x)isUndefPrimalreturn[neg(ybar)]transpose_rules[neg_p]=neg_transpose_ruledefadd_transpose_rule(cts,x,y):z_bar,=ctsreturn[z_bar,z_bar]transpose_rules[add_p]=add_transpose_ruledefreduce_sum_transpose_rule(cts,x,*,axis):y_bar,=ctsreturn[broadcast(y_bar,x.aval.shape,axis)]transpose_rules[reduce_sum_p]=reduce_sum_transpose_ruledefxla_call_transpose_rule(cts,*invals,jaxpr,num_consts):delnum_consts# Unusedundef_primals=[type(x)isUndefPrimalforxininvals]transposed_jaxpr,new_consts=transpose_jaxpr(jaxpr,tuple(undef_primals))residuals,_=partition_list(undef_primals,invals)outs=bind(xla_call_p,*new_consts,*residuals,*cts,jaxpr=transposed_jaxpr,num_consts=len(new_consts))outs=iter(outs)return[next(outs)ifundefelseNoneforundefinundef_primals]transpose_rules[xla_call_p]=xla_call_transpose_rule@lru_cachedeftranspose_jaxpr(jaxpr:Jaxpr,undef_primals:tuple[bool,...])->tuple[Jaxpr,list[Any]]:avals_in,avals_out=typecheck_jaxpr(jaxpr)traceable=partial(eval_jaxpr_transposed,jaxpr)args=[UndefPrimal(a)ifuelseafora,uinzip(avals_in,undef_primals)]trans_jaxpr,consts,_=make_jaxpr(traceable,tuple(args),tuple(avals_out))typecheck_jaxpr(trans_jaxpr)returntrans_jaxpr,consts

Now that we can linearize and transpose, we can finally writegrad:

defgrad(f):defgradfun(x,*xs):y,f_vjp=vjp(f,x,*xs)ifnp.shape(y)!=():raiseTypeErrorx_bar,*_=f_vjp(np.ones(np.shape(y),np.result_type(y)))returnx_barreturngradfun
y,f_vjp=vjp(sin,3.)print(f_vjp(1.),cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
deff(x):y=sin(x)*2.z=-y+xreturnzprint(grad(f)(3.))
2.979984993200891
@jitdeff(x):y=x*2.z=g(y)returnz@jitdefg(x):returncos(x)*2.print(grad(f)(3.))
1.1176619927957034

Here’s something of a compositionality stress test:

# from core_test.py fun_with_nested_calls_2deffoo(x):@jitdefbar(y):defbaz(w):q=jit(lambdax:y)(x)q=q+jit(lambda:y)()q=q+jit(lambday:w+y)(y)q=jit(lambdaw:jit(sin)(x)*y)(1.0)+qreturnqp,t=jvp(baz,(x+1.0,),(y,))returnt+(x*p)returnbar(x)defassert_allclose(*vals):forv1,v2inzip(vals[:-1],vals[1:]):np.testing.assert_allclose(v1,v2)ans1=f(3.)ans2=jit(f)(3.)ans3,_=jvp(f,(3.,),(5.,))ans4,_=jvp(jit(f),(3.,),(5.,))assert_allclose(ans1,ans2,ans3,ans4)deriv1=grad(f)(3.)deriv2=grad(jit(f))(3.)deriv3=jit(grad(jit(f)))(3.)_,deriv4=jvp(f,(3.,),(1.,))_,deriv5=jvp(jit(f),(3.,),(1.,))assert_allclose(deriv1,deriv2,deriv3,deriv4,deriv5)hess1=grad(grad(f))(3.)hess2=grad(grad(jit(f)))(3.)hess3=grad(jit(grad(f)))(3.)hess4=jit(grad(grad(f)))(3.)_,hess5=jvp(grad(f),(3.,),(1.,))_,hess6=jvp(jit(grad(f)),(3.,),(1.,))_,hess7=jvp(jit(grad(f)),(3.,),(1.,))assert_allclose(hess1,hess2,hess3,hess4,hess5,hess6,hess7)

Part 5: the control flow primitivescond#

Next we’ll add higher-order primitives for staged-out control flow. Theseresemblejit from Part 3, another higher-order primitive, but differ in thatthey are parameterized by multiple callables rather than just one.

Addingcond#

We introduce acond primitive to represent conditional application of onefunction or another inside a jaxpr. We write the type ofcond asBool->(a->b)->(a->b)->a->b. In words,cond takes a booleanrepresenting the predicate and two functions of equal types. Depending on thevalue of the predicate, it applies one function or the other to its finalargument.

In Python, we represent it as a function which itself takes two functions asarguments. As withjit, the first step is to callmake_jaxpr on itscallable arguments to turn them into jaxprs:

defcond(pred,true_fn,false_fn,*operands):avals_in=[raise_to_shaped(get_aval(x))forxinoperands]true_jaxpr,true_consts,out_tree=make_jaxpr(true_fn,*avals_in)false_jaxpr,false_consts,out_tree_=make_jaxpr(false_fn,*avals_in)ifout_tree!=out_tree_:raiseTypeErrortrue_jaxpr,false_jaxpr=_join_jaxpr_consts(true_jaxpr,false_jaxpr,len(true_consts),len(false_consts))iftypecheck_jaxpr(true_jaxpr)!=typecheck_jaxpr(false_jaxpr):raiseTypeErrorouts=bind_cond(pred,*true_consts,*false_consts,*operands,true_jaxpr=true_jaxpr,false_jaxpr=false_jaxpr)returntree_unflatten(out_tree,outs)cond_p=Primitive('cond')def_join_jaxpr_consts(jaxpr1:Jaxpr,jaxpr2:Jaxpr,n1:int,n2:int)->tuple[Jaxpr,Jaxpr]:jaxpr1_type,jaxpr2_type=typecheck_jaxpr(jaxpr1),typecheck_jaxpr(jaxpr2)assertjaxpr1_type.in_types[n1:]==jaxpr2_type.in_types[n2:]consts1,rest1=split_list(jaxpr1.in_binders,n1)consts2,rest2=split_list(jaxpr2.in_binders,n2)new_jaxpr1=Jaxpr(consts1+consts2+rest1,jaxpr1.eqns,jaxpr1.outs)new_jaxpr2=Jaxpr(consts1+consts2+rest2,jaxpr2.eqns,jaxpr2.outs)returnnew_jaxpr1,new_jaxpr2defbind_cond(pred,*args,true_jaxpr,false_jaxpr):assertlen(args)==len(true_jaxpr.in_binders)==len(false_jaxpr.in_binders)returnbind(cond_p,pred,*args,true_jaxpr=true_jaxpr,false_jaxpr=false_jaxpr)

We requiretrue_jaxpr andfalse_jaxpr to have the same type, but becausethey might close over different constants (and because jaxprs can onlyrepresent closed terms, i.e. can’t have free variables and are insteadclosure-converted) we need to use the helper_join_jaxpr_consts to makeconsistent the input binder lists of the two jaxprs. (To be more economical wecould try to identify pairs of constants with the same shapes, but instead wejust concatenate the lists of constants.)

Next we can turn to adding interpreter rules forcond. Its evaluation ruleis simple:

defcond_impl(pred,*operands,true_jaxpr,false_jaxpr):ifpred:returneval_jaxpr(true_jaxpr,operands)else:returneval_jaxpr(false_jaxpr,operands)impl_rules[cond_p]=cond_impl
out=cond(True,lambda:3,lambda:4)print(out)
3

For its JVP and vmap rules, we only need to call the samejvp_jaxpr andvmap_jaxpr utilities we created forjit, followed by another pass of_join_jaxpr_consts:

defcond_jvp_rule(primals,tangents,*,true_jaxpr,false_jaxpr):pred,*primals=primals_,*tangents=tangentstrue_jaxpr,true_consts=jvp_jaxpr(true_jaxpr)false_jaxpr,false_consts=jvp_jaxpr(false_jaxpr)true_jaxpr,false_jaxpr=_join_jaxpr_consts(true_jaxpr,false_jaxpr,len(true_consts),len(false_consts))asserttypecheck_jaxpr(true_jaxpr)==typecheck_jaxpr(false_jaxpr)outs=bind_cond(pred,*true_consts,*false_consts,*primals,*tangents,true_jaxpr=true_jaxpr,false_jaxpr=false_jaxpr)primals_out,tangents_out=split_half(outs)returnprimals_out,tangents_outjvp_rules[cond_p]=cond_jvp_rule
out,out_tan=jvp(lambdax:cond(True,lambda:x*x,lambda:0.),(1.,),(1.,))print(out_tan)
2.0
defcond_vmap_rule(axis_size,vals_in,dims_in,*,true_jaxpr,false_jaxpr):pred,*vals_in=vals_inpred_dim,*dims_in=dims_inifpred_dimisnotnot_mapped:raiseNotImplementedError# TODOtrue_jaxpr,true_consts=vmap_jaxpr(true_jaxpr,axis_size,tuple(dims_in))false_jaxpr,false_consts=vmap_jaxpr(false_jaxpr,axis_size,tuple(dims_in))true_jaxpr,false_jaxpr=_join_jaxpr_consts(true_jaxpr,false_jaxpr,len(true_consts),len(false_consts))asserttypecheck_jaxpr(true_jaxpr)==typecheck_jaxpr(false_jaxpr)outs=bind_cond(pred,*true_consts,*false_consts,*vals_in,true_jaxpr=true_jaxpr,false_jaxpr=false_jaxpr)returnouts,[0]*len(outs)vmap_rules[cond_p]=cond_vmap_rule
xs=np.array([1.,2.,3])out=vmap(lambdax:cond(True,lambda:x+1.,lambda:0.),(0,))(xs)print(out)
[2. 3. 4.]

Notice that we’re not currently supporting the case where the predicate valueitself is batched. In mainline JAX, we handle this case by transforming theconditional to aselect primitive.That transformation is semantically correct so long astrue_fun andfalse_fun do not involve any side-effecting primitives.

Another thing not represented here, but present in the mainline JAX, is thatapplying transformations to two jaxprs of equal type might result in jaxprs ofdifferent types. For example, applying the mainline JAX version ofvmap_jaxpr to the identity-function jaxpr

{lambdaa:float32[].letin(a)}

would result in a jaxpr with a batched output, of type[float32[10]]->[float32[10]] if the batch size were 10, while applying itto the zero-function jaxpr

{lambdaa:float32[].letin(0.)}

would result in a jaxpr with an unbatched output, of type[float32[10]]->[float32[]]. This is an optimization, aimed at not batchingvalues unnecessarily. But it means that incond we’d need an extra step ofjoining the two transformed jaxprs to have consistent output types. We don’tneed this step here because we chosevmap_jaxpr always to batch all outputsover the leading axis.

Next we can turn to abstract evaluation and XLA lowering rules:

defcond_abstract_eval(pred_type,*in_types,true_jaxpr,false_jaxpr):ifpred_type!=ShapedArray((),np.dtype('bool')):raiseTypeErrorjaxpr_type=typecheck_jaxpr(true_jaxpr)ifjaxpr_type!=typecheck_jaxpr(false_jaxpr):raiseTypeErrorifnotall(t1==t2fort1,t2inzip(jaxpr_type.in_types,in_types)):raiseTypeErrorreturnjaxpr_type.out_typesabstract_eval_rules[cond_p]=cond_abstract_evaldefcond_translation(c,in_avals,out_avals,in_vals,*,true_jaxpr,false_jaxpr):delin_avals# Unusedpred,*in_vals=in_valsop=hlo.IfOp([aval_to_ir_type(aval)foravalinout_avals],pred)withir.InsertionPoint(op.true_branch.blocks.append()):hlo.return_(jaxpr_subcomp(c,true_jaxpr,in_vals))withir.InsertionPoint(op.false_branch.blocks.append()):hlo.return_(jaxpr_subcomp(c,false_jaxpr,in_vals))returnop.resultshlo_translations[cond_p]=cond_translation
out=jit(lambda:cond(False,lambda:1,lambda:2))()print(out)
2

Finally, to support reverse-mode automatic differentiation, we need partialevaluation and transposition rules. For partial evaluation, we need tointroduce another jaxpr-munging utility,_join_jaxpr_res, to handle the factthat applying partial evaluation totrue_fun andfalse_fun will in generalresult in distinct residuals. We use_join_jaxpr_res to make the outputtypes of the transformed jaxprs consistent (while_join_jaxpr_consts dealtwith input types).

defcond_partial_eval(trace,tracers,*,true_jaxpr,false_jaxpr):pred_tracer,*tracers=tracersassertpred_tracer.pval.is_knownpred=pred_tracer.pval.constin_uks=[nott.pval.is_knownfortintracers]*jaxprs,out_uks,num_res=_cond_partial_eval(true_jaxpr,false_jaxpr,in_uks)t_jaxpr1,f_jaxpr1,t_jaxpr2,f_jaxpr2=jaxprsknown_tracers,unknown_tracers=partition_list(in_uks,tracers)known_vals=[t.pval.constfortinknown_tracers]outs1_res=bind_cond(pred,*known_vals,true_jaxpr=t_jaxpr1,false_jaxpr=f_jaxpr1)outs1,res=split_list(outs1_res,len(outs1_res)-num_res)pred_tracer_=trace.instantiate_const(full_raise(trace,pred_tracer))res_tracers=[trace.instantiate_const(full_raise(trace,x))forxinres]outs2=[PartialEvalTracer(trace,PartialVal.unknown(v.aval),None)forvint_jaxpr2.outs]eqn=JaxprEqnRecipe(cond_p,[pred_tracer_,*res_tracers,*unknown_tracers],dict(true_jaxpr=t_jaxpr2,false_jaxpr=f_jaxpr2),[v.avalforvint_jaxpr2.outs],map(ref,outs2))fortinouts2:t.recipe=eqnreturnmerge_lists(out_uks,outs1,outs2)partial_eval_rules[cond_p]=cond_partial_evaldef_cond_partial_eval(true_jaxpr:Jaxpr,false_jaxpr:Jaxpr,in_uks:list[bool])->tuple[Jaxpr,Jaxpr,Jaxpr,Jaxpr,list[bool],int]:_,_,t_out_uks,_=partial_eval_jaxpr(true_jaxpr,in_uks)_,_,f_out_uks,_=partial_eval_jaxpr(false_jaxpr,in_uks)out_uks=map(op.or_,t_out_uks,f_out_uks)t_jaxpr1,t_jaxpr2,_,t_nres=partial_eval_jaxpr(true_jaxpr,in_uks,out_uks)f_jaxpr1,f_jaxpr2,_,f_nres=partial_eval_jaxpr(false_jaxpr,in_uks,out_uks)t_jaxpr1,f_jaxpr1=_join_jaxpr_res(t_jaxpr1,f_jaxpr1,t_nres,f_nres)t_jaxpr2,f_jaxpr2=_join_jaxpr_consts(t_jaxpr2,f_jaxpr2,t_nres,f_nres)asserttypecheck_jaxpr(t_jaxpr1)==typecheck_jaxpr(f_jaxpr1)asserttypecheck_jaxpr(t_jaxpr2)==typecheck_jaxpr(f_jaxpr2)num_res=t_nres+f_nresreturnt_jaxpr1,f_jaxpr1,t_jaxpr2,f_jaxpr2,out_uks,num_resdef_join_jaxpr_res(jaxpr1:Jaxpr,jaxpr2:Jaxpr,n1:int,n2:int)->tuple[Jaxpr,Jaxpr]:jaxpr1_type,jaxpr2_type=typecheck_jaxpr(jaxpr1),typecheck_jaxpr(jaxpr2)out_types1,_=split_list(jaxpr1_type.out_types,len(jaxpr1.outs)-n1)out_types2,_=split_list(jaxpr2_type.out_types,len(jaxpr2.outs)-n2)assertout_types1==out_types2outs1,res1=split_list(jaxpr1.outs,len(jaxpr1.outs)-n1)outs2,res2=split_list(jaxpr2.outs,len(jaxpr2.outs)-n2)zeros_like1=[Lit(np.zeros(v.aval.shape,v.aval.dtype))forvinres1]zeros_like2=[Lit(np.zeros(v.aval.shape,v.aval.dtype))forvinres2]new_jaxpr1=Jaxpr(jaxpr1.in_binders,jaxpr1.eqns,outs1+res1+zeros_like2)new_jaxpr2=Jaxpr(jaxpr2.in_binders,jaxpr2.eqns,outs2+zeros_like1+res2)returnnew_jaxpr1,new_jaxpr2
_,f_lin=linearize(lambdax:cond(True,lambda:x,lambda:0.),1.)out=f_lin(3.14)print(out)
3.14
defcond_peval_eqn(unks_in:list[bool],eqn:JaxprEqn,)->tuple[JaxprEqn,JaxprEqn,list[bool],list[Atom]]:pred_unk,*unks_in=unks_inassertnotpred_unktrue_jaxpr,false_jaxpr=eqn.params['true_jaxpr'],eqn.params['false_jaxpr']*jaxprs,unks_out,num_res=_cond_partial_eval(true_jaxpr,false_jaxpr,unks_in)t_jaxpr1,f_jaxpr1,t_jaxpr2,f_jaxpr2=jaxprsins1,ins2=partition_list(unks_in,eqn.inputs[1:])outs1,outs2=partition_list(unks_out,eqn.out_binders)residuals,_=split_list(t_jaxpr2.in_binders,num_res)eqn1=JaxprEqn(cond_p,[eqn.inputs[0],*ins1],dict(true_jaxpr=t_jaxpr1,false_jaxpr=f_jaxpr1),outs1+residuals)eqn2=JaxprEqn(cond_p,[eqn.inputs[0],*residuals,*ins2],dict(true_jaxpr=t_jaxpr2,false_jaxpr=f_jaxpr2),outs2)res=[eqn.inputs[0],*residuals]iftype(eqn.inputs[0])isVarelseresidualsreturneqn1,eqn2,unks_out,respartial_eval_jaxpr_rules[cond_p]=cond_peval_eqn
_,f_lin=linearize(jit(lambdax:cond(True,lambda:x,lambda:0.)),1.)out=f_lin(3.14)print(out)
3.14

Transposition is a fairly straightforward application oftranspose_jaxpr:

defcond_transpose_rule(cts,pred,*invals,true_jaxpr,false_jaxpr):undef_primals=tuple(type(x)isUndefPrimalforxininvals)true_jaxpr,true_consts=transpose_jaxpr(true_jaxpr,undef_primals)false_jaxpr,false_consts=transpose_jaxpr(false_jaxpr,undef_primals)true_jaxpr,false_jaxpr=_join_jaxpr_consts(true_jaxpr,false_jaxpr,len(true_consts),len(false_consts))res=[xforxininvalsiftype(x)isnotUndefPrimal]outs=bind_cond(pred,*true_consts,*false_consts,*res,*cts,true_jaxpr=true_jaxpr,false_jaxpr=false_jaxpr)outs=iter(outs)return[None]+[next(outs)iftype(x)isUndefPrimalelseNoneforxininvals]transpose_rules[cond_p]=cond_transpose_rule
out=grad(lambdax:cond(True,lambda:x*x,lambda:0.))(1.)print(out)
2.0

Show code cell source

Hide code cell source

defpprint_cond(names:defaultdict[Var,str],eqn:JaxprEqn)->PPrint:true_jaxpr,false_jaxpr=eqn.params['true_jaxpr'],eqn.params['false_jaxpr']new_params={k:vfork,vineqn.params.items()ifnotk.endswith('jaxpr')}lhs=pp(' '.join(var_str(names,v)forvineqn.out_binders))rhs=(pp(eqn.primitive.name)>>pp_params(new_params)>>pp(' '.join(names[x]ifisinstance(x,Var)elsestr(x.val)forxineqn.inputs)))returnvcat([lhs>>pp(' = ')>>rhs,pp_jaxpr(true_jaxpr).indent(2),pp_jaxpr(false_jaxpr).indent(2)])pp_rules[cond_p]=pprint_cond

[8]ページ先頭

©2009-2026 Movatter.jp