Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Autodidax2, part 1: JAX from scratch, again#

If you want to understand how JAX works you could trying reading the code. Butthe code is complicated, often for no good reason. This notebook presents astripped-back version without the cruft. It’s a minimal version of JAX fromfirst principles. Enjoy!

Main idea: context-sensitive interpretation#

JAX is two things:

  1. a set of primitive operations (roughly the NumPy API)

  2. a set of interpreters over those primitives (compilation, AD, etc.)

In this minimal version of JAX we’ll start with just two primitive operations,addition and multiplication, and we’ll add interpreters one by one. Suppose wehave a user-defined function like this:

deffoo(x):returnmul(x,add(x,3.0))

We want to be able to interpretfoo in different ways without changing itsimplementation: we want to evaluate it on concrete values, differentiate it,stage it out to an IR, compile it and so on.

Here’s how we’ll do it. For each of these interpretations we’ll define anInterpreter object with a rule for handling each primitive operation. We’llkeep track of thecurrent interpreter using a global context variable. Theuser-facing functionsadd andmul will dispatch to the currentinterpreter. At the beginning of the program the current interpreter will bethe “evaluating” interpreter which just evaluates the operations on ordinaryconcrete data. Here’s what this all looks like so far.

fromenumimportEnum,autofromcontextlibimportcontextmanagerfromtypingimportAny# The full (closed) set of primitive operationsclassOp(Enum):add=auto()# addition on floatsmul=auto()# multiplication on floats# Interpreters have rules for handling each primitive operation.classInterpreter:definterpret_op(self,op:Op,args:tuple[Any,...]):assertFalse,"subclass should implement this"# Our first interpreter is the "evaluating interpreter" which performs ordinary# concrete evaluation.classEvalInterpreter:definterpret_op(self,op,args):assertall(isinstance(arg,float)forarginargs)matchop:caseOp.add:x,y=argsreturnx+ycaseOp.mul:x,y=argsreturnx*ycase_:raiseValueError(f"Unrecognized primitive op:{op}")# The current interpreter is initially the evaluating interpreter.current_interpreter=EvalInterpreter()# A context manager for temporarily changing the current interpreter@contextmanagerdefset_interpreter(new_interpreter):globalcurrent_interpreterprev_interpreter=current_interpretertry:current_interpreter=new_interpreteryieldfinally:current_interpreter=prev_interpreter# The user-facing functions `mul` and `add` dispatch to the current interpreter.defadd(x,y):returncurrent_interpreter.interpret_op(Op.add,(x,y))defmul(x,y):returncurrent_interpreter.interpret_op(Op.mul,(x,y))

At this point we can callfoo with ordinary concrete inputs and see theresults:

print(foo(2.0))
10.0

Aside: forward-mode automatic differentiation#

For our second interpreter we’re going to try forward-mode automaticdifferentiation (AD). Here’s a quick introduction to forward-mode AD in casethis is the first time you’ve come across it. Otherwise skip ahead to the“JVPInterprer” section.

Suppose we’re interested in the derivative offoo(x) evaluated atx=2.0.We could approximate it with finite differences:

print((foo(2.00001)-foo(2.0))/0.00001)
7.000009999913458

The answer is close to 7.0 as expected. But computing it this way required twoevaluations of the function (not to mention the roundoff error and truncationerror). Here’s a funny thing though. We can almost get the answer with asingle evaluation:

print(foo(2.00001))
10.0000700001

The answer we’re looking for, 7.0, is right there in the insignificant digits!

Here’s one way to think about what’s happening. The initial argument tofoo,2.00001, carries two pieces of data: a “primal” value, 2.0, and a “tangent”value,1.0. The representation of this primal-tangent pair,2.00001, isthe sum of the two, with the tangent scaled by a small fixed epsilon,1e-5.Ordinary evaluation offoo(2.00001) propagates this primal-tangent pair,producing10.0000700001 as the result. The primal and tangent components arewell separated in scale so we can visually interpret the result as theprimal-tangent pair (10.0, 7.0), ignoring the ~1e-10 truncation error atthe end.

The idea with forward-mode differentiation is to do the same thing but exactlyand explicitly (eyeballing floats doesn’t really scale). We’ll represent theprimal-tangent pair as an actual pair instead of folding them both into asingle floating point number. For each primitive operation we’ll have a rulethat describes how to propagate these primal tangent pairs. Let’s work out therules for our two primitives.

Addition is easy. Considerx+y wherex=xp+xt*eps andy=yp+yt*eps(“p” for “primal”, “t” for “tangent”):

 x + y = (xp + xt * eps) + (yp + yt * eps)       =   (xp + yp)             # primal component         + (xt + yt) * eps       # tangent component

The result is a first-order polynomial ineps and we can read off theprimal-tangent pair as (xp + yp, xt + yt).

Multiplication is more interesting:

 x * y = (xp + xt * eps) * (yp + yt * eps)       =    (xp * yp)                        # primal component          + (xp * yt + xt * yp) * eps        # tangent component          + (xt * yt)           * eps * eps  # quadratic component, vanishes in the eps->0 limit

Now we have a second order polynomial. But as epsilon goes to zero thequadratic term vanishes and our primal-tangent pairis just(xp*yp,xp*yt+xt*yp)(In our earlier example with finiteeps this term not vanishing iswhy we had the 1e-10 “truncation error”).

Putting this into code, we can write down the forward-AD rules for additionand multiplication and expressfoo in terms of these:

fromdataclassesimportdataclass# A primal-tangent pair is conventionally called a "dual number"@dataclassclassDualNumber:primal:floattangent:floatdefadd_dual(x:DualNumber,y:DualNumber)->DualNumber:returnDualNumber(x.primal+y.primal,x.tangent+y.tangent)defmul_dual(x:DualNumber,y:DualNumber)->DualNumber:returnDualNumber(x.primal*y.primal,x.primal*y.tangent+x.tangent*y.primal)deffoo_dual(x:DualNumber)->DualNumber:returnmul_dual(x,add_dual(x,DualNumber(3.0,0.0)))print(foo_dual(DualNumber(2.0,1.0)))
DualNumber(primal=10.0, tangent=7.0)

That works! But rewritingfoo to use the_dual versions of addition andmultiplication was a bit tedious. Let’s get back to the main program and useour interpretation machinery to do the rewrite automatically.

JVP Interpreter#

We’ll set up a new interpreter calledJVPInterpreter (“JVP” for“Jacobian-vector product”) which propagates these dual numbers instead ofordinary values. TheJVPInterpreter has methods ‘add’ and ‘mul’ that operateon dual number. They cast constant arguments to dual numbers as needed bycallingJVPInterpreter.lift. In our manually rewritten version above we didthat by replacing the literal3.0 withDualNumber(3.0,0.0).

# This is like DualNumber above except that is also has a pointer to the# interpreter it belongs to, which is needed to avoid "perturbation confusion"# in higher order differentiation.@dataclassclassTaggedDualNumber:interpreter:Interpreterprimal:floattangent:floatclassJVPInterpreter(Interpreter):def__init__(self,prev_interpreter:Interpreter):# We keep a pointer to the interpreter that was current when this# interpreter was first invoked. That's the context in which our# rules should run.self.prev_interpreter=prev_interpreterdefinterpret_op(self,op,args):args=tuple(self.lift(arg)forarginargs)withset_interpreter(self.prev_interpreter):matchop:caseOp.add:# Notice that we use `add` and `mul` here, which are the# interpreter-dispatching functions defined earlier.x,y=argsreturnself.dual_number(add(x.primal,y.primal),add(x.tangent,y.tangent))caseOp.mul:x,y=argsx=self.lift(x)y=self.lift(y)returnself.dual_number(mul(x.primal,y.primal),add(mul(x.primal,y.tangent),mul(x.tangent,y.primal)))defdual_number(self,primal,tangent):returnTaggedDualNumber(self,primal,tangent)# Lift a constant value (constant with respect to this interpreter) to# a TaggedDualNumber.deflift(self,x):ifisinstance(x,TaggedDualNumber)andx.interpreterisself:returnxelse:returnself.dual_number(x,0.0)defjvp(f,primal,tangent):jvp_interpreter=JVPInterpreter(current_interpreter)dual_number_in=jvp_interpreter.dual_number(primal,tangent)withset_interpreter(jvp_interpreter):result=f(dual_number_in)dual_number_out=jvp_interpreter.lift(result)returndual_number_out.primal,dual_number_out.tangent# Let's try it out:print(jvp(foo,2.0,1.0))# Because we were careful to consider nesting interpreters, higher-order AD# works out of the box:defderivative(f,x):_,tangent=jvp(f,x,1.0)returntangentdefnth_order_derivative(n,f,x):ifn==0:returnf(x)else:returnderivative(lambdax:nth_order_derivative(n-1,f,x),x)
(10.0, 7.0)
print(nth_order_derivative(0,foo,2.0))
10.0
print(nth_order_derivative(1,foo,2.0))
7.0
print(nth_order_derivative(2,foo,2.0))
2.0
# The rest are zero because `foo` is only a second-order polymonialprint(nth_order_derivative(3,foo,2.0))
0.0
print(nth_order_derivative(4,foo,2.0))
0.0

There are some subtleties worth discussing. First, how do you tell ifsomething is constant with respect to differentiation? It’s tempting to say“it’s a constant if and only if it’s not a dual number”. But actually dualnumbers created by adifferent JVPInterpreter also need to be consideredconstants with respect to the JVPInterpreter we’re currently handling. That’swhy we need thex.interpreterisself check inJVPInterpreter.lift. Thiscomes up in higher order differentiation when there are multiple JVPInterprersin scope. The sort of bug where you accidentally interpret a dual number froma different interpreter as non-constant is sometimes called “perturbationconfusion” in the literature. Here’s an example program that would have giventhe wrong answer if we hadn’t had theandx.interpreterisself check inJVPInterpreter.lift.

deff(x):# g is constant in its (ignored) argument `y`. Its derivative should be zero# but our AD will mess it up if we don't distinguish perturbations from# different interpreters.defg(y):returnxshould_be_zero=derivative(g,0.0)returnmul(x,should_be_zero)print(derivative(f,0.0))
0.0

Another subtlety:JVPInterpreter.add andJVPInterpreter.mul describeaddition and multiplication on dual numbers in terms of addition andmultiplication on the primal and tangent components. But we don’t use ordinary+ and* for this. Instead we use our ownadd andmul functions whichdispatch to the current interpreter. Before calling them we set the currentinterpreter to be theprevious interpreter, i.e. the interpreter that wascurrent whenJVPInterpreter was first invoked. If we didn’t do this we’dhave an infinite recursion, withadd andmul dispatching toJVPInterpreter endlessly. The advantage of using ownadd andmul insteadof ordinary+ and* is that it means we can nest these interpreters and dohigher-order AD.

At this point you might be wondering: have we just reinvented operatoroverloading? Python overloads the infix ops+ and* to dispatch to theargument’s__add__ and__mul__. Could we have just used that mechanisminstead of this whole interpreter business? Yes, actually. Indeed, the earlierautomatic differentiation (AD) literature uses the term “operator overloading”to describe this style of AD implementation. One detail is that we can’t relyexclusively on Python built-in overloading because that only lets us overloada handful of built-in infix ops whereas we eventually want to overloadnumpy-level operations likesin andcos. So we need our own mechanism.

But there’s a more important difference: our dispatch is based oncontextwhereas traditional Python-style overloading is based ondata. This isactually a recent development for JAX. The earliest versions of JAX lookedmore like traditional data-based overloading. An interpreter (a “trace” in JAXjargon) for an operation would be chosen based on data attached to thearguments to that operation. We’ve gradually made the interpreter-dispatchdecision rely more and more on context rather than data (omnistaging [link],stackless [link]). The reason to prefer context-based interpretation overdata-based interpretation is that it makes the implementation much simpler.

All that said, we doalso want to take advantage of Python’s built-inoverloading mechanism. That way we get the syntactic convenience of usinginfix operators+ and* instead of writing outadd(..) andmul(..).But we’ll put that aside for now.

3. Staging to an untyped IR#

The two program transformations we’ve seen so far – evaluation and JVP –both traverse the input program from top to bottom. They visit the operationsone by one in the same order as ordinary evaluation. A convenient thing abouttop-to-bottom transformations is that they can be implemented eagerly, or“online”, meaning that we can evaluate the program from top to bottom andperform the necessary transformations as we go. We never look at the entireprogram at once.

But not all transformations work this way. For example, dead-code eliminationrequires traversing from bottom to top, collecting usage statistics on the wayup and eliminating pure operations whose results have no uses. Anotherbottom-to-top transformation is AD transposition, which we use to implementreverse-mode AD. For these we need to first “stage” the program into an IR(internal representation), a data structure representing the program, which wecan then traverse in any order we like. Building this IR from a Python programwill be the goal of our third and final interpreter.

First, let’s define the IR. We’ll do an untypes ANF IR to start. A function(we call IR functions “jaxprs” in JAX) will have a list of formal parameters,a list of operations, and a return value. Each argument to an operation mustbe an “atom”, which is either a variable or a literal. The return value of thefunction is also an atom.

Var=str# Variables are just strings in this untyped IRAtom=Var|float# Atoms (arguments to operations) can be variables or (float) literals# Equation - a single line in our IR like `z = mul(x, y)`@dataclassclassEquation:var:Var# The variable name of the resultop:Op# The primitive operation we're applyingargs:tuple[Atom]# The arguments we're applying the primitive operation to# We call an IR function a "Jaxpr", for "JAX expression"@dataclassclassJaxpr:parameters:list[Var]# The function's formal parameters (arguments)equations:list[Equation]# The body of the function, a list of instructions/equationsreturn_val:Atom# The function's return valuedef__str__(self):lines=[]lines.append(', '.join(bforbinself.parameters)+' ->')foreqninself.equations:args_str=', '.join(str(arg)forargineqn.args)lines.append(f'{eqn.var} ={eqn.op}({args_str})')lines.append(self.return_val)return'\n'.join(lines)

To build the IR from a Python function we define aStagingInterpreter thattakes each operation and adds it to a growing list of all the operations we’veseen so far:

classStagingInterpreter(Interpreter):def__init__(self):self.equations=[]# A mutable list of all the ops we've seen so farself.name_counter=0# Counter for generating unique namesdeffresh_var(self):self.name_counter+=1return"v_"+str(self.name_counter)definterpret_op(self,op,args):binder=self.fresh_var()self.equations.append(Equation(binder,op,args))returnbinderdefbuild_jaxpr(f,num_args):interpreter=StagingInterpreter()parameters=tuple(interpreter.fresh_var()for_inrange(num_args))withset_interpreter(interpreter):result=f(*parameters)returnJaxpr(parameters,interpreter.equations,result)

Now we can construct an IR for a Python program and print it out:

print(build_jaxpr(foo,1))
v_1 ->  v_2 = Op.add(v_1, 3.0)  v_3 = Op.mul(v_1, v_2)v_3

We can also evaluate our IR by writing an explicit interpreter that traversesthe operations one by one:

defeval_jaxpr(jaxpr,args):# An environment mapping variables to valuesenv=dict(zip(jaxpr.parameters,args))defeval_atom(x):returnenv[x]ifisinstance(x,Var)elsexforeqninjaxpr.equations:args=tuple(eval_atom(x)forxineqn.args)env[eqn.var]=current_interpreter.interpret_op(eqn.op,args)returneval_atom(jaxpr.return_val)print(eval_jaxpr(build_jaxpr(foo,1),(2.0,)))
10.0

We’ve written this interpreter in terms ofcurrent_interpreter.interpret_opwhich means we’ve done a full round-trip: interpretable Python program to IRto interpretable Python program. Since the result is “interpretable” we candifferentiate it again, or stage it out or anything we like:

print(jvp(lambdax:eval_jaxpr(build_jaxpr(foo,1),(x,)),2.0,1.0))
(10.0, 7.0)

Up next…#

That’s it for part one of this tutorial. We’ve done two primitives, threeinterpreters and the tracing mechanism that weaves them together. In the nextpart we’ll add types other than floats, error handling, compilation,reverse-mode AD and higher-order primitives. Note that the second part isstructured differently. Rather than trying to have a top-to-bottom order thatobeys both code dependencies (e.g. data structures need to be defined beforethey’re used) and pedagogical dependencies (concepts need to be introducedbefore they’re implemented) we’re going with a single file that can be approachedin any order.


[8]ページ先頭

©2009-2026 Movatter.jp