Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.experimental.jet module

Contents

jax.experimental.jet module#

Jet is an experimental module for higher-order automatic differentiationthat does not rely on repeated first-order automatic differentiation.

How? Through the propagation of truncated Taylor polynomials.Consider a function\(f = g \circ h\), some point\(x\)and some offset\(v\).First-order automatic differentiation (such asjax.jvp())computes the pair\((f(x), \partial f(x)[v])\) from the pair\((h(x), \partial h(x)[v])\).

jet() implements the higher-order analogue:Given the tuple

\[(h_0, ... h_K) :=(h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),\]

which represents a\(K\)-th order Taylor approximationof\(h\) at\(x\),jet() returns a\(K\)-th orderTaylor approximation of\(f\) at\(x\),

\[(f_0, ..., f_K) :=(f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).\]

More specifically,jet() computes

\[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))\]

and can thus be used for high-orderautomatic differentiation of\(f\).Details are explained inthese notes.

Note

Help improvejet() by contributingoutstanding primitive rules.

API#

jax.experimental.jet.jet(fun,primals,series,factorial_scaled=True,**_)[source]#

Taylor-mode higher-order automatic differentiation.

Parameters:
  • fun – Function to be differentiated. Its arguments should be arrays, scalars,or standard Python containers of arrays or scalars. It should return anarray, scalar, or standard Python container of arrays or scalars.

  • primals – The primal values at which the Taylor approximation offun should beevaluated. Should be either a tuple or a list of arguments,and its length should be equal to the number of positional parameters offun.

  • series – Higher order Taylor-series-coefficients.Together,primals andseries make up a truncated Taylor polynomial.Should be either a tuple or a list of tuples or lists,and its length dictates the degree of the truncated Taylor polynomial.

  • factorial_scaled – If True, each term in both the input and output series is scaledby the factorial of its order, so that the input and output series is aTaylor series. This is the default behavior so that the n-th order termin the input and output series is the n-th order derivative of the function.If False, the input and output series are the non-factorial scaled Taylorcoefficients (i.e., the constant coefficients for each term in the Taylorseries).

Returns:

A(primals_out,series_out) pair, whereprimals_out isfun(*primals),and together,primals_out andseries_out are atruncated Taylor polynomial of\(f(h(\cdot))\).Theprimals_out value has the same Python tree structure asprimals,and theseries_out value the same Python tree structure asseries.

For example:

>>>importjax>>>importjax.numpyasnp

Consider the function\(h(z) = z^3\),\(x = 0.5\),and the first few Taylor coefficients\(h_0=x^3\),\(h_1=3x^2\), and\(h_2=6x\).Let\(f(y) = \sin(y)\).

>>>h0,h1,h2=0.5**3.,3.*0.5**2.,6.*0.5>>>f,df,ddf=np.sin,np.cos,lambda*args:-np.sin(*args)

jet() returns the Taylor coefficients of\(f(h(z)) = \sin(z^3)\)according to Faà di Bruno’s formula:

>>>f0,(f1,f2)=jet(f,(h0,),((h1,h2),))>>>print(f0,f(h0))0.12467473 0.12467473
>>>print(f1,df(h0)*h1)0.74414825 0.74414825
>>>print(f2,ddf(h0)*h1**2+df(h0)*h2)2.9064636 2.9064634
Contents

[8]ページ先頭

©2009-2025 Movatter.jp