jax.numpy.convolve
Contents
jax.numpy.convolve#
- jax.numpy.convolve(a,v,mode='full',*,precision=None,preferred_element_type=None)[source]#
Convolution of two one dimensional arrays.
JAX implementation of
numpy.convolve().Convolution of one dimensional arrays is defined as:
\[c_k = \sum_j a_{k - j} v_j\]- Parameters:
a (ArrayLike) – left-hand input to the convolution. Must have
a.ndim==1.v (ArrayLike) – right-hand input to the convolution. Must have
v.ndim==1.mode (str) –
controls the size of the output. Available operations are:
"full": (default) output the full convolution of the inputs."same": return a centered portion of the"full"output whichis the same size asa."valid": return the portion of the"full"output which do notdepend on padding at the array edges.
precision (lax.PrecisionLike) – Specify the precision of the computation. Refer to
jax.lax.Precisionfor a description of available values.preferred_element_type (DTypeLike |None) – A datatype, indicating to accumulate results to andreturn a result with that datatype. Default is
None, which means thedefault accumulation type for the input types.
- Returns:
Array containing the convolved result.
- Return type:
See also
jax.scipy.signal.convolve(): ND convolutionjax.numpy.correlate(): 1D correlation
Examples
A few 1D convolution examples:
>>>x=jnp.array([1,2,3,2,1])>>>y=jnp.array([4,1,2])
jax.numpy.convolve, by default, returns full convolution using implicitzero-padding at the edges:>>>jnp.convolve(x,y)Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32)
Specifying
mode='same'returns a centered convolution the same sizeas the first input:>>>jnp.convolve(x,y,mode='same')Array([ 9., 16., 15., 12., 5.], dtype=float32)
Specifying
mode='valid'returns only the portion where the two arraysfully overlap:>>>jnp.convolve(x,y,mode='valid')Array([16., 15., 12.], dtype=float32)
For complex-valued inputs:
>>>x1=jnp.array([3+1j,2,4-3j])>>>y1=jnp.array([1,2-3j,4+5j])>>>jnp.convolve(x1,y1)Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64)
