jax.Array.view
Contents
jax.Array.view#
- abstractArray.view(dtype=None,type=None)[source]#
Return a bitwise copy of the array, viewed as a new dtype.
This is fuller-featured wrapper around
jax.lax.bitcast_convert_type().If the source and target dtype have the same bitwidth, the result has the sameshape as the input array. If the bitwidth of the target dtype is differentfrom the source, the size of the last axis of the result is adjustedaccordingly.
>>>jnp.zeros([1,2,3],dtype=jnp.int16).view(jnp.int8).shape(1, 2, 6)>>>jnp.zeros([1,2,4],dtype=jnp.int8).view(jnp.int16).shape(1, 2, 2)
Conversions involving booleans are not well-defined in all situations. Withregards to the shape of result as explained above, booleans are treated ashaving a bitwidth of 8. However, when converting to a boolean array, the inputshould only contain 0 or 1 bytes. Otherwise, results may be unpredictable ormay change depending on how the result is used.
This conversion is guaranteed and safe:
>>>jnp.array([1,0,1],dtype=jnp.int8).view(jnp.bool_)Array([ True, False, True], dtype=bool)
However, there are no guarantees about the results of any expression involvinga view such as this:
jnp.array([1,2,3],dtype=jnp.int8).view(jnp.bool_).In particular, the results may change between JAX releases and depending onthe platform. To safely convert such an array to a boolean array, compare itwith0:>>>jnp.array([1,2,0],dtype=jnp.int8)!=0Array([ True, True, False], dtype=bool)
- Parameters:
dtype (DTypeLike |None) – An optional output dtype. If not specified, the output dtype is thesame as the input dtype.
type (None) – Not implemented; accepted for NumPy compatibility.
self (Array)
- Returns:
The array, viewed as the new dtype. Unlike NumPy, the array may or may notbe a copy of the input array.
- Return type:
