jax.numpy.result_type
Contents
jax.numpy.result_type#
- jax.numpy.result_type(*args)[source]#
Return the result of applying JAX promotion rules to the inputs.
JAX implementation of
numpy.result_type().JAX’s dtype promotion behavior is described inType promotion semantics.
- Parameters:
args (Any) – one or more arrays or dtype-like objects.
- Returns:
A
numpy.dtypeinstance representing the result of typepromotion for the inputs.- Return type:
DType
Examples
Inputs can be dtype specifiers:
>>>jnp.result_type('int32','float32')dtype('float32')>>>jnp.result_type(np.uint16,np.dtype('int32'))dtype('int32')
Inputs may also be scalars or arrays:
>>>jnp.result_type(1.0,jnp.bfloat16(2))dtype(bfloat16)>>>jnp.result_type(jnp.arange(4),jnp.zeros(4))dtype('float32')
Be aware that the result type will be canonicalized based on the stateof the
jax_enable_x64configuration flag, meaning that 64-bit typesmay be downcast to 32-bit:>>>jnp.result_type('float64')dtype('float32')
For details on 64-bit values, refer toSharp bits - double precision:
Contents
