Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Forward- and reverse-mode autodiff in JAX#

Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)#

JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiarjax.grad() function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.

JVPs in math#

Mathematically, given a function\(f : \mathbb{R}^n \to \mathbb{R}^m\), the Jacobian of\(f\) evaluated at an input point\(x \in \mathbb{R}^n\), denoted\(\partial f(x)\), is often thought of as a matrix in\(\mathbb{R}^m \times \mathbb{R}^n\):

\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\).

But you can also think of\(\partial f(x)\) as a linear map, which maps the tangent space of the domain of\(f\) at the point\(x\) (which is just another copy of\(\mathbb{R}^n\)) to the tangent space of the codomain of\(f\) at the point\(f(x)\) (a copy of\(\mathbb{R}^m\)):

\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\).

This map is called thepushforward map of\(f\) at\(x\). The Jacobian matrix is just the matrix for this linear map on a standard basis.

If you don’t commit to one specific input point\(x\), then you can think of the function\(\partial f\) as first taking an input point and returning the Jacobian linear map at that input point:

\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\).

In particular, you can uncurry things so that given input point\(x \in \mathbb{R}^n\) and a tangent vector\(v \in \mathbb{R}^n\), you get back an output tangent vector in\(\mathbb{R}^m\). We call that mapping, from\((x, v)\) pairs to output tangent vectors, theJacobian-vector product, and write it as:

\(\qquad (x, v) \mapsto \partial f(x) v\)

JVPs in JAX code#

Back in Python code, JAX’sjax.jvp() function models this transformation. Given a Python function that evaluates\(f\), JAX’sjax.jvp() is a way to get a Python function for evaluating\((x, v) \mapsto (f(x), \partial f(x) v)\).

importjaximportjax.numpyasjnpkey=jax.random.key(0)# Initialize random model coefficientskey,W_key,b_key=jax.random.split(key,3)W=jax.random.normal(W_key,(3,))b=jax.random.normal(b_key,())# Define a sigmoid function.defsigmoid(x):return0.5*(jnp.tanh(x/2)+1)# Outputs probability of a label being true.defpredict(W,b,inputs):returnsigmoid(jnp.dot(inputs,W)+b)# Build a toy dataset.inputs=jnp.array([[0.52,1.12,0.77],[0.88,-1.08,0.15],[0.52,0.06,-1.30],[0.74,-2.49,1.39]])# Isolate the function from the weight matrix to the predictionsf=lambdaW:predict(W,b,inputs)key,subkey=jax.random.split(key)v=jax.random.normal(subkey,W.shape)# Push forward the vector `v` along `f` evaluated at `W`y,u=jax.jvp(f,(W,),(v,))

In terms ofHaskell-like type signatures, you could write:

jvp::(a->b)->a->Ta->(b,Tb)

whereTa is used to denote the type of the tangent space fora.

In other words,jvp takes as arguments a function of typea->b, a value of typea, and a tangent vector value of typeTa. It gives back a pair consisting of a value of typeb and an output tangent vector of typeTb.

Thejvp-transformed function is evaluated much like the original function, but paired up with each primal value of typea it pushes along tangent values of typeTa. For each primitive numerical operation that the original function would have applied, thejvp-transformed function executes a “JVP rule” for that primitive that both evaluates the primitive on the primals and applies the primitive’s JVP at those primal values.

That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don’t need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of thejvp-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for examplesin(x); one unit for linearizing, likecos(x); and one unit for applying the linearized function to a vector, likecos_x*v). Put another way, for a fixed primal point\(x\), we can evaluate\(v \mapsto \partial f(x) \cdot v\) for about the same marginal cost as evaluating\(f\).

That memory complexity sounds pretty compelling! So why don’t we see forward-mode very often in machine learning?

To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with “tall” Jacobians, but inefficient for “wide” Jacobians.

If you’re doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in\(\mathbb{R}^n\) to a scalar loss value in\(\mathbb{R}\). That means the Jacobian of this function is a very wide matrix:\(\partial f(x) \in \mathbb{R}^{1 \times n}\), which we often identify with the Gradient vector\(\nabla f(x) \in \mathbb{R}^n\). Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where\(f\) is a training loss function and\(n\) can be in the millions or billions, this approach just won’t scale.

To do better for functions like this, you just need to use reverse-mode.

Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff)#

Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.

VJPs in math#

Let’s again consider a function\(f : \mathbb{R}^n \to \mathbb{R}^m\).Starting from our notation for JVPs, the notation for VJPs is pretty simple:

\(\qquad (x, v) \mapsto v \partial f(x)\),

where\(v\) is an element of the cotangent space of\(f\) at\(x\) (isomorphic to another copy of\(\mathbb{R}^m\)). When being rigorous, we should think of\(v\) as a linear map\(v : \mathbb{R}^m \to \mathbb{R}\), and when we write\(v \partial f(x)\) we mean function composition\(v \circ \partial f(x)\), where the types work out because\(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\). But in the common case we can identify\(v\) with a vector in\(\mathbb{R}^m\) and use the two almost interchangeably, just like we might sometimes flip between “column vectors” and “row vectors” without much comment.

With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:

\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\).

For a given point\(x\), we can write the signature as

\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\).

The corresponding map on cotangent spaces is often called thepullbackof\(f\) at\(x\). The key for our purposes is that it goes from something that looks like the output of\(f\) to something that looks like the input of\(f\), just like we might expect from a transposed linear function.

VJPs in JAX code#

Switching from math back to Python, the JAX functionvjp can take a Python function for evaluating\(f\) and give us back a Python function for evaluating the VJP\((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\).

fromjaximportvjp# Isolate the function from the weight matrix to the predictionsf=lambdaW:predict(W,b,inputs)y,vjp_fun=vjp(f,W)key,subkey=jax.random.split(key)u=jax.random.normal(subkey,y.shape)# Pull back the covector `u` along `f` evaluated at `W`v=vjp_fun(u)

In terms ofHaskell-like type signatures, we could write

vjp::(a->b)->a->(b,CTb->CTa)

where we useCTa to denote the type for the cotangent space fora. In words,vjp takes as arguments a function of typea->b and a point of typea, and gives back a pair consisting of a value of typeb and a linear map of typeCTb->CTa.

This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating\((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) is only about three times the cost of evaluating\(f\). In particular, if we want the gradient of a function\(f : \mathbb{R}^n \to \mathbb{R}\), we can do it in just one call. That’s howjax.grad() is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.

There’s a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that’s a story for a future notebook!).

For more on how reverse-mode works, check outthis tutorial video from the Deep Learning Summer School in 2017.

Vector-valued gradients with VJPs#

If you’re interested in taking vector-valued gradients (liketf.gradients):

defvgrad(f,x):y,vjp_fn=jax.vjp(f,x)returnvjp_fn(jnp.ones(y.shape))[0]print(vgrad(lambdax:3*x**2,jnp.ones((2,2))))
[[6. 6.] [6. 6.]]

Hessian-vector products using both forward- and reverse-mode#

In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives):

defhvp(f,x,v):returnjax.grad(lambdax:jnp.vdot(jax.grad(f)(x),v))(x)

That’s efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode.

Mathematically, given a function\(f : \mathbb{R}^n \to \mathbb{R}\) to differentiate, a point\(x \in \mathbb{R}^n\) at which to linearize the function, and a vector\(v \in \mathbb{R}^n\), the Hessian-vector product function we want is:

\((x, v) \mapsto \partial^2 f(x) v\)

Consider the helper function\(g : \mathbb{R}^n \to \mathbb{R}^n\) defined to be the derivative (or gradient) of\(f\), namely\(g(x) = \partial f(x)\). All you need is its JVP, since that will give us:

\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\).

We can translate that almost directly into code:

# forward-over-reversedefhvp(f,primals,tangents):returnjax.jvp(jax.grad(f),primals,tangents)[1]

Even better, since you didn’t have to calljnp.dot() directly, thishvp function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn’t even have a dependence onjax.numpy.

Here’s an example of how to use it:

deff(X):returnjnp.sum(jnp.tanh(X)**2)key,subkey1,subkey2=jax.random.split(key,3)X=jax.random.normal(subkey1,(30,40))V=jax.random.normal(subkey2,(30,40))defhessian(f):returnjax.jacfwd(jax.jacrev(f))ans1=hvp(f,(X,),(V,))ans2=jnp.tensordot(hessian(f)(X),V,2)print(jnp.allclose(ans1,ans2,1e-4,1e-4))
True

Another way you might consider writing this is using reverse-over-forward:

# Reverse-over-forwarddefhvp_revfwd(f,primals,tangents):g=lambdaprimals:jax.jvp(f,primals,tangents)[1]returnjax.grad(g)(primals)

That’s not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:

# Reverse-over-reverse, only works for single argumentsdefhvp_revrev(f,primals,tangents):x,=primalsv,=tangentsreturnjax.grad(lambdax:jnp.vdot(jax.grad(f)(x),v))(x)print("Forward over reverse")%timeit -n10 -r3 hvp(f, (X,), (V,))print("Reverse over forward")%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))print("Reverse over reverse")%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))print("Naive full Hessian materialization")%timeit -n10 -r3 jnp.tensordot(jax.hessian(f)(X), V, 2)
Forward over reverse3.41 ms ± 143 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)Reverse over forwardThe slowest run took 6.83 times longer than the fastest. This could mean that an intermediate result is being cached.16.2 ms ± 15 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)Reverse over reverseThe slowest run took 4.57 times longer than the fastest. This could mean that an intermediate result is being cached.18.3 ms ± 14 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)Naive full Hessian materialization42.4 ms ± 972 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Composing VJPs, JVPs, andjax.vmap#

Jacobian-Matrix and Matrix-Jacobian products#

Now that you havejax.jvp() andjax.vjp() transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX’sjax.vmap()transformation to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:

# Isolate the function from the weight matrix to the predictionsf=lambdaW:predict(W,b,inputs)# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.# First, use a list comprehension to loop over rows in the matrix M.defloop_mjp(f,x,M):y,vjp_fun=jax.vjp(f,x)returnjnp.vstack([vjp_fun(mi)formiinM])# Now, use vmap to build a computation that does a single fast matrix-matrix# multiply, rather than an outer loop over vector-matrix multiplies.defvmap_mjp(f,x,M):y,vjp_fun=jax.vjp(f,x)outs,=jax.vmap(vjp_fun)(M)returnoutskey=jax.random.key(0)num_covecs=128U=jax.random.normal(key,(num_covecs,)+y.shape)loop_vs=loop_mjp(f,W,M=U)print('Non-vmapped Matrix-Jacobian product')%timeit -n10 -r3 loop_mjp(f, W, M=U)print('\nVmapped Matrix-Jacobian product')vmap_vs=vmap_mjp(f,W,M=U)%timeit -n10 -r3 vmap_mjp(f, W, M=U)assertjnp.allclose(loop_vs,vmap_vs),'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product88.3 ms ± 139 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)Vmapped Matrix-Jacobian product4.34 ms ± 48.2 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1890/3279626862.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.  return jnp.vstack([vjp_fun(mi) for mi in M])
defloop_jmp(f,W,M):# jvp immediately returns the primal and tangent values as a tuple,# so we'll compute and select the tangents in a list comprehensionreturnjnp.vstack([jax.jvp(f,(W,),(mi,))[1]formiinM])defvmap_jmp(f,W,M):_jvp=lambdas:jax.jvp(f,(W,),(s,))[1]returnjax.vmap(_jvp)(M)num_vecs=128S=jax.random.normal(key,(num_vecs,)+W.shape)loop_vs=loop_jmp(f,W,M=S)print('Non-vmapped Jacobian-Matrix product')%timeit -n10 -r3 loop_jmp(f, W, M=S)vmap_vs=vmap_jmp(f,W,M=S)print('\nVmapped Jacobian-Matrix product')%timeit -n10 -r3 vmap_jmp(f, W, M=S)assertjnp.allclose(loop_vs,vmap_vs),'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product85.4 ms ± 241 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)Vmapped Jacobian-Matrix product1.41 ms ± 36.5 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

The implementation ofjax.jacfwd andjax.jacrev#

Now that we’ve seen fast Jacobian-matrix and matrix-Jacobian products, it’s not hard to guess how to writejax.jacfwd() andjax.jacrev(). We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.

fromjaximportjacrevasbuiltin_jacrevdefour_jacrev(f):defjacfun(x):y,vjp_fun=jax.vjp(f,x)# Use vmap to do a matrix-Jacobian product.# Here, the matrix is the Euclidean basis, so we get all# entries in the Jacobian at once.J,=jax.vmap(vjp_fun,in_axes=0)(jnp.eye(len(y)))returnJreturnjacfunassertjnp.allclose(builtin_jacrev(f)(W),our_jacrev(f)(W)),'Incorrect reverse-mode Jacobian results!'
fromjaximportjacfwdasbuiltin_jacfwddefour_jacfwd(f):defjacfun(x):_jvp=lambdas:jax.jvp(f,(x,),(s,))[1]Jt=jax.vmap(_jvp,in_axes=1)(jnp.eye(len(x)))returnjnp.transpose(Jt)returnjacfunassertjnp.allclose(builtin_jacfwd(f)(W),our_jacfwd(f)(W)),'Incorrect forward-mode Jacobian results!'

Interestingly, theAutograd library couldn’t do this. Theimplementation of reverse-modejacobian in Autograd had to pull back one vector at a time with an outer-loopmap. Pushing one vector at a time through the computation is much less efficient than batching it all together withjax.vmap().

Another thing that Autograd couldn’t do isjax.jit(). Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always usejax.jit() on the linear part of the computation. For example:

deff(x):try:ifx<3:return2*x**3else:raiseValueErrorexceptValueError:returnjnp.pi*xy,f_vjp=jax.vjp(f,4.)print(jax.jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)

[8]ページ先頭

©2009-2026 Movatter.jp