Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.pad

Contents

jax.numpy.pad#

jax.numpy.pad(array,pad_width,mode='constant',**kwargs)[source]#

Add padding to an array.

JAX implementation ofnumpy.pad().

Parameters:
  • array (ArrayLike) – array to pad.

  • pad_width (PadValueLike[int |Array |np.ndarray]) –

    specify the pad width for each dimension of an array. Padding widthsmay be separately specified forbefore andafter the array. Options are:

    • int or(int,): pad each array dimension with the same number of valuesboth before and after.

    • (before,after): pad each array withbefore elements before, andafterelements after

    • ((before_1,after_1),(before_2,after_2),...(before_N,after_N)): specifydistinctbefore andafter values for each array dimension.

  • mode (str |Callable[...,Any]) –

    a string or callable. Supported pad modes are:

    • 'constant' (default): pad with a constant value, which defaults to zero.

    • 'empty': pad with empty values (i.e. zero)

    • 'edge': pad with the edge values of the array.

    • 'wrap': pad by wrapping the array.

    • 'linear_ramp': pad with a linear ramp to specifiedend_values.

    • 'maximum': pad with the maximum value.

    • 'mean': pad with the mean value.

    • 'median': pad with the median value.

    • 'minimum': pad with the minimum value.

    • 'reflect': pad by reflection.

    • 'symmetric': pad by symmetric reflection.

    • <callable>: a callable function. See Notes below.

  • constant_values – referenced formode='constant'. Specify the constant valueto pad with.

  • stat_length – referenced formodein['maximum','mean','median','minimum'].An integer or tuple specifying the number of edge values to use when calculatingthe statistic.

  • end_values – referenced formode='linear_ramp'. Specify the end values toramp the padding values to.

  • reflect_type – referenced formodein['reflect','symmetric']. Specify whetherto use even or odd reflection.

Returns:

A padded copy ofarray.

Return type:

Array

Notes

Whenmode is callable, it should have the following signature:

defpad_func(row:Array,pad_width:tuple[int,int],iaxis:int,kwargs:dict)->Array:...

Hererow is a 1D slice of the padded array along axisiaxis, with the padvalues filled with zeros.pad_width is a tuple specifying the(before,after)padding sizes, andkwargs are any additional keyword arguments passed to thejax.numpy.pad() function.

Note that while in NumPy, the function should modifyrow in-place, in JAX thefunction should return the modifiedrow. In JAX, the custom padding functionwill be mapped across the padded axis using thejax.vmap() transformation.

See also

Examples

Pad a 1-dimensional array with zeros:

>>>x=jnp.array([10,20,30,40])>>>jnp.pad(x,2)Array([ 0,  0, 10, 20, 30, 40,  0,  0], dtype=int32)>>>jnp.pad(x,(2,4))Array([ 0,  0, 10, 20, 30, 40,  0,  0,  0,  0], dtype=int32)

Pad a 1-dimensional array with specified values:

>>>jnp.pad(x,2,constant_values=99)Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)

Pad a 1-dimensional array with the mean array value:

>>>jnp.pad(x,2,mode='mean')Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)

Pad a 1-dimensional array with reflected values:

>>>jnp.pad(x,2,mode='reflect')Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)

Pad a 2-dimensional array with different paddings in each dimension:

>>>x=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.pad(x,((1,2),(3,0)))Array([[0, 0, 0, 0, 0, 0],       [0, 0, 0, 1, 2, 3],       [0, 0, 0, 4, 5, 6],       [0, 0, 0, 0, 0, 0],       [0, 0, 0, 0, 0, 0]], dtype=int32)

Pad a 1-dimensional array with a custom padding function:

>>>defcustom_pad(row,pad_width,iaxis,kwargs):...# row represents a 1D slice of the zero-padded array....before,after=pad_width...before_value=kwargs.get('before_value',0)...after_value=kwargs.get('after_value',0)...row=row.at[:before].set(before_value)...returnrow.at[len(row)-after:].set(after_value)>>>x=jnp.array([2,3,4])>>>jnp.pad(x,2,custom_pad,before_value=-10,after_value=10)Array([-10, -10,   2,   3,   4,  10,  10], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp