jax.numpy.right_shift
Contents
jax.numpy.right_shift#
- jax.numpy.right_shift(x1,x2,/)[source]#
Right shift the bits of
x1to the amount specified inx2.JAX implementation of
numpy.right_shift.- Parameters:
x1 (ArrayLike) – Input array, only accepts unsigned integer subtypes
x2 (ArrayLike) – The amount of bits to shift each element in
x1to the right, only acceptsinteger subtypes
- Returns:
An array-like object containing the right shifted elements of
x1by theamount specified inx2, with the same shape as the broadcasted shape ofx1andx2.- Return type:
Note
If
x1.shape!=x2.shape, they must be compatible for broadcasting to ashared shape, this shared shape will also be the shape of the output. Right shiftinga scalar x1 by scalar x2 is equivalent tox1//2**x2.Examples
>>>defprint_binary(x):...return[bin(int(val))forvalinx]
>>>x1=jnp.array([1,2,4,8])>>>print_binary(x1)['0b1', '0b10', '0b100', '0b1000']>>>x2=1>>>result=jnp.right_shift(x1,x2)>>>resultArray([0, 1, 2, 4], dtype=int32)>>>print_binary(result)['0b0', '0b1', '0b10', '0b100']
>>>x1=16>>>print_binary([x1])['0b10000']>>>x2=jnp.array([1,2,3,4])>>>result=jnp.right_shift(x1,x2)>>>resultArray([8, 4, 2, 1], dtype=int32)>>>print_binary(result)['0b1000', '0b100', '0b10', '0b1']
Contents
