jax.numpy.isscalar
Contents
jax.numpy.isscalar#
- jax.numpy.isscalar(element)[source]#
Return True if the input is a scalar.
JAX implementation of
numpy.isscalar(). JAX’s implementation differsfrom NumPy’s in that it considers zero-dimensional arrays to be scalars; seetheNote below for more details.- Parameters:
element (Any) – input object to check; any type is valid input.
- Returns:
True if
elementis a scalar value or an array-like object with zerodimensions, False otherwise.- Return type:
Note
JAX and NumPy differ in their representation of scalar values. NumPy hasspecial scalar objects (e.g.
np.int32(0)) which are distinct fromzero-dimensional arrays (e.g.np.array(0)), andnumpy.isscalar()returnsTruefor the former andFalsefor the latter.JAX does not define special scalar objects, but rather represents scalars aszero-dimensional arrays. As such,
jax.numpy.isscalar()returnsTruefor both scalar objects (e.g.0.0ornp.float32(0.0)) and array-likeobjects with zero dimensions (e.g.jnp.array(0.0),np.array(0.0)).One reason for the different conventions in
isscalaris to maintainJIT-invariance: i.e. the property that the result of a function should notchange when it is JIT-compiled. Because scalar inputs are cast tozero-dimensional JAX arrays at JIT boundaries, the semantics ofnumpy.isscalar()are such that the result changes under JIT:>>>np.isscalar(1.0)True>>>jax.jit(np.isscalar)(1.0)Array(False, dtype=bool)
By treating zero-dimensional arrays as scalars,
jax.numpy.isscalar()avoids this issue:>>>jnp.isscalar(1.0)True>>>jax.jit(jnp.isscalar)(1.0)Array(True, dtype=bool)
Examples
In JAX, both scalars and zero-dimensional array-like objects are consideredscalars:
>>>jnp.isscalar(1.0)True>>>jnp.isscalar(1+1j)True>>>jnp.isscalar(jnp.array(1))# zero-dimensional JAX arrayTrue>>>jnp.isscalar(jnp.int32(1))# JAX scalar constructorTrue>>>jnp.isscalar(np.array(1.0))# zero-dimensional NumPy arrayTrue>>>jnp.isscalar(np.int32(1))# NumPy scalar typeTrue
Arrays with one or more dimension are not considered scalars:
>>>jnp.isscalar(jnp.array([1]))False>>>jnp.isscalar(np.array([1]))False
Compare this to
numpy.isscalar(), which returnsTrueforscalar-typed objects, andFalseforall arrays, even those withzero dimensions:>>>np.isscalar(np.int32(1))# scalar objectTrue>>>np.isscalar(np.array(1))# zero-dimensional arrayFalse
In JAX, as in NumPy, objects which are not array-like are not consideredscalars:
>>>jnp.isscalar(None)False>>>jnp.isscalar([1])False>>>jnp.isscalar(())False>>>jnp.isscalar(slice(10))False
