Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.linearize

Contents

jax.linearize#

jax.linearize(fun:Callable,*primals,has_aux:Literal[False]=False)tuple[Any,Callable][source]#
jax.linearize(fun:Callable,*primals,has_aux:Literal[True])tuple[Any,Callable,Any]

Produces a linear approximation tofun usingjvp() and partial eval.

Parameters:
  • fun – Function to be differentiated. Its arguments should be arrays, scalars,or standard Python containers of arrays or scalars. It should return anarray, scalar, or standard python container of arrays or scalars.

  • primals – The primal values at which the Jacobian offun should beevaluated. Should be a tuple of arrays, scalar, or standard Pythoncontainer thereof. The length of the tuple is equal to the number ofpositional parameters offun.

  • has_aux – Optional, bool. Indicates whetherfun returns a pair where the firstelement is considered the output of the mathematical function to be linearized,and the second is auxiliary data. Default False.

Returns:

Ifhas_aux isFalse, returns a pair where the first element is the value off(*primals) and the second element is a function that evaluates the(forward-mode) Jacobian-vector product offun evaluated atprimals withoutre-doing the linearization work. Ifhas_aux isTrue, returns a(primals_out,lin_fn,aux) tuple whereaux is the auxiliary data returned byfun.

In terms of values computed,linearize() behaves much like a curriedjvp(), where these two code blocks compute the same values:

y,out_tangent=jax.jvp(f,(x,),(in_tangent,))y,f_jvp=jax.linearize(f,x)out_tangent=f_jvp(in_tangent)

However, the difference is thatlinearize() uses partial evaluationso that the functionf is not re-linearized on calls tof_jvp. Ingeneral that means the memory usage scales with the size of the computation,much like in reverse-mode. (Indeed,linearize() has a similarsignature tovjp()!)

This function is mainly useful if you want to applyf_jvp multiple times,i.e. to evaluate a pushforward for many different input tangent vectors at thesame linearization point. Moreover if all the input tangent vectors are knownat once, it can be more efficient to vectorize usingvmap(), as in:

pushfwd=partial(jvp,f,(x,))y,out_tangents=vmap(pushfwd,out_axes=(None,0))((in_tangents,))

By usingvmap() andjvp() together like this we avoid the stored-linearizationmemory cost that scales with the depth of the computation, which is incurredby bothlinearize() andvjp().

Here’s a more complete example of usinglinearize():

>>>importjax>>>importjax.numpyasjnp>>>>>>deff(x):return3.*jnp.sin(x)+jnp.cos(x/2.)...>>>jax.jvp(f,(2.,),(3.,))(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))>>>y,f_jvp=jax.linearize(f,2.)>>>print(y)3.2681944>>>print(f_jvp(3.))-5.007528>>>print(f_jvp(4.))-6.676704
Contents

[8]ページ先頭

©2009-2025 Movatter.jp