jax.dtypes module
jax.dtypes module#
bfloat16 floating-point values | |
| Convert from a dtype to a canonical dtype based on config.x64_enabled. |
DType class corresponding to the scalar type and dtype of the same name. | |
| Number of bits per element for the dtype. |
| Returns True if first argument is a typecode lower/equal in type hierarchy. |
| Scalar class for PRNG Key dtypes. |
| Convenience function to apply JAX argument dtype promotion. |
Return the scalar type associated with a JAX value. |
