jax.experimental.jet module
Contents
jax.experimental.jet module#
Jet is an experimental module for higher-order automatic differentiationthat does not rely on repeated first-order automatic differentiation.
How? Through the propagation of truncated Taylor polynomials.Consider a function\(f = g \circ h\), some point\(x\)and some offset\(v\).First-order automatic differentiation (such asjax.jvp())computes the pair\((f(x), \partial f(x)[v])\) from the pair\((h(x), \partial h(x)[v])\).
jet() implements the higher-order analogue:Given the tuple
which represents a\(K\)-th order Taylor approximationof\(h\) at\(x\),jet() returns a\(K\)-th orderTaylor approximation of\(f\) at\(x\),
More specifically,jet() computes
and can thus be used for high-orderautomatic differentiation of\(f\).Details are explained inthese notes.
Note
Help improvejet() by contributingoutstanding primitive rules.
API#
- jax.experimental.jet.jet(fun,primals,series,factorial_scaled=True,**_)[source]#
Taylor-mode higher-order automatic differentiation.
- Parameters:
fun – Function to be differentiated. Its arguments should be arrays, scalars,or standard Python containers of arrays or scalars. It should return anarray, scalar, or standard Python container of arrays or scalars.
primals – The primal values at which the Taylor approximation of
funshould beevaluated. Should be either a tuple or a list of arguments,and its length should be equal to the number of positional parameters offun.series – Higher order Taylor-series-coefficients.Together,primals andseries make up a truncated Taylor polynomial.Should be either a tuple or a list of tuples or lists,and its length dictates the degree of the truncated Taylor polynomial.
factorial_scaled – If True, each term in both the input and output series is scaledby the factorial of its order, so that the input and output series is aTaylor series. This is the default behavior so that the n-th order termin the input and output series is the n-th order derivative of the function.If False, the input and output series are the non-factorial scaled Taylorcoefficients (i.e., the constant coefficients for each term in the Taylorseries).
- Returns:
A
(primals_out,series_out)pair, whereprimals_outisfun(*primals),and together,primals_outandseries_outare atruncated Taylor polynomial of\(f(h(\cdot))\).Theprimals_outvalue has the same Python tree structure asprimals,and theseries_outvalue the same Python tree structure asseries.
For example:
>>>importjax>>>importjax.numpyasnp
Consider the function\(h(z) = z^3\),\(x = 0.5\),and the first few Taylor coefficients\(h_0=x^3\),\(h_1=3x^2\), and\(h_2=6x\).Let\(f(y) = \sin(y)\).
>>>h0,h1,h2=0.5**3.,3.*0.5**2.,6.*0.5>>>f,df,ddf=np.sin,np.cos,lambda*args:-np.sin(*args)
jet()returns the Taylor coefficients of\(f(h(z)) = \sin(z^3)\)according to Faà di Bruno’s formula:>>>f0,(f1,f2)=jet(f,(h0,),((h1,h2),))>>>print(f0,f(h0))0.12467473 0.12467473
>>>print(f1,df(h0)*h1)0.74414825 0.74414825
>>>print(f2,ddf(h0)*h1**2+df(h0)*h2)2.9064636 2.9064634
