Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.isscalar

Contents

jax.numpy.isscalar#

jax.numpy.isscalar(element)[source]#

Return True if the input is a scalar.

JAX implementation ofnumpy.isscalar(). JAX’s implementation differsfrom NumPy’s in that it considers zero-dimensional arrays to be scalars; seetheNote below for more details.

Parameters:

element (Any) – input object to check; any type is valid input.

Returns:

True ifelement is a scalar value or an array-like object with zerodimensions, False otherwise.

Return type:

bool

Note

JAX and NumPy differ in their representation of scalar values. NumPy hasspecial scalar objects (e.g.np.int32(0)) which are distinct fromzero-dimensional arrays (e.g.np.array(0)), andnumpy.isscalar()returnsTrue for the former andFalse for the latter.

JAX does not define special scalar objects, but rather represents scalars aszero-dimensional arrays. As such,jax.numpy.isscalar() returnsTruefor both scalar objects (e.g.0.0 ornp.float32(0.0)) and array-likeobjects with zero dimensions (e.g.jnp.array(0.0),np.array(0.0)).

One reason for the different conventions inisscalar is to maintainJIT-invariance: i.e. the property that the result of a function should notchange when it is JIT-compiled. Because scalar inputs are cast tozero-dimensional JAX arrays at JIT boundaries, the semantics ofnumpy.isscalar() are such that the result changes under JIT:

>>>np.isscalar(1.0)True>>>jax.jit(np.isscalar)(1.0)Array(False, dtype=bool)

By treating zero-dimensional arrays as scalars,jax.numpy.isscalar()avoids this issue:

>>>jnp.isscalar(1.0)True>>>jax.jit(jnp.isscalar)(1.0)Array(True, dtype=bool)

Examples

In JAX, both scalars and zero-dimensional array-like objects are consideredscalars:

>>>jnp.isscalar(1.0)True>>>jnp.isscalar(1+1j)True>>>jnp.isscalar(jnp.array(1))# zero-dimensional JAX arrayTrue>>>jnp.isscalar(jnp.int32(1))# JAX scalar constructorTrue>>>jnp.isscalar(np.array(1.0))# zero-dimensional NumPy arrayTrue>>>jnp.isscalar(np.int32(1))# NumPy scalar typeTrue

Arrays with one or more dimension are not considered scalars:

>>>jnp.isscalar(jnp.array([1]))False>>>jnp.isscalar(np.array([1]))False

Compare this tonumpy.isscalar(), which returnsTrue forscalar-typed objects, andFalse forall arrays, even those withzero dimensions:

>>>np.isscalar(np.int32(1))# scalar objectTrue>>>np.isscalar(np.array(1))# zero-dimensional arrayFalse

In JAX, as in NumPy, objects which are not array-like are not consideredscalars:

>>>jnp.isscalar(None)False>>>jnp.isscalar([1])False>>>jnp.isscalar(())False>>>jnp.isscalar(slice(10))False
Contents

[8]ページ先頭

©2009-2025 Movatter.jp