jax.numpy.fft.rfftn
Contents
jax.numpy.fft.rfftn#
- jax.numpy.fft.rfftn(a,s=None,axes=None,norm=None)[source]#
Compute a multidimensional discrete Fourier transform of a real-valued array.
JAX implementation of
numpy.fft.rfftn().- Parameters:
a (ArrayLike) – real-valued input array.
s (Shape |None) – optional sequence of integers. Controls the effective size of the inputalong each specified axis. If not specified, it will default to thedimension of input along
axes.axes (Sequence[int]|None) – optional sequence of integers, default=None. Specifies the axes alongwhich the transform is computed. If not specified, the transform is computedalong the last
len(s)axes. If neitheraxesnorsis specified,the transform is computed along all the axes.norm (str |None) – string, default=”backward”. The normalization mode. “backward”, “ortho”and “forward” are supported.
- Returns:
An array containing the multidimensional discrete Fourier transform of
ahaving size specified insalong the axesaxesexcept along the axisaxes[-1]. The size of the output along the axisaxes[-1]iss[-1]//2+1.- Return type:
See also
jax.numpy.fft.rfft(): Computes a one-dimensional discrete Fouriertransform of real-valued array.jax.numpy.fft.rfft2(): Computes a two-dimensional discrete Fouriertransform of real-valued array.jax.numpy.fft.irfftn(): Computes a real-valued multidimensional inversediscrete Fourier transform.
Examples
>>>x=jnp.array([[[1,3,5],...[2,4,6]],...[[7,9,11],...[8,10,12]]])>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.rfftn(x)Array([[[ 78.+0.j , -12.+6.93j], [ -6.+0.j , 0.+0.j ]], [[-36.+0.j , 0.+0.j ], [ 0.+0.j , 0.+0.j ]]], dtype=complex64)
When
s=[3,3,4], size of the transform alongaxes(-3,-2)willbe (3, 3), and alongaxis-1will be4//2+1=3and size alongother axes will be the same as that of input.>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.rfftn(x,s=[3,3,4])Array([[[ 78. +0.j , -16. -26.j , 26. +0.j ], [ 15. -36.37j, -16.12 +1.93j, 5. -12.12j], [ 15. +36.37j, 8.12-11.93j, 5. +12.12j]], [[ -7.5 -49.36j, -20.45 +9.43j, -2.5 -16.45j], [-25.5 -7.79j, -0.6 +11.96j, -8.5 -2.6j ], [ 19.5 -12.99j, -8.33 -6.5j , 6.5 -4.33j]], [[ -7.5 +49.36j, 12.45 -4.43j, -2.5 +16.45j], [ 19.5 +12.99j, 0.33 -6.5j , 6.5 +4.33j], [-25.5 +7.79j, 4.6 +5.04j, -8.5 +2.6j ]]], dtype=complex64)
When
s=[3,5]andaxes=(0,1), size of the transform alongaxis0will be3, alongaxis1will be5//2+1=3and dimension alongother axes will be same as that of input.>>>withjnp.printoptions(precision=2,suppress=True):...jnp.fft.rfftn(x,s=[3,5],axes=[0,1])Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ], [ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j], [ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]], [[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j], [ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j], [ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]], [[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j], [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64)
For 1-D input:
>>>x1=jnp.array([1,2,3,4])>>>jnp.fft.rfftn(x1)Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64)
