Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Complex numbers and differentiation

Complex numbers and differentiation#

JAX is great at complex numbers and differentiation. To support bothholomorphic and non-holomorphic differentiation, it helps to think in terms of JVPs and VJPs.

Consider a complex-to-complex function\(f: \mathbb{C} \to \mathbb{C}\) and identify it with a corresponding function\(g: \mathbb{R}^2 \to \mathbb{R}^2\),

importjax.numpyasjnpdeff(z):x,y=jnp.real(z),jnp.imag(z)returnu(x,y)+v(x,y)*1jdefg(x,y):return(u(x,y),v(x,y))

That is, we’ve decomposed\(f(z) = u(x, y) + v(x, y) i\) where\(z = x + y i\), and identified\(\mathbb{C}\) with\(\mathbb{R}^2\) to get\(g\).

Since\(g\) only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector\((c, d) \in \mathbb{R}^2\), namely:

\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}\begin{bmatrix} c \\ d \end{bmatrix}\).

To get a JVP for the original function\(f\) applied to a tangent vector\(c + di \in \mathbb{C}\), we just use the same definition and identify the result as another complex number,

\(\partial f(x + y i)(c + d i) =\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix}\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}\begin{bmatrix} c \\ d \end{bmatrix}\).

That’s our definition of the JVP of a\(\mathbb{C} \to \mathbb{C}\) function! Notice it doesn’t matter whether or not\(f\) is holomorphic: the JVP is unambiguous.

Here’s a check:

fromjaximportrandom,grad,jvpdefcheck(seed):key=random.key(seed)# random coeffs for u and vkey,subkey=random.split(key)a,b,c,d=random.uniform(subkey,(4,))deffun(z):x,y=jnp.real(z),jnp.imag(z)returnu(x,y)+v(x,y)*1jdefu(x,y):returna*x+b*ydefv(x,y):returnc*x+d*y# primal pointkey,subkey=random.split(key)x,y=random.uniform(subkey,(2,))z=x+y*1j# tangent vectorkey,subkey=random.split(key)c,d=random.uniform(subkey,(2,))z_dot=c+d*1j# check jvp_,ans=jvp(fun,(z,),(z_dot,))expected=(grad(u,0)(x,y)*c+grad(u,1)(x,y)*d+grad(v,0)(x,y)*c*1j+grad(v,1)(x,y)*d*1j)print(jnp.allclose(ans,expected))
check(0)check(1)check(2)
TrueTrueTrue

What about VJPs? We do something pretty similar: for a cotangent vector\(c + di \in \mathbb{C}\) we define the VJP of\(f\) as

\((c + di)^* \; \partial f(x + y i) =\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix}\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}\begin{bmatrix} 1 \\ -i \end{bmatrix}\).

What’s with the negatives? They’re just to take care of complex conjugation, and the fact that we’re working with covectors.

Here’s a check of the VJP rules:

fromjaximportvjpdefcheck(seed):key=random.key(seed)# random coeffs for u and vkey,subkey=random.split(key)a,b,c,d=random.uniform(subkey,(4,))deffun(z):x,y=jnp.real(z),jnp.imag(z)returnu(x,y)+v(x,y)*1jdefu(x,y):returna*x+b*ydefv(x,y):returnc*x+d*y# primal pointkey,subkey=random.split(key)x,y=random.uniform(subkey,(2,))z=x+y*1j# cotangent vectorkey,subkey=random.split(key)c,d=random.uniform(subkey,(2,))z_bar=jnp.array(c+d*1j)# for dtype control# check vjp_,fun_vjp=vjp(fun,z)ans,=fun_vjp(z_bar)expected=(grad(u,0)(x,y)*c+grad(v,0)(x,y)*(-d)+grad(u,1)(x,y)*c*(-1j)+grad(v,1)(x,y)*(-d)*(-1j))assertjnp.allclose(ans,expected,atol=1e-5,rtol=1e-5)
check(0)check(1)check(2)

What about convenience wrappers likejax.grad(),jax.jacfwd(), andjax.jacrev()?

For\(\mathbb{R} \to \mathbb{R}\) functions, recall we definedgrad(f)(x) as beingvjp(f,x)[1](1.0), which works because applying a VJP to a1.0 value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for\(\mathbb{C} \to \mathbb{R}\) functions: we can still use1.0 as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian:

deff(z):x,y=jnp.real(z),jnp.imag(z)returnx**2+y**2z=3.+4jgrad(f)(z)
Array(6.-8.j, dtype=complex64)

For general\(\mathbb{C} \to \mathbb{C}\) functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can’t hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a\(\mathbb{C} \to \mathbb{C}\) function with the special property that its derivative can be represented as a single complex number. (TheCauchy-Riemann equations ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call tovjp with a covector of1.0.

Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error whenjax.grad() is used for a complex-output function:

deff(z):returnjnp.sin(z)z=3.+4jgrad(f,holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

All theholomorphic=True promise does is disable the error when the output is complex-valued. We can still writeholomorphic=True when the function isn’t holomorphic, but the answer we get out won’t represent the full Jacobian. Instead, it’ll be the Jacobian of the function where we just discard the imaginary part of the output:

deff(z):returnjnp.conjugate(z)z=3.+4jgrad(f,holomorphic=True)(z)# f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)

There are some useful upshots for howjax.grad() works here:

  1. We can usejax.grad() on holomorphic\(\mathbb{C} \to \mathbb{C}\) functions.

  2. We can usejax.grad() to optimize\(f : \mathbb{C} \to \mathbb{R}\) functions, like real-valued loss functions of complex parametersx, by taking steps in the direction of the conjugate ofgrad(f)(x).

  3. If we have an\(\mathbb{R} \to \mathbb{R}\) function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) thenjax.grad() still works and we get the same result that an implementation using only real values would have given.

In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic\(\mathbb{C} \to \mathbb{C}\) function, we can do it with JVPs or VJPs!

You should expect complex numbers to work everywhere in JAX. Here’s differentiating through a Cholesky decomposition of a complex matrix:

A=jnp.array([[5.,2.+3j,5j],[2.-3j,7.,1.+7j],[-5j,1.-7j,12.]])deff(X):L=jnp.linalg.cholesky(X)returnjnp.sum((L-jnp.sin(L))**2)grad(f,holomorphic=True)(A)
Array([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,         5.9896846  +3.5423026j],       [-3.0509028 +10.940544j , -8.904491   +0.j       ,        -5.1351523  -6.559373j ],       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,         0.01320427 +0.j       ]], dtype=complex64)

[8]ページ先頭

©2009-2026 Movatter.jp