jax.numpy.fft.fftn
Contents
jax.numpy.fft.fftn#
- jax.numpy.fft.fftn(a,s=None,axes=None,norm=None)[source]#
Compute a multidimensional discrete Fourier transform along given axes.
JAX implementation of
numpy.fft.fftn().- Parameters:
a (ArrayLike) – input array
s (Shape |None) – sequence of integers. Specifies the shape of the result. If not specified,it will default to the shape of
aalong the specifiedaxes.axes (Sequence[int]|None) – sequence of integers, default=None. Specifies the axes along which thetransform is computed.
norm (str |None) – string. The normalization mode. “backward”, “ortho” and “forward” aresupported.
- Returns:
An array containing the multidimensional discrete Fourier transform of
a.- Return type:
See also
jax.numpy.fft.fft(): Computes a one-dimensional discrete Fouriertransform.jax.numpy.fft.ifft(): Computes a one-dimensional inverse discreteFourier transform.jax.numpy.fft.ifftn(): Computes a multidimensional inverse discreteFourier transform.
Examples
jnp.fft.fftncomputes the transform along all the axes by default whenaxesargument isNone.>>>x=jnp.array([[1,2,5,6],...[4,1,3,7],...[5,9,2,1]])>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.fftn(x)Array([[ 46. +0.j , 0. +2.j , -6. +0.j , 0. -2.j ], [ -2. +1.73j, 6.12+6.73j, 0. -1.73j, -18.12-3.27j], [ -2. -1.73j, -18.12+3.27j, 0. +1.73j, 6.12-6.73j]], dtype=complex64)
When
s=[2], dimension of the transform alongaxis-1will be2and dimension along other axes will be the same as that of input.>>>withjnp.printoptions(precision=2,suppress=True):...print(jax.numpy.fft.fftn(x,s=[2]))[[ 3.+0.j -1.+0.j] [ 5.+0.j 3.+0.j] [14.+0.j -4.+0.j]]
When
s=[2]andaxes=[0], dimension of the transform alongaxis0will be2and dimension along other axes will be same as that of input.>>>withjnp.printoptions(precision=2,suppress=True):...print(jax.numpy.fft.fftn(x,s=[2],axes=[0]))[[ 5.+0.j 3.+0.j 8.+0.j 13.+0.j] [-3.+0.j 1.+0.j 2.+0.j -1.+0.j]]
When
s=[2,3], shape of the transform will be(2,3).>>>withjnp.printoptions(precision=2,suppress=True):...print(jax.numpy.fft.fftn(x,s=[2,3]))[[16. +0.j -0.5+4.33j -0.5-4.33j] [ 0. +0.j -4.5+0.87j -4.5-0.87j]]
jnp.fft.ifftncan be used to reconstructxfrom the result ofjnp.fft.fftn.>>>x_fftn=jnp.fft.fftn(x)>>>jnp.allclose(x,jnp.fft.ifftn(x_fftn))Array(True, dtype=bool)
