Gradient checkpointing with jax.checkpoint (jax.remat)
Contents
Gradient checkpointing withjax.checkpoint (jax.remat)#
In this tutorial, you will learn how to control JAX automatic differentiation’s saved values usingjax.checkpoint() (also known asjax.remat()), which can be particularly helpful in machine learning.
If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has anAutomatic differentiation tutorial and severalAdvanced automatic differentiation guides.
TL;DR Use thejax.checkpoint() decorator (aliased asjax.remat()) withjax.grad() to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs.
If you don’t usejax.checkpoint(), thejax.grad(f)(x) forward pass stores Jacobian coefficients and other intermediates to use during the backward pass. These saved values are calledresiduals.
Note: Don’t miss thePractical notes for a discussion about howjax.checkpoint() interacts withjax.jit().
importjaximportjax.numpyasjnpdefg(W,x):y=jnp.dot(W,x)returnjnp.sin(y)deff(W1,W2,W3,x):x=g(W1,x)x=g(W2,x)x=g(W3,x)returnxW1=jnp.ones((5,4))W2=jnp.ones((6,5))W3=jnp.ones((7,6))x=jnp.ones(4)# Inspect the 'residual' values to be saved on the forward pass# if you were to evaluate `jax.grad(f)(W1, W2, W3, x)`fromjax.ad_checkpointimportprint_saved_residualsprint_saved_residuals(f,W1,W2,W3,x)
f32[5,4] from the argument W1f32[6,5] from the argument W2f32[7,6] from the argument W3f32[4] from the argument xf32[5] output of sin from /tmp/ipykernel_1818/1857807639.py:6:9 (g)f32[5] output of cos from /tmp/ipykernel_1818/1857807639.py:6:9 (g)f32[6] output of sin from /tmp/ipykernel_1818/1857807639.py:6:9 (g)f32[6] output of cos from /tmp/ipykernel_1818/1857807639.py:6:9 (g)f32[7] output of cos from /tmp/ipykernel_1818/1857807639.py:6:9 (g)
By applyingjax.checkpoint() to sub-functions, as a decorator or at specific application sites, you force JAX not to save any of that sub-function’s residuals. Instead, only the inputs of ajax.checkpoint()-decorated function might be saved, and any residuals consumed on the backward pass are re-computed from those inputs as needed:
deff2(W1,W2,W3,x):x=jax.checkpoint(g)(W1,x)x=jax.checkpoint(g)(W2,x)x=jax.checkpoint(g)(W3,x)returnxprint_saved_residuals(f2,W1,W2,W3,x)
f32[5,4] from the argument W1f32[6,5] from the argument W2f32[7,6] from the argument W3f32[4] from the argument xf32[5] output of sin from /tmp/ipykernel_1818/1857807639.py:6:9 (g)f32[6] output of sin from /tmp/ipykernel_1818/1857807639.py:6:9 (g)
Here, the values of twosin applications are saved because they are argumentsin subsequent applications of thejax.checkpoint()-decoratedg function, andinputs to ajax.checkpoint()-decorated function may be saved. But no values ofcos applications are saved.
To control which values are saveable without having to edit the definition of the function to be differentiated, you can use a rematerializationpolicy. Here is an example that saves only the results ofdot operations with no batch dimensions (since they are often FLOP-bound, and hence worth saving rather than recomputing):
f3=jax.checkpoint(f,policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)print_saved_residuals(f3,W1,W2,W3,x)
f32[5,4] from the argument W1f32[6,5] from the argument W2f32[7,6] from the argument W3f32[4] from the argument xf32[5] output of reduce_precision from /tmp/ipykernel_1818/1857807639.py:5:6 (g)f32[6] output of reduce_precision from /tmp/ipykernel_1818/1857807639.py:5:6 (g)f32[7] output of reduce_precision from /tmp/ipykernel_1818/1857807639.py:5:6 (g)
You can also use policies to refer to intermediate values you name usingjax.ad_checkpoint.checkpoint_name():
fromjax.ad_checkpointimportcheckpoint_namedeff4(W1,W2,W3,x):x=checkpoint_name(g(W1,x),name='a')x=checkpoint_name(g(W2,x),name='b')x=checkpoint_name(g(W3,x),name='c')returnxf4=jax.checkpoint(f4,policy=jax.checkpoint_policies.save_only_these_names('a'))print_saved_residuals(f4,W1,W2,W3,x)
f32[5,4] from the argument W1f32[6,5] from the argument W2f32[7,6] from the argument W3f32[4] from the argument xf32[5] output of reduce_precision from /tmp/ipykernel_1818/3722338705.py:4:6 (f4)
When playing around with these toy examples, you can get a closer look at what’s going on using a customprint_fwd_bwd utility defined in this notebook:
fromjax.tree_utilimporttree_flatten,tree_unflattenfromrich.consoleimportConsolefromrich.tableimportTableimportrich.textdefprint_fwd_bwd(f,*args,**kwargs)->None:args,in_tree=tree_flatten((args,kwargs))deff_(*args):args,kwargs=tree_unflatten(in_tree,args)returnf(*args,**kwargs)fwd=jax.make_jaxpr(lambda*args:jax.vjp(f_,*args))(*args).jaxpry,f_vjp=jax.vjp(f_,*args)res,in_tree=tree_flatten(f_vjp)defg_(*args):*res,y=argsf_vjp=tree_unflatten(in_tree,res)returnf_vjp(y)bwd=jax.make_jaxpr(g_)(*res,y).jaxprtable=Table(show_header=False,show_lines=True,padding=(1,2,0,2),box=None)table.add_row("[bold green]forward computation:","[bold green]backward computation:")table.add_row(rich.text.Text.from_ansi(str(fwd)),rich.text.Text.from_ansi(str(bwd)))console=Console(width=240,force_jupyter=True)console.print(table)def_renderable_repr(self):returnself.htmlrich.jupyter.JupyterRenderable._repr_html_=_renderable_repr
# Without using `jax.checkpoint`:print_fwd_bwd(f,W1,W2,W3,x)
forward computation:backward computation: {lambda; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4].let {lambda; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4] e:f32[5] f:f32[5] g:f32[6] h:f32[6]e:f32[5] = dot_general[ i:f32[7] j:f32[7].let dimension_numbers=(([1], [0]), ([], []))k:f32[7] = mul j i preferred_element_type=float32 l:f32[6] = dot_general[ ] a d dimension_numbers=(([0], [0]), ([], [])) f:f32[5] = sin e preferred_element_type=float32 g:f32[5] = cos e ] k c h:f32[6] = dot_general[ m:f32[7,6] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] b f ] k h i:f32[6] = sin h n:f32[6] = mul l g j:f32[6] = cos h o:f32[5] = dot_general[ k:f32[7] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 preferred_element_type=float32 ] n b ] c i p:f32[6,5] = dot_general[ l:f32[7] = sin k dimension_numbers=(([], []), ([], [])) m:f32[7] = cos k preferred_element_type=float32in(l, a, b, c, d, g, f, j, i, m) } ] n f q:f32[5] = mul o e r:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] q a s:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] q din(s, p, m, r) }
# Using `jax.checkpoint` with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:print_fwd_bwd(f3,W1,W2,W3,x)
forward computation:backward computation: {lambda; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4].let {lambda; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4] e:f32[5] f:f32[6] g:f32[7] h:f32[7].lete:f32[5] = dot_general[i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ dimension_numbers=(([1], [0]), ([], [])) differentiated=True preferred_element_type=float32 jaxpr={lambda; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] ] a d s:f32[4] t:f32[7].let f:f32[5] = reduce_precision[exponent_bits=8 mantissa_bits=23] eu:f32[5] = sin m g:f32[5] = sin f v:f32[5] = cos m h:f32[6] = dot_general[ w:f32[6] = sin n dimension_numbers=(([1], [0]), ([], [])) x:f32[6] = cos n preferred_element_type=float32 y:f32[7] = cos o ] b g z:f32[7] = mul t y i:f32[6] = reduce_precision[exponent_bits=8 mantissa_bits=23] h ba:f32[6] = dot_general[ j:f32[6] = sin i dimension_numbers=(([0], [0]), ([], [])) k:f32[7] = dot_general[ preferred_element_type=float32 dimension_numbers=(([1], [0]), ([], [])) ] z r preferred_element_type=float32 bb:f32[7,6] = dot_general[ ] c j dimension_numbers=(([], []), ([], [])) l:f32[7] = reduce_precision[exponent_bits=8 mantissa_bits=23] k preferred_element_type=float32 m:f32[7] = sin l ] z win(m, a, b, c, d, f, i, l) } bc:f32[6] = mul ba x bd:f32[5] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] bc q be:f32[6,5] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bc u bf:f32[5] = mul bd v bg:f32[4] = dot_general[ dimension_numbers=(([0], [0]), ([], [])) preferred_element_type=float32 ] bf p bh:f32[5,4] = dot_general[ dimension_numbers=(([], []), ([], [])) preferred_element_type=float32 ] bf sin(bh, be, bb, bg) } policy=<function dots_with_no_batch_dims_saveable at 0x74a8d0c12c00> prevent_cse=True ] e f g a b c d hin(i, j, k, l) }
Let’s think step by step#
Note: It may help to check out the“Advanced automatic differentiation” guides prior to continuing here.
jax.checkpoint fundamentals#
In bothjax.linearize() andjax.vjp(), there is flexibility in how and when some values are computed. Different choices can trade off memory use against FLOPs. JAX provides control over these choices withjax.checkpoint().
One such choice is whether to perform Jacobian coefficient computations on the forward pass, as soon as the inputs are available, or on the backward pass, just before the coefficients are needed. Consider the example ofsin_vjp:
defsin_vjp(x):y=jnp.sin(x)cos_x=jnp.cos(x)returny,lambday_bar:cos_x*y_bar
Another valid implementation would compute the value ofjnp.cos(x) on the backward pass rather than on the forward pass:
defsin_vjp2(x):y=jnp.sin(x)returny,lambday_bar:jnp.cos(x)*y_bar
For this particular function, the amount of memory used by the two versions is the same, though you’ve reduced the FLOPs for the primal computation (the forward pass) and increased the FLOPs for the cotangent computation (the backward pass).
There’s another choice when it comes to function composition. Recall the VJP rule for a composition of two functions:
deff(x):y=g(x)z=h(y)returnzdeff_vjp(x):y,g_vjp=jax.vjp(g,x)z,h_vjp=jax.vjp(h,y)deff_bwd(z_bar):y_bar,=h_vjp(z_bar)x_bar,=g_vjp(y_bar)returnx_barreturnz,f_bwd
An alternative is:
deff_vjp_checkpoint(x):y=g(x)z,h_vjp=jax.vjp(h,y)deff_bwd2(z_bar):y_bar,=h_vjp(z_bar)_,g_vjp=jax.vjp(g,x)x_bar,=g_vjp(y_bar)returnx_barreturnz,f_bwd2
Using words, this alternative implementation doesn’t computeg_vjp, or the residual values in its closure, on the forward pass. Instead, it only computes them in the backward passf_bwd2. That meansf_vjp_checkpoint requires less memory: ifg andh each required similar amounts of memory for their residuals, each much larger thanx, then the function produced byf_vjp_checkpoint(x) requires half the memory as that off_vjp(x)!
The cost you pay is redundant work: inf_bwd2 you must re-evaluateg(x) as part ofjax.vjp(g,x) just to discard its value (in the underscore variable on the line_,g_vjp=jax.vjp(g,x)).
You can get this VJP behavior in autodiff — without having to write VJP functions directly — by instead usingjax.checkpoint() in an alternative definition of the original functionf:
deff_checkpoint(x):y=jax.checkpoint(g)(x)z=h(y)returnz
In other words, you applyjax.checkpoint() tog — the first stage off — rather than tof itself. This way, when you evaluatejax.grad(f_checkpoint)(x), you’d get a computation like:
Run the forward pass of
g, discarding residual values.Run the forward pass of
h, saving residuals.Run the backward pass of
h, consuming residuals from step 2.Re-run the forward pass of
g, saving residuals.Run the backward pass of
g, consuming residuals from step 4.
That is, by evaluatingjax.grad(f_checkpoint)(x) we’d get the same computation as:
deff_checkpoint_grad(x):y=g(x)# step 1_,h_vjp=jax.vjp(h)(y)# step 2y_bar,=h_vjp(1.0)# step 3_,g_vjp=jax.vjp(g,x)# step 4x_bar,=g_vjp(y_bar)# step 5returnx_bar
In general,jax.checkpoint(foo) is a new function which has the same input-output behavior asfoo, but behaves differently under autodiff, particularly underjax.linearize() andjax.vjp() (and their wrappers, likejax.grad()) but notjax.jvp(). When differentiated, only the input to ajax.checkpoint()-differentiated function is stored on the forward pass. On the backward pass, the residuals (intermediates fromfoo and its Jacobian coefficient values needed for the backward pass) are recomputed.
Notice that iff=lambdax:h(g(x)) is the function you want to differentiate (in other words, if you want to applyjax.grad(f)) you don’t get any memory savings by applyingjax.checkpoint() tof itself. That’s because evaluatingjax.grad(jax.checkpoint(f))(x) would lead to a computation, such as:
Run the forward pass, discarding all residuals.
Immediately re-run the forward pass, saving residuals.
Run the backward pass, consuming residuals from step 2.
In code, you’d have something like:
deff_grad_bad(x):_=f(x)# step 1_,f_vjp=jax.vjp(f,x)# step 2x_bar,=f_vjp(1.0)# step 3returnx_bar
You also wouldn’t get any memory savings by applyingjax.checkpoint() toh, the second stage off. That’s because evaluatingjax.grad(lambdax:jax.checkpoint(h)(g(x))) would lead to a computation, such as:
Run the forward pass of
g, saving residuals.Run the forward pass of
h, discarding residuals.Immediately re-run the forward pass of
h, saving residuals.Run the backward pass of
h, consuming residuals from step 3.Run the backward pass of
g, consuming residuals from step 1.
In code you’d have something like:
deff_grad_bad2(x):y,g_vjp=jax.vjp(g,x)# step 1z=h(y)# step 2_,h_vjp=jax.vjp(h,y)# step 3y_bar,=h_vjp(1.0)# step 3x_bar,=g_vjp(y_bar)# step 5returnx_bar
Slightly more generally, if you had a chain composition of functions, such asf=lambdax:f3(f2(f1(x))), and were interested in evaluatingjax.grad(f), you could say that you:
Shouldn’t apply
jax.checkpoint()to the whole functionf, since that wouldn’t save any memory (and will perform wasteful recomputation).Shouldn’t apply
jax.checkpoint()to the last sub-functionf3, since that wouldn’t save any memory (and will perform wasteful recomputation).Could apply
jax.checkpoint()tof1,f2, or their compositionlambdax:f2(f1(x)), since any of those might save memory and would express different memory/recompute tradeoffs.
Custom policies for what’s saveable#
As shown so far, usingjax.checkpoint() switches from one extreme to another:
Without
jax.checkpoint(), JAX’s autodiff tends to compute everything possible on the forward pass and store it for the backward pass.With a
jax.checkpoint()decorator, you instead compute as little as possible on the forward pass and recompute values as needed on the backward pass.
To operate between these two extremes, saving some things and not others, you can carefully placejax.checkpoint() decorators on sub-functions. But that requires editing the function to be differentiated, e.g. model code, which may be inconvenient. It can also be hard to experiment with variations.
So an alternative is to use thepolicy argument tojax.checkpoint(). A policy is a callable (i.e. a function) which takes as input a type-level specification of a first order primitive application and returns a boolean indicating whether the corresponding output value(s) are allowed to be saved as residuals (or instead must be recomputed in the (co)tangent computation as needed). To write robust code, a policy should be selected from the attributes onjax.checkpoint_policies, likejax.checkpoint_policies.dots_with_no_batch_dims_saveable(), since the API for writing custom policy callables is considered internal.
For example, consider this function to be differentiated:
defloss(params,x,y):returnjnp.sum((predict(params,x)-y)**2)defpredict(params,x):*Ws,Wlast=paramsforWinWs:x=layer(W,x)x=jnp.dot(Wlast,x)returnxdeflayer(W,x):returnjnp.sin(jnp.dot(W,x))
W1=W2=W3=jnp.ones((4,4))params=[W1,W2,W3]x=jnp.ones(4)y=jnp.ones(4)
print_saved_residuals(loss,params,x,y)
f32[4,4] from the argument params[0]f32[4,4] from the argument params[1]f32[4,4] from the argument params[2]f32[4] from the argument xf32[4] output of sin from /tmp/ipykernel_1818/4230705069.py:12:9 (layer)f32[4] output of cos from /tmp/ipykernel_1818/4230705069.py:12:9 (layer)f32[4] output of sin from /tmp/ipykernel_1818/4230705069.py:12:9 (layer)f32[4] output of cos from /tmp/ipykernel_1818/4230705069.py:12:9 (layer)f32[4] output of mul from /tmp/ipykernel_1818/4230705069.py:2:17 (loss)
Instead of saving so many values on the forward pass, perhaps you only want to save the results of matrix multiplications with no batch dimension (since they may be FLOP- rather than memory-bound). You can do that using the policyjax.checkpoint_policies.dots_with_no_batch_dims_saveable():
loss_checkpoint=jax.checkpoint(loss,policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)print_saved_residuals(loss_checkpoint,params,x,y)
f32[4,4] from the argument params[0]f32[4,4] from the argument params[1]f32[4,4] from the argument params[2]f32[4] from the argument xf32[4] from the argument yf32[4] output of reduce_precision from /tmp/ipykernel_1818/4230705069.py:12:17 (layer)f32[4] output of reduce_precision from /tmp/ipykernel_1818/4230705069.py:12:17 (layer)f32[4] output of reduce_precision from /tmp/ipykernel_1818/4230705069.py:8:6 (predict)
Notice also that by providing a policy, you didn’t need to edit the code definingloss,predict, orlayer. That is particularly convenient if you want to experiment with policies in calling code (such as a training script) without changing library code (for example, the neural network library).
Some policies can refer to values named withjax.ad_checkpoint.checkpoint_name():
fromjax.ad_checkpointimportcheckpoint_namedefpredict(params,x):*Ws,Wlast=paramsfori,Winenumerate(Ws):x=layer(W,x)x=checkpoint_name(x,name=f'layer{i}_output')x=jnp.dot(Wlast,x)returnx
By itself,jax.ad_checkpoint.checkpoint_name() is just an identity function. But because some policy functions know to look for them, you can use the names to control whether certain values output byjax.ad_checkpoint.checkpoint_name() are considered saveable:
print_saved_residuals(loss,params,x,y)
f32[4,4] from the argument params[0]f32[4,4] from the argument params[1]f32[4,4] from the argument params[2]f32[4] from the argument xf32[4] output of cos from /tmp/ipykernel_1818/4230705069.py:12:9 (layer)f32[4] named 'layer0_output' from /tmp/ipykernel_1818/178264713.py:7:8 (predict)f32[4] output of cos from /tmp/ipykernel_1818/4230705069.py:12:9 (layer)f32[4] named 'layer1_output' from /tmp/ipykernel_1818/178264713.py:7:8 (predict)f32[4] output of mul from /tmp/ipykernel_1818/4230705069.py:2:17 (loss)
loss_checkpoint2=jax.checkpoint(loss,policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))print_saved_residuals(loss_checkpoint2,params,x,y)
f32[4,4] from the argument params[0]f32[4,4] from the argument params[1]f32[4,4] from the argument params[2]f32[4] from the argument xf32[4] from the argument y
Another policy which refers to names isjax.checkpoint_policies.save_only_these_names.
Custom policies for offload#
You may consider offloading to CPU memory instead of recomputing when checkpointing to save accelerator memory.jax.checkpoint_policies.offload_dot_with_no_batch_dims can offload the results of matrix multiplications with no batch dimensions to the CPU.
fromjaximportcheckpointdefcheckpoint_offload_dot_with_no_batch_dims(self):policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims("device","pinned_host")@functools.partial(checkpoint,policy=policy)deff(x):x=jnp.einsum('ij,jk->ik',x,x,precision=lax.Precision.HIGHEST)x=jnp.sin(x)x=jnp.einsum('ij,jk->ik',x,x,precision=lax.Precision.HIGHEST)x=jnp.sin(x)x=jnp.einsum('ij,jk->ik',x,x,precision=lax.Precision.HIGHEST)x=jnp.sin(x)x=jnp.sum(x)returnx
One of JAX’s checkpoint policies allows specified checkpoint names to be offloaded to CPUs. This policy is implemented throughjax.checkpoint_policies.save_and_offload_only_these_names, which has four arguments:names_which_can_be_saved,names_which_can_be_offloaded, the offloading source, and destination. Names listed innames_which_can_be_saved are kept on the device, names listed innames_which_can_be_offloaded are moved to CPU memory, and other names or operations without names are recomputed. For example, if we have checkpoint namesy,z, andw,y can be saved on the device,z can be offloaded to CPU memory, andw can be recomputed.
fromjaximportcheckpointfromjax.ad_checkpointimportcheckpoint_namefromjax._srcimporttest_utilasjtudefcheckpoint_names_saved_offloaded_recomputed(self):mesh=jtu.create_mesh((2,),("x",))shape=(256,128)np_inp=np.arange(math.prod(shape),dtype=np.float32).reshape(shape)s=NamedSharding(mesh,P("x"))inp=jax.device_put(np_inp,s)policy=jax.checkpoint_policies.save_and_offload_only_these_names(names_which_can_be_saved=["y"],names_which_can_be_offloaded=["z"],offload_src='device',offload_dst='pinned_host')@functools.partial(checkpoint,policy=policy)deff(x):defg(ys,_):y,_=ysy=checkpoint_name(jnp.sin(y),"y")z=checkpoint_name(jnp.sin(y),"z")z=z.Tw=checkpoint_name(jnp.sin(z),"w")return(w.T,jnp.sum(w)),None_,scan_out=jax.lax.scan(g,(x,np.array(1,dtype=np.float32)),[np_inp])[0]returnscan_out
The code defines a functionf that which applies checkpointing with a custom policy. This policy determines which computations can be saved or offloaded during execution. Insidef, there is a nested functiong that performs the core computations. Thejax.lax.scan function is used to applyg repeatedly over the input data.
List of policies#
The policies can be foundhere.
Policies only indicate what is saveable; a value is only saved if it’s actually needed by the backward pass.
Advanced: Recursivejax.checkpoint#
By applyingjax.checkpoint() in the right way, there are many tradeoffs between memory usage and (re)computation that can be expressed. One surprising example isrecursive checkpointing, where you applyjax.checkpoint() to a function which itself callsjax.checkpoint()-decorated functions in a way so that memory usage from the chain composition of\(D\) functions scales like\(\mathcal{O}(\log_2 D)\) rather than\(\mathcal{O}(D)\).
As a toy example, consider the chain composition of multiplejax.numpy.sin() functions:
defchain_compose(funs):deff(x):forfuninfuns:x=fun(x)returnxreturnff=chain_compose([jnp.sin]*8)print_saved_residuals(f,3.)
f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)
In general, the number of stored residuals scales linearly with the length of the chain:
f=chain_compose([jnp.sin]*16)print_saved_residuals(f,3.)
f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)f32[] output of cos from /tmp/ipykernel_1818/410288286.py:4:10 (chain_compose.<locals>.f)
But you can applyjax.checkpoint() recursively to improve the scaling:
defrecursive_checkpoint(funs):iflen(funs)==1:returnfuns[0]eliflen(funs)==2:f1,f2=funsreturnlambdax:f1(f2(x))else:f1=recursive_checkpoint(funs[:len(funs)//2])f2=recursive_checkpoint(funs[len(funs)//2:])returnlambdax:f1(jax.checkpoint(f2)(x))
f=recursive_checkpoint([jnp.sin]*8)print_saved_residuals(f,3.)
f32[] from the argument xf32[] output of sin from /tmp/ipykernel_1818/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)f32[] output of cos from /tmp/ipykernel_1818/1943107544.py:6:24 (recursive_checkpoint.<locals>.<lambda>)f32[] output of cos from /tmp/ipykernel_1818/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
f=recursive_checkpoint([jnp.sin]*16)print_saved_residuals(f,3.)
f32[] from the argument xf32[] output of sin from /tmp/ipykernel_1818/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)f32[] output of sin from /tmp/ipykernel_1818/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)f32[] output of cos from /tmp/ipykernel_1818/1943107544.py:6:24 (recursive_checkpoint.<locals>.<lambda>)f32[] output of cos from /tmp/ipykernel_1818/1943107544.py:6:21 (recursive_checkpoint.<locals>.<lambda>)
The cost here, as usual, is recomputation: in particular, you end up performing\(\mathcal{O}(\log_2 D)\) times as many FLOPs:
f=chain_compose([jnp.sin]*8)print_fwd_bwd(f,3.)
forward computation:backward computation: {lambda; a:f32[].let {lambda; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[].letb:f32[] = sin aj:f32[] = mul i h c:f32[] = cos a k:f32[] = mul j g d:f32[] = sin b l:f32[] = mul k f e:f32[] = cos b m:f32[] = mul l e f:f32[] = sin d n:f32[] = mul m d g:f32[] = cos d o:f32[] = mul n c h:f32[] = sin f p:f32[] = mul o b i:f32[] = cos f q:f32[] = mul p a j:f32[] = sin hin(q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos nin(p, c, e, g, i, k, m, o, q) }
f=recursive_checkpoint([jnp.sin]*8)print_fwd_bwd(f,3.)
forward computation:backward computation: {lambda; a:f32[].let {lambda; a:f32[] b:f32[] c:f32[] d:f32[].letb:f32[] = remat2[e:f32[] = mul d c differentiated=False f:f32[] = mul e b jaxpr={lambda; c:f32[].let d:f32[] = sin c; e:f32[] = sin din(e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={lambda; h:f32[] i:f32[].let ] aj:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin iin(n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] a fin(l, a, g, k, m) } o:f32[] = remat2[ differentiated=True jaxpr={lambda; p:f32[] q:f32[].letr:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={lambda; z:f32[] ba:f32[].letbb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bcin(bf,) } policy=None prevent_cse=True ] p xin(y,) } policy=None prevent_cse=True ] 3.0:f32[] gin(o,) }
Practical notes#
When differentiated functions are staged out to XLA for compilation — for example by applyingjax.jit() to a function which contains ajax.grad() call — XLA will automatically optimize the computation, including decisions about when to compute or rematerialize values. As a result,jax.checkpoint() often isn’t needed for differentiated functions under ajax.jit(). XLA will optimize things for you.
One exception is when using staged-out control flow, likejax.lax.scan(). Automatic compiler optimizations across multiple control flow primitives (for example, across a forward-passscan and the corresponding backward-passscan), typically aren’t as thorough. As a result, it’s often a good idea to usejax.checkpoint() on the body function passed tojax.lax.scan().
For example, one common pattern in largeTransformer models is to express the architecture as ajax.lax.scan() over layers so as to reduce compilation times. That is, using a simple fully-connected network as an analogy, instead of writing something like this:
LayerParam=tuple[jnp.ndarray,jnp.ndarray]# Weights-bias pair for a layer.ParamsList=list[LayerParam]defnet(params:ParamsList,x:jnp.ndarray):forW,binparams:x=jnp.maximum(jnp.dot(x,W)+b,0.)returnx
Instead, iterate over the layer application withjax.lax.scan():
params=[(jnp.array([[0.5,0.5],[1.,1.]]),jnp.array([0.5,0.5])),(jnp.array([[0.5,0.5],[1.,1.]]),jnp.array([0.5,0.5]))]all_weights=jnp.stack([WforW,_inparams])all_biases=jnp.stack([bfor_,binparams])deflayer(x,W_b_pair):W,b=W_b_pairout=jnp.maximum(jnp.dot(x,W)+b,0.)returnout,Nonedefnet(all_weights,all_biases,x):x,_=jax.lax.scan(layer,x,(all_weights,all_biases))returnx
This scan-over-layers version reduces compile times, but by foiling some compiler optimizations it can lead to inefficient computation of gradients. To mitigate the issue, you can usejax.checkpoint() on the scanned function:
fromfunctoolsimportpartial@partial(jax.checkpoint,policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)deflayer(x,W_b_pair):W,b=W_b_pairout=jnp.maximum(jnp.dot(x,W)+b,0.)returnout,None
By usingjax.checkpoint() this way, you’re manually controlling which values JAX’s autodiff saves between the forward and backward passes, and therefore not relying on XLA optimizations to choose for you.
