Automatic differentiation
Contents
Automatic differentiation#
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system.Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
3. Differentiating with respect to nested lists, tuples, and dicts
4. Evaluating a function and its gradient using jax.value_and_grad
Make sure to also check out theAdvanced automatic differentiation tutorial for more advanced topics.
While understanding how automatic differentiation works “under the hood” isn’t crucial for using JAX in most contexts, you are encouraged to check out this quite accessiblevideo to get a deeper sense of what’s going on.
1. Taking gradients withjax.grad#
In JAX, you can differentiate a scalar-valued function with thejax.grad() transformation:
importjaximportjax.numpyasjnpfromjaximportgradgrad_tanh=grad(jnp.tanh)print(grad_tanh(2.0))
0.070650816
jax.grad() takes a function and returns a function. If you have a Python functionf that evaluates the mathematical function\(f\), thenjax.grad(f) is a Python function that evaluates the mathematical function\(\nabla f\). That meansgrad(f)(x) represents the value\(\nabla f(x)\).
Sincejax.grad() operates on functions, you can apply it to its own output to differentiate as many times as you like:
print(grad(grad(jnp.tanh))(2.0))print(grad(grad(grad(jnp.tanh)))(2.0))
-0.136218680.25265405
JAX’s autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. This can be illustrated in the single-variable case:
The derivative of\(f(x) = x^3 + 2x^2 - 3x + 1\) can be computed as:
f=lambdax:x**3+2*x**2-3*x+1dfdx=jax.grad(f)
The higher-order derivatives of\(f\) are:
Computing any of these in JAX is as easy as chaining thejax.grad() function:
d2fdx=jax.grad(dfdx)d3fdx=jax.grad(d2fdx)d4fdx=jax.grad(d3fdx)
Evaluating the above in\(x=1\) would give you:
Using JAX:
print(dfdx(1.))print(d2fdx(1.))print(d3fdx(1.))print(d4fdx(1.))
4.010.06.00.0
2. Computing gradients in a linear logistic regression#
The next example shows how to compute gradients withjax.grad() in a linear logistic regression model. First, the setup:
key=jax.random.key(0)defsigmoid(x):return0.5*(jnp.tanh(x/2)+1)# Outputs probability of a label being true.defpredict(W,b,inputs):returnsigmoid(jnp.dot(inputs,W)+b)# Build a toy dataset.inputs=jnp.array([[0.52,1.12,0.77],[0.88,-1.08,0.15],[0.52,0.06,-1.30],[0.74,-2.49,1.39]])targets=jnp.array([True,True,False,True])# Training loss is the negative log-likelihood of the training examples.defloss(W,b):preds=predict(W,b,inputs)label_probs=preds*targets+(1-preds)*(1-targets)return-jnp.sum(jnp.log(label_probs))# Initialize random model coefficientskey,W_key,b_key=jax.random.split(key,3)W=jax.random.normal(W_key,(3,))b=jax.random.normal(b_key,())
Use thejax.grad() function with itsargnums argument to differentiate a function with respect to positional arguments.
# Differentiate `loss` with respect to the first positional argument:W_grad=grad(loss,argnums=0)(W,b)print(f'{W_grad=}')# Since argnums=0 is the default, this does the same thing:W_grad=grad(loss)(W,b)print(f'{W_grad=}')# But you can choose different values too, and drop the keyword:b_grad=grad(loss,1)(W,b)print(f'{b_grad=}')# Including tuple valuesW_grad,b_grad=grad(loss,(0,1))(W,b)print(f'{W_grad=}')print(f'{b_grad=}')
W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)b_grad=Array(-0.69001776, dtype=float32)W_grad=Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32)b_grad=Array(-0.69001776, dtype=float32)
Thejax.grad() API has a direct correspondence to the excellent notation in Spivak’s classicCalculus on Manifolds (1965), also used in Sussman and Wisdom’sStructure and Interpretation of Classical Mechanics (2015) and theirFunctional Differential Geometry (2013). Both books are open-access. See in particular the “Prologue” section ofFunctional Differential Geometry for a defense of this notation.
Essentially, when using theargnums argument, iff is a Python function for evaluating the mathematical function\(f\), then the Python expressionjax.grad(f,i) evaluates to a Python function for evaluating\(\partial_i f\).
3. Differentiating with respect to nested lists, tuples, and dicts#
Due to JAX’s PyTree abstraction (seePytrees), differentiating withrespect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
Continuing the previous example:
defloss2(params_dict):preds=predict(params_dict['W'],params_dict['b'],inputs)label_probs=preds*targets+(1-preds)*(1-targets)return-jnp.sum(jnp.log(label_probs))print(grad(loss2)({'W':W,'b':b}))
{'W': Array([-0.433146 , -0.7354605, -1.2598922], dtype=float32), 'b': Array(-0.69001776, dtype=float32)}You can createCustom pytree nodes to work with not justjax.grad() but other JAX transformations (jax.jit(),jax.vmap(), and so on).
4. Evaluating a function and its gradient usingjax.value_and_grad#
Another convenient function isjax.value_and_grad() for efficiently computing both a function’s value as well as its gradient’s value in one pass.
Continuing the previous examples:
loss_value,Wb_grad=jax.value_and_grad(loss,(0,1))(W,b)print('loss value',loss_value)print('loss value',loss(W,b))
loss value 2.9729187loss value 2.9729187
5. Checking against numerical differences#
A great thing about derivatives is that they’re straightforward to check with finite differences.
Continuing the previous examples:
# Set a step size for finite differences calculationseps=1e-4# Check b_grad with scalar finite differencesb_grad_numerical=(loss(W,b+eps/2.)-loss(W,b-eps/2.))/epsprint('b_grad_numerical',b_grad_numerical)print('b_grad_autodiff',grad(loss,1)(W,b))# Check W_grad with finite differences in a random directionkey,subkey=jax.random.split(key)vec=jax.random.normal(subkey,W.shape)unitvec=vec/jnp.sqrt(jnp.vdot(vec,vec))W_grad_numerical=(loss(W+eps/2.*unitvec,b)-loss(W-eps/2.*unitvec,b))/epsprint('W_dirderiv_numerical',W_grad_numerical)print('W_dirderiv_autodiff',jnp.vdot(grad(loss)(W,b),unitvec))
b_grad_numerical -0.6890297b_grad_autodiff -0.69001776W_dirderiv_numerical 1.3041496W_dirderiv_autodiff 1.3006744
JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:
fromjax.test_utilimportcheck_gradscheck_grads(loss,(W,b),order=2)# check up to 2nd order derivatives
Next steps#
TheAdvanced automatic differentiation tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such asCustom derivative rules for JAX-transformable Python functions, depend on understanding advanced automatic differentiation, so do check out that section in theAdvanced automatic differentiation tutorial if you are interested.
