Control autodiff’s saved values with jax.checkpoint (aka jax.remat)
Contents
Control autodiff’s saved values withjax.checkpoint (akajax.remat)#
importjaximportjax.numpyasjnp
Summary#
Use thejax.checkpoint decorator (aliased asjax.remat) withjax.grad to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.
Don’t miss thepractical notes for a discussion about howjax.checkpoint interacts withjax.jit.
Without usingjax.checkpoint, the forward pass ofjax.grad(f)(x) saves, for use on the backward pass, the values of Jacobian coefficients and other intermediates. We call these saved valuesresiduals:
defg(W,x):y=jnp.dot(W,x)returnjnp.sin(y)deff(W1,W2,W3,x):x=g(W1,x)x=g(W2,x)x=g(W3,x)returnxW1=jnp.ones((5,4))W2=jnp.ones((6,5))W3=jnp.ones((7,6))x=jnp.ones(4)# Inspect the 'residual' values to be saved on the forward pass# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`fromjax.ad_checkpointimportprint_saved_residualsjax.ad_checkpoint.print_saved_residuals(f,W1,W2,W3,x)
f32[5,4] from the argument 'W1'f32[6,5] from the argument 'W2'f32[7,6] from the argument 'W3'f32[4] from the argument 'x'f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
By applyingjax.checkpoint to sub-functions, as a decorator or at specific application sites, we force JAX not to save any of that sub-function’s residuals. Instead, only the inputs of ajax.checkpoint-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:
deff2(W1,W2,W3,x):x=jax.checkpoint(g)(W1,x)x=jax.checkpoint(g)(W2,x)x=jax.checkpoint(g)(W3,x)returnxjax.ad_checkpoint.print_saved_residuals(f2,W1,W2,W3,x)
f32[5,4] from the argument 'W1'f32[6,5] from the argument 'W2'f32[7,6] from the argument 'W3'f32[4] from the argument 'x'f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
Here the values of twosin applications are saved because they are argumentsin subsequent applications of thejax.checkpoint-decoratedg function, andinputs to ajax.checkpoint-decorated function may be saved. But no values ofcos applications are saved.
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerializationpolicy. Here is an example that saves only the results ofdot operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):
f3=jax.checkpoint(f,policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)jax.ad_checkpoint.print_saved_residuals(f3,W1,W2,W3,x)
f32[5,4] from the argument 'W1'f32[6,5] from the argument 'W2'f32[7,6] from the argument 'W3'f32[4] from the argument 'x'f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
You can also use policies to refer to intermediate values you name usingjax.ad_checkpoint.checkpoint_name:
fromjax.ad_checkpointimportcheckpoint_namedeff4(W1,W2,W3,x):x=checkpoint_name(g(W1,x),name='a')x=checkpoint_name(g(W2,x),name='b')x=checkpoint_name(g(W3,x),name='c')returnxf4=jax.checkpoint(f4,policy=jax.checkpoint_policies.save_only_these_names('a'))jax.ad_checkpoint.print_saved_residuals(f4,W1,W2,W3,x)
f32[5,4] from the argument 'W1'f32[6,5] from the argument 'W2'f32[7,6] from the argument 'W3'f32[4] from the argument 'x'f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)
When playing around with these toy examples, we can get a closer look at what’s going on using theprint_fwd_bwd utility defined in this notebook:
fromjax.tree_utilimporttree_flatten,tree_unflattenfromrich.consoleimportConsolefromrich.tableimportTableimportrich.textdefprint_fwd_bwd(f,*args,**kwargs)->None:args,in_tree=tree_flatten((args,kwargs))deff_(*args):args,kwargs=tree_unflatten(in_tree,args)returnf(*args,**kwargs)fwd=jax.make_jaxpr(lambda*args:jax.vjp(f_,*args))(*args).jaxpry,f_vjp=jax.vjp(f_,*args)res,in_tree=tree_flatten(f_vjp)defg_(*args):*res,y=argsf_vjp=tree_unflatten(in_tree,res)returnf_vjp(y)bwd=jax.make_jaxpr(g_)(*res,y).jaxprtable=Table(show_header=False,show_lines=True,padding=(1,2,0,2),box=None)table.add_row("[bold green]forward computation:","[bold green]backward computation:")table.add_row(rich.text.Text.from_ansi(str(fwd)),rich.text.Text.from_ansi(str(bwd)))console=Console(width=240,force_jupyter=True)console.print(table)def_renderable_repr(self):returnself.htmlrich.jupyter.JupyterRenderable._repr_html_=_renderable_repr
# no use of jax.checkpoint:print_fwd_bwd(f,W1,W2,W3,x)
forward computation:backward computation: {lambda; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4].let {lambda; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[7].let f:f32[5] = sin ek:f32[7] = mul j a g:f32[5] = cos e l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b i:f32[6] = sin h n:f32[6] = mul l d j:f32[6] = cos h o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e l:f32[7] = sin k q:f32[5] = mul o g m:f32[7] = cos k r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q iin(l, m, i, c, j, f, b, g, d, a) } s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q hin(s, p, m, r) }
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:print_fwd_bwd(f3,W1,W2,W3,x)
forward computation:backward computation: {lambda; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4].let {lambda; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7].lete:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a di:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ f:f32[5] = sin e differentiated=True g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f jaxpr={lambda; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] h:f32[6] = sin g s:f32[4] t:f32[7].let i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c hu:f32[5] = sin m j:f32[7] = sin i v:f32[5] = cos min(j, e, g, i, a, b, c, d) } w:f32[6] = sin n x:f32[6] = cos n y:f32[7] = cos o z:f32[7] = mul t y ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r bb:f32[6] = mul ba x bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z win(bf, bg, bh, be) } policy=<function dot_with_no_batch_dims at 0x7f5e469b1700> prevent_cse=True ] a b c d e f g hin(i, j, k, l) }
Let’s think step by step#
You might want to first (re)readthe Autodiff Cookbook Part 1.
Fundamentals ofjax.checkpoint#
In bothjax.linearize andjax.vjp there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices withjax.checkpoint.
One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example ofsin_vjp:
defsin_vjp(x):y=jnp.sin(x)cos_x=jnp.cos(x)returny,lambday_bar:cos_x*y_bar
Another valid implementation would compute the value ofjnp.cos(x) on the backward pass rather than on the forward pass:
defsin_vjp2(x):y=jnp.sin(x)returny,lambday_bar:jnp.cos(x)*y_bar
For this particular function, the amount of memory used by the two versions is the same, though we’ve reduced the FLOPs for the primal computation (i.e. the forward pass) and increased the FLOPs for the cotangent computation (i.e. the backward pass).
There’s another choice when it comes to function composition. Recall our VJP rule for a composition of two functions:
deff(x):y=g(x)z=h(y)returnzdeff_vjp(x):y,g_vjp=jax.vjp(g,x)z,h_vjp=jax.vjp(h,y)deff_bwd(z_bar):y_bar,=h_vjp(z_bar)x_bar,=g_vjp(y_bar)returnx_barreturnz,f_bwd
An alternative is:
deff_vjp_checkpoint(x):y=g(x)z,h_vjp=jax.vjp(h,y)deff_bwd2(z_bar):y_bar,=h_vjp(z_bar)_,g_vjp=jax.vjp(g,x)x_bar,=g_vjp(y_bar)returnx_barreturnz,f_bwd2
In words, this alternative implementation doesn’t computeg_vjp, or the residual values in its closure, on the forward pass. Instead it only computes them in the backward passf_bwd2. That meansf_vjp_checkpoint requires less memory: ifg andh each required similar amounts of memory for their residuals, each much larger thanx, then the function produced byf_vjp_checkpoint(x) requires half the memory as that off_vjp(x)!
The cost we pay is redundant work: inf_bwd2 we must re-evaluateg(x) as part ofjax.vjp(g,x) just to discard its value (in the underscore variable on the line_,g_vjp=jax.vjp(g,x)).
We can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead usingjax.checkpoint in an alternative definition of the original functionf:
deff_checkpoint(x):y=jax.checkpoint(g)(x)z=h(y)returnz
In other words, we applyjax.checkpoint tog, the first stage off, rather than tof itself. This way, when we evaluatejax.grad(f_checkpoint)(x), we’d get a computation like:
run the forward pass of
g, discarding residual values;run the forward pass of
h, saving residuals;run the backward pass of
h, consuming residuals from step 2;re-run the forward pass of
g, saving residuals;run the backward pass of
g, consuming residuals from step 4.
That is, by evaluatingjax.grad(f_checkpoint)(x) we’d get the same computation as:
deff_checkpoint_grad(x):y=g(x)# step 1_,h_vjp=jax.vjp(h)(y)# step 2y_bar,=h_vjp(1.0)# step 3_,g_vjp=jax.vjp(g,x)# step 4x_bar,=g_vjp(y_bar)# step 5returnx_bar
In general,jax.checkpoint(foo) is a new function which has the same input-output behavior asfoo, but behaves differently under autodiff, particularly underjax.linearize andjax.vjp (and their wrappers, likejax.grad) but notjax.jvp. When differentiated, only the input to ajax.checkpoint-differentiated function is stored on the forward pass; on the backward pass, residuals (i.e. intermediates fromfoo and its Jacobian coefficient values needed for the backward pass) are recomputed.
Notice that iff=lambdax:h(g(x)) is the function we want to differentiate, i.e. if we want to applyjax.grad(f), we don’t get any memory savings by applyingjax.checkpoint tof itself. That’s because evaluatingjax.grad(jax.checkpoint(f))(x) would lead to a computation like:
run the forward pass, discarding all residuals;
immediately re-run the forward pass, saving residuals;
run the backward pass, consuming residuals from step 2.
That is, in code we’d have something like:
deff_grad_bad1(x):_=f(x)# step 1_,f_vjp=jax.vjp(f,x)# step 2x_bar,=f_vjp(1.0)# step 3returnx_bar
We also wouldn’t get any memory savings by applyingjax.checkpoint toh, the second stage off. That’s because evaluatingjax.grad(lambdax:jax.checkpoint(h)(g(x))) would lead to a computation like:
run the forward pass of
g, saving residuals;run the forward pass of
h, discarding residuals;immediately re-run the forward pass of
h, saving residuals;run the backward pass of
h, consuming residuals from step 3;run the backward pass of
g, consuming residuals from step 1.
That is, in code we’d have something like:
deff_grad_bad2(x):y,g_vjp=jax.vjp(g,x)# step 1_z=h(y)# step 2_,h_vjp=jax.vjp(h,y)# step 3y_bar,=h_vjp(1.0)# step 3x_bar,=g_vjp(y_bar)# step 5returnx_bar
Slightly more generally, if we had a chain composition of functions, likef=lambdax:f3(f2(f1(x))), and we were interested in evaluatingjax.grad(f), we could say that:
we shouldn’t apply
jax.checkpointto the whole functionf, since that wouldn’t save any memory (and will perform wasteful recomputation);we shouldn’t apply
jax.checkpointto the last sub-functionf3, since that wouldn’t save any memory (and will perform wasteful recomputation);we could apply
jax.checkpointtof1,f2, or their compositionlambdax:f2(f1(x)), since any of those might save memory and would express different memory/recompute tradeoffs.
Custom policies for what’s saveable#
As shown so far, usingjax.checkpoint switches from one extreme to another:
without
jax.checkpoint, JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass;with a
jax.checkpointdecorator, we instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.
To operate between these two extremes, saving some things and not others, we can carefully placejax.checkpoint decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.
So an alternative is to use thepolicy argument tojax.checkpoint. A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes onjax.checkpoint_policies, likejax.checkpoint_policies.dots_with_no_batch_dims_saveable, since the API for writing custom policy callables is considered internal.
For example, consider this function to be differentiated:
defloss(params,x,y):returnjnp.sum((predict(params,x)-y)**2)defpredict(params,x):*Ws,Wlast=paramsforWinWs:x=layer(W,x)x=jnp.dot(Wlast,x)returnxdeflayer(W,x):returnjnp.sin(jnp.dot(W,x))
W1=W2=W3=jnp.ones((4,4))params=[W1,W2,W3]x=jnp.ones(4)y=jnp.ones(4)
print_saved_residuals(loss,params,x,y)
f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4] from the argument 'x'f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
Instead of saving so many values on the forward pass, perhaps we only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). We can do that using the policyjax.checkpoint_policies.dots_with_no_batch_dims_saveable:
loss_checkpoint=jax.checkpoint(loss,policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)print_saved_residuals(loss_checkpoint,params,x,y)
f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4] from the argument 'x'f32[4] from the argument 'y'f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)
Notice also that by providing a policy, we didn’t need to edit the code definingloss,predict, orlayer. That is particularly convenient if we want to experiment with policies in calling code (e.g. a training script) without changing library code (e.g. the neural network library).
Some policies can refer to values named withjax.ad_checkpoint.checkpoint_name:
defpredict(params,x):*Ws,Wlast=paramsfori,Winenumerate(Ws):x=layer(W,x)x=checkpoint_name(x,name=f'layer{i}_output')x=jnp.dot(Wlast,x)returnx
By itself,checkpoint_name is just an identity function. But because some policy functions know to look for them, we can use the names to control whether certain values output bycheckpoint_name are considered saveable:
print_saved_residuals(loss,params,x,y)
f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4] from the argument 'x'f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2=jax.checkpoint(loss,policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))print_saved_residuals(loss_checkpoint2,params,x,y)
f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4,4] from the argument 'params'f32[4] from the argument 'x'f32[4] from the argument 'y'
Another policy which refers to names isjax.checkpoint_policies.save_only_these_names.
A list of policies can be foundhere.
Policies only indicate what is saveable; a value is only saved if it’s actually needed by the backward pass.
Advanced: recursivejax.checkpoint#
By applyingjax.checkpoint in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example isrecursive checkpointing, where we applyjax.checkpoint to a function which itself callsjax.checkpoint-decorated functions in a way so that memory usage from the chain composition of\(D\) functions scales like\(\mathcal{O}(\log_2 D)\) rather than\(\mathcal{O}(D)\).
As a toy example, consider the chain composition of multiplejnp.sin functions:
defchain_compose(funs):deff(x):forfuninfuns:x=fun(x)returnxreturnff=chain_compose([jnp.sin]*8)print_saved_residuals(f,3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
In general, the number of stored residuals scales linearly with the length of the chain:
f=chain_compose([jnp.sin]*16)print_saved_residuals(f,3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
But we can applyjax.checkpoint recursively to improve the scaling:
defrecursive_checkpoint(funs):iflen(funs)==1:returnfuns[0]eliflen(funs)==2:f1,f2=funsreturnlambdax:f1(f2(x))else:f1=recursive_checkpoint(funs[:len(funs)//2])f2=recursive_checkpoint(funs[len(funs)//2:])returnlambdax:f1(jax.checkpoint(f2)(x))
f=recursive_checkpoint([jnp.sin]*8)print_saved_residuals(f,3.)
f32[] from the argument 'x'f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f=recursive_checkpoint([jnp.sin]*16)print_saved_residuals(f,3.)
f32[] from the argument 'x'f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
The cost here, as usual, is recomputation: in particular, we end up performing\(\mathcal{O}(\log_2 D)\) times as many FLOPs:
f=chain_compose([jnp.sin]*8)print_fwd_bwd(f,3.)
forward computation:backward computation: {lambda; a:f32[].let {lambda; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[].letb:f32[] = sin aj:f32[] = mul i a c:f32[] = cos a k:f32[] = mul j b d:f32[] = sin b l:f32[] = mul k c e:f32[] = cos b m:f32[] = mul l d f:f32[] = sin d n:f32[] = mul m e g:f32[] = cos d o:f32[] = mul n f h:f32[] = sin f p:f32[] = mul o g i:f32[] = cos f q:f32[] = mul p h j:f32[] = sin hin(q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos nin(p, q, o, m, k, i, g, e, c) }
f=recursive_checkpoint([jnp.sin]*8)print_fwd_bwd(f,3.)
forward computation:backward computation: {lambda; a:f32[].let {lambda; a:f32[] b:f32[] c:f32[] d:f32[].letb:f32[] = remat2[e:f32[] = mul d a differentiated=False f:f32[] = mul e b jaxpr={lambda; c:f32[].let d:f32[] = sin c; e:f32[] = sin din(e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={lambda; h:f32[] i:f32[].let ] aj:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin iin(n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] c fin(l, m, k, g, a) } o:f32[] = remat2[ differentiated=True jaxpr={lambda; p:f32[] q:f32[].letr:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={lambda; z:f32[] ba:f32[].letbb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bcin(bf,) } policy=None prevent_cse=True ] p xin(y,) } policy=None prevent_cse=True ] 3.0 gin(o,) }
Practical notes#
When differentiated functions are staged out to XLA for compilation, for example by applyingjax.jit to a function which contains ajax.grad call, XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result,jax.checkpoint often isn’t needed for differentiated functions under ajax.jit. XLA will optimize things for you.
One exception is when using staged-out control flow, likejax.lax.scan. Automatic compiler optimizations across multiple control flow primitives, e.g. across a forward-passscan and the corresponding backward-passscan, typically aren’t as thorough. As a result, it’s often a good idea to usejax.checkpoint on the body function passed tojax.lax.scan.
For example, one common pattern in largeTransformer models is to express the architecture as ajax.lax.scan over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
LayerParam=tuple[jnp.ndarray,jnp.ndarray]# weights, bias pair for a layerParamsList=list[LayerParam]defnet(params:ParamsList,x:jnp.ndarray):forW,binparams:x=jnp.maximum(jnp.dot(x,W)+b,0.)returnx
We would instead iterate over the layer application withjax.lax.scan:
StackedWeights=jnp.ndarray# all weight matrices stacked togetherStackedBiases=jnp.ndarray# all bias vectors stacked togetherall_weights=jnp.stack([WforW,_inparams])all_biases=jnp.stack([bfor_,binparams])deflayer(x,W_b_pair):W,b=W_b_pairout=jnp.maximum(jnp.dot(x,W)+b,0.)returnout,Nonedefnet(all_weights,all_biases,x):x,_=jax.lax.scan(layer,x,(all_weights,all_biases))returnx
This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, we would usejax.checkpoint on the scanned function:
fromfunctoolsimportpartial@partial(jax.checkpoint,policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)deflayer(x,W_b_pair):W,b=W_b_pairout=jnp.maximum(jnp.dot(x,W)+b,0.)returnout,None
By usingjax.checkpoint this way, we’re manually controlling which values JAX’s autodiff saves between the forward and backward passes, and hence not relying on XLA optimizations to choose for us.
