jax.numpy.unpackbits
Contents
jax.numpy.unpackbits#
- jax.numpy.unpackbits(a,axis=None,count=None,bitorder='big')[source]#
Unpack the bits in a uint8 array.
JAX implementation of
numpy.unpackbits().- Parameters:
a (ArrayLike) – N-dimensional array of type
uint8.axis (int |None) – optional axis along which to unpack. If not specified,
awillbe flattenedcount (int |None) – specify the number of bits to unpack (if positive) or the numberof bits to trim from the end (if negative).
bitorder (str) –
"big"(default) or"little": specify whether the bit orderis big-endian or little-endian.
- Returns:
a uint8 array of unpacked bits.
- Return type:
See also
jax.numpy.packbits(): this inverse ofunpackbits.
Examples
Unpacking bits from a scalar:
>>>jnp.unpackbits(jnp.uint8(27))# big-endian by defaultArray([0, 0, 0, 1, 1, 0, 1, 1], dtype=uint8)>>>jnp.unpackbits(jnp.uint8(27),bitorder="little")Array([1, 1, 0, 1, 1, 0, 0, 0], dtype=uint8)
Compare this to the Python binary representation:
>>>0b0001101127
Unpacking bits along an axis:
>>>vals=jnp.array([[154],...[49]],dtype='uint8')>>>bits=jnp.unpackbits(vals,axis=1)>>>bitsArray([[1, 0, 0, 1, 1, 0, 1, 0], [0, 0, 1, 1, 0, 0, 0, 1]], dtype=uint8)
Using
packbits()to invert this:>>>jnp.packbits(bits,axis=1)Array([[154], [ 49]], dtype=uint8)
The
countkeyword letsunpackbitsserve as an inverse ofpackbitsin cases where not all bits are present:>>>bits=jnp.array([1,1,0,1,1,0,1,1,0,1,1])# 11 bits>>>vals=jnp.packbits(bits)>>>valsArray([219, 96], dtype=uint8)>>>jnp.unpackbits(vals)# 16 zero-padded bitsArray([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], dtype=uint8)>>>jnp.unpackbits(vals,count=11)# specify 11 output bitsArray([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)>>>jnp.unpackbits(vals,count=-5)# specify 5 bits to be trimmedArray([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)
