jax.numpy.fft.fft2
Contents
jax.numpy.fft.fft2#
- jax.numpy.fft.fft2(a,s=None,axes=(-2,-1),norm=None)[source]#
Compute a two-dimensional discrete Fourier transform along given axes.
JAX implementation of
numpy.fft.fft2().- Parameters:
a (ArrayLike) – input array. Must have
a.ndim>=2.s (Shape |None) – optional length-2 sequence of integers. Specifies the size of the outputalong each specified axis. If not specified, it will default to the sizeof
aalong the specifiedaxes.axes (Sequence[int]) – optional length-2 sequence of integers, default=(-2,-1). Specifies theaxes along which the transform is computed.
norm (str |None) – string, default=”backward”. The normalization mode. “backward”, “ortho”and “forward” are supported.
- Returns:
An array containing the two-dimensional discrete Fourier transform of
aalong givenaxes.- Return type:
See also
jax.numpy.fft.fft(): Computes a one-dimensional discrete Fouriertransform.jax.numpy.fft.fftn(): Computes a multidimensional discrete Fouriertransform.jax.numpy.fft.ifft2(): Computes a two-dimensional inverse discreteFourier transform.
Examples
jnp.fft.fft2computes the transform along the last two axes by default.>>>x=jnp.array([[[1,3],...[2,4]],...[[5,7],...[6,8]]])>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.fft2(x)Array([[[10.+0.j, -4.+0.j], [-2.+0.j, 0.+0.j]], [[26.+0.j, -4.+0.j], [-2.+0.j, 0.+0.j]]], dtype=complex64)
When
s=[2,3], dimension of the transform alongaxes(-2,-1)will be(2,3)and dimension along other axes will be the same as that of input.>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.fft2(x,s=[2,3])Array([[[10. +0.j , -0.5 -6.06j, -0.5 +6.06j], [-2. +0.j , -0.5 +0.87j, -0.5 -0.87j]], [[26. +0.j , 3.5-12.99j, 3.5+12.99j], [-2. +0.j , -0.5 +0.87j, -0.5 -0.87j]]], dtype=complex64)
When
s=[2,3]andaxes=(0,1), shape of the transform alongaxes(0,1)will be(2,3)and dimension along other axes will besame as that of input.>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.fft2(x,s=[2,3],axes=(0,1))Array([[[14. +0.j , 22. +0.j ], [ 2. -6.93j, 4.-10.39j], [ 2. +6.93j, 4.+10.39j]], [[-8. +0.j , -8. +0.j ], [-2. +3.46j, -2. +3.46j], [-2. -3.46j, -2. -3.46j]]], dtype=complex64)
jnp.fft.ifft2can be used to reconstructxfrom the result ofjnp.fft.fft2.>>>x_fft2=jnp.fft.fft2(x)>>>jnp.allclose(x,jnp.fft.ifft2(x_fft2))Array(True, dtype=bool)
