jax.numpy.pad
Contents
jax.numpy.pad#
- jax.numpy.pad(array,pad_width,mode='constant',**kwargs)[source]#
Add padding to an array.
JAX implementation of
numpy.pad().- Parameters:
array (ArrayLike) – array to pad.
pad_width (PadValueLike[int |Array |np.ndarray]) –
specify the pad width for each dimension of an array. Padding widthsmay be separately specified forbefore andafter the array. Options are:
intor(int,): pad each array dimension with the same number of valuesboth before and after.(before,after): pad each array withbeforeelements before, andafterelements after((before_1,after_1),(before_2,after_2),...(before_N,after_N)): specifydistinctbeforeandaftervalues for each array dimension.
mode (str |Callable[...,Any]) –
a string or callable. Supported pad modes are:
'constant'(default): pad with a constant value, which defaults to zero.'empty': pad with empty values (i.e. zero)'edge': pad with the edge values of the array.'wrap': pad by wrapping the array.'linear_ramp': pad with a linear ramp to specifiedend_values.'maximum': pad with the maximum value.'mean': pad with the mean value.'median': pad with the median value.'minimum': pad with the minimum value.'reflect': pad by reflection.'symmetric': pad by symmetric reflection.<callable>: a callable function. See Notes below.
constant_values – referenced for
mode='constant'. Specify the constant valueto pad with.stat_length – referenced for
modein['maximum','mean','median','minimum'].An integer or tuple specifying the number of edge values to use when calculatingthe statistic.end_values – referenced for
mode='linear_ramp'. Specify the end values toramp the padding values to.reflect_type – referenced for
modein['reflect','symmetric']. Specify whetherto use even or odd reflection.
- Returns:
A padded copy of
array.- Return type:
Notes
When
modeis callable, it should have the following signature:defpad_func(row:Array,pad_width:tuple[int,int],iaxis:int,kwargs:dict)->Array:...
Here
rowis a 1D slice of the padded array along axisiaxis, with the padvalues filled with zeros.pad_widthis a tuple specifying the(before,after)padding sizes, andkwargsare any additional keyword arguments passed to thejax.numpy.pad()function.Note that while in NumPy, the function should modify
rowin-place, in JAX thefunction should return the modifiedrow. In JAX, the custom padding functionwill be mapped across the padded axis using thejax.vmap()transformation.See also
jax.numpy.resize(): resize an arrayjax.numpy.tile(): create a larger array by tiling a smaller array.jax.numpy.repeat(): create a larger array by repeating values of a smaller array.
Examples
Pad a 1-dimensional array with zeros:
>>>x=jnp.array([10,20,30,40])>>>jnp.pad(x,2)Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32)>>>jnp.pad(x,(2,4))Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
Pad a 1-dimensional array with specified values:
>>>jnp.pad(x,2,constant_values=99)Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
Pad a 1-dimensional array with the mean array value:
>>>jnp.pad(x,2,mode='mean')Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
Pad a 1-dimensional array with reflected values:
>>>jnp.pad(x,2,mode='reflect')Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
Pad a 2-dimensional array with different paddings in each dimension:
>>>x=jnp.array([[1,2,3],...[4,5,6]])>>>jnp.pad(x,((1,2),(3,0)))Array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 3], [0, 0, 0, 4, 5, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=int32)
Pad a 1-dimensional array with a custom padding function:
>>>defcustom_pad(row,pad_width,iaxis,kwargs):...# row represents a 1D slice of the zero-padded array....before,after=pad_width...before_value=kwargs.get('before_value',0)...after_value=kwargs.get('after_value',0)...row=row.at[:before].set(before_value)...returnrow.at[len(row)-after:].set(after_value)>>>x=jnp.array([2,3,4])>>>jnp.pad(x,2,custom_pad,before_value=-10,after_value=10)Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)
