jax.numpy.issubdtype
Contents
jax.numpy.issubdtype#
- jax.numpy.issubdtype(arg1,arg2)[source]#
Return True if arg1 is equal or lower than arg2 in the type hierarchy.
JAX implementation of
numpy.issubdtype().The main difference in JAX’s implementation is that it properly handlesdtype extensions such as
bfloat16.- Parameters:
arg1 (DTypeLike) – dtype-like object. In typical usage, this will be a dtype specifier,such as
"float32"(i.e. a string),np.dtype('int32')(i.e. aninstance ofnumpy.dtype),jnp.complex64(i.e. a JAX scalarconstructor), ornp.uint8(i.e. a NumPy scalar type).arg2 (DTypeLike) – dtype-like object. In typical usage, this will be a generic scalartype, such as
jnp.integer,jnp.floating, orjnp.complexfloating.
- Returns:
True if arg1 represents a dtype that is equal or lower in the typehierarchy than arg2.
- Return type:
See also
jax.numpy.isdtype(): similar function aligning with the array API standard.
Examples
>>>jnp.issubdtype('uint32',jnp.unsignedinteger)True>>>jnp.issubdtype(np.int32,jnp.integer)True>>>jnp.issubdtype(jnp.bfloat16,jnp.floating)True>>>jnp.issubdtype(np.dtype('complex64'),jnp.complexfloating)True>>>jnp.issubdtype('complex64',jnp.integer)False
Be aware that while this is very similar to
numpy.issubdtype(), theresults of these differ in the case of JAX’s custom floating point types:>>>np.issubdtype('bfloat16',np.floating)False>>>jnp.issubdtype('bfloat16',jnp.floating)True
