jax.numpy.arange
Contents
jax.numpy.arange#
- jax.numpy.arange(start,stop=None,step=None,dtype=None,*,device=None,out_sharding=None)[source]#
Create an array of evenly-spaced values.
JAX implementation of
numpy.arange(), implemented in terms ofjax.lax.iota().Similar to Python’s
range()function, this can be called with a fewdifferent positional signatures:jnp.arange(stop): generate values from 0 tostop, stepping by 1.jnp.arange(start,stop): generate values fromstarttostop,stepping by 1.jnp.arange(start,stop,step): generate values fromstarttostop,stepping bystep.
Like with Python’s
range()function, the starting value is inclusive,and the stop value is exclusive.- Parameters:
start (ArrayLike |DimSize) – start of the interval, inclusive.
stop (ArrayLike |DimSize |None) – optional end of the interval, exclusive. If not specified, then
(start,stop)=(0,start)step (ArrayLike |None) – optional step size for the interval. Default = 1.
dtype (DTypeLike |None) – optional dtype for the returned array; if not specified it willbe determined via type promotion ofstart,stop, andstep.
device (xc.Device |Sharding |None) – (optional)
DeviceorShardingto which the created array will be committed.out_sharding (NamedSharding |P |None) – (optional)
NamedShardingorPtowhich the created array will be committed. Useout_sharding argument,if using explicit sharding(https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)
- Returns:
Array of evenly-spaced values from
starttostop, separated bystep.- Return type:
Note
Using
arangewith a floating-pointstepargument can lead to unexpectedresults due to accumulation of floating-point errors, especially withlower-precision data types likefloat8_*andbfloat16.To avoid precision errors, consider generating a range of integers, and scalingit to the desired range. For example, instead of this:jnp.arange(-1,1,0.01,dtype='bfloat16')
it can be more accurate to generate a sequence of integers, and scale them:
(jnp.arange(-100,100)*0.01).astype('bfloat16')
Examples
Single-argument version specifies only the
stopvalue:>>>jnp.arange(4)Array([0, 1, 2, 3], dtype=int32)
Passing a floating-point
stopvalue leads to a floating-point result:>>>jnp.arange(4.0)Array([0., 1., 2., 3.], dtype=float32)
Two-argument version specifies
startandstop, withstep=1:>>>jnp.arange(1,6)Array([1, 2, 3, 4, 5], dtype=int32)
Three-argument version specifies
start,stop, andstep:>>>jnp.arange(0,2,0.5)Array([0. , 0.5, 1. , 1.5], dtype=float32)
See also
jax.numpy.linspace(): generate a fixed number of evenly-spaced values.jax.lax.iota(): directly generate integer sequences in XLA.
