Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.arange

Contents

jax.numpy.arange#

jax.numpy.arange(start,stop=None,step=None,dtype=None,*,device=None,out_sharding=None)[source]#

Create an array of evenly-spaced values.

JAX implementation ofnumpy.arange(), implemented in terms ofjax.lax.iota().

Similar to Python’srange() function, this can be called with a fewdifferent positional signatures:

  • jnp.arange(stop): generate values from 0 tostop, stepping by 1.

  • jnp.arange(start,stop): generate values fromstart tostop,stepping by 1.

  • jnp.arange(start,stop,step): generate values fromstart tostop,stepping bystep.

Like with Python’srange() function, the starting value is inclusive,and the stop value is exclusive.

Parameters:
  • start (ArrayLike |DimSize) – start of the interval, inclusive.

  • stop (ArrayLike |DimSize |None) – optional end of the interval, exclusive. If not specified, then(start,stop)=(0,start)

  • step (ArrayLike |None) – optional step size for the interval. Default = 1.

  • dtype (DTypeLike |None) – optional dtype for the returned array; if not specified it willbe determined via type promotion ofstart,stop, andstep.

  • device (xc.Device |Sharding |None) – (optional)Device orShardingto which the created array will be committed.

  • out_sharding (NamedSharding |P |None) – (optional)NamedSharding orP towhich the created array will be committed. Useout_sharding argument,if using explicit sharding(https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)

Returns:

Array of evenly-spaced values fromstart tostop, separated bystep.

Return type:

Array

Note

Usingarange with a floating-pointstep argument can lead to unexpectedresults due to accumulation of floating-point errors, especially withlower-precision data types likefloat8_* andbfloat16.To avoid precision errors, consider generating a range of integers, and scalingit to the desired range. For example, instead of this:

jnp.arange(-1,1,0.01,dtype='bfloat16')

it can be more accurate to generate a sequence of integers, and scale them:

(jnp.arange(-100,100)*0.01).astype('bfloat16')

Examples

Single-argument version specifies only thestop value:

>>>jnp.arange(4)Array([0, 1, 2, 3], dtype=int32)

Passing a floating-pointstop value leads to a floating-point result:

>>>jnp.arange(4.0)Array([0., 1., 2., 3.], dtype=float32)

Two-argument version specifiesstart andstop, withstep=1:

>>>jnp.arange(1,6)Array([1, 2, 3, 4, 5], dtype=int32)

Three-argument version specifiesstart,stop, andstep:

>>>jnp.arange(0,2,0.5)Array([0. , 0.5, 1. , 1.5], dtype=float32)

See also

Contents

[8]ページ先頭

©2009-2025 Movatter.jp