Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.convolve

Contents

jax.numpy.convolve#

jax.numpy.convolve(a,v,mode='full',*,precision=None,preferred_element_type=None)[source]#

Convolution of two one dimensional arrays.

JAX implementation ofnumpy.convolve().

Convolution of one dimensional arrays is defined as:

\[c_k = \sum_j a_{k - j} v_j\]
Parameters:
  • a (ArrayLike) – left-hand input to the convolution. Must havea.ndim==1.

  • v (ArrayLike) – right-hand input to the convolution. Must havev.ndim==1.

  • mode (str) –

    controls the size of the output. Available operations are:

    • "full": (default) output the full convolution of the inputs.

    • "same": return a centered portion of the"full" output whichis the same size asa.

    • "valid": return the portion of the"full" output which do notdepend on padding at the array edges.

  • precision (lax.PrecisionLike) – Specify the precision of the computation. Refer tojax.lax.Precision for a description of available values.

  • preferred_element_type (DTypeLike |None) – A datatype, indicating to accumulate results to andreturn a result with that datatype. Default isNone, which means thedefault accumulation type for the input types.

Returns:

Array containing the convolved result.

Return type:

Array

See also

Examples

A few 1D convolution examples:

>>>x=jnp.array([1,2,3,2,1])>>>y=jnp.array([4,1,2])

jax.numpy.convolve, by default, returns full convolution using implicitzero-padding at the edges:

>>>jnp.convolve(x,y)Array([ 4.,  9., 16., 15., 12.,  5.,  2.], dtype=float32)

Specifyingmode='same' returns a centered convolution the same sizeas the first input:

>>>jnp.convolve(x,y,mode='same')Array([ 9., 16., 15., 12.,  5.], dtype=float32)

Specifyingmode='valid' returns only the portion where the two arraysfully overlap:

>>>jnp.convolve(x,y,mode='valid')Array([16., 15., 12.], dtype=float32)

For complex-valued inputs:

>>>x1=jnp.array([3+1j,2,4-3j])>>>y1=jnp.array([1,2-3j,4+5j])>>>jnp.convolve(x1,y1)Array([ 3. +1.j, 11. -7.j, 15.+10.j,  7. -8.j, 31. +8.j], dtype=complex64)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp