jax.numpy.ufunc
Contents
jax.numpy.ufunc#
- classjax.numpy.ufunc(func,/,nin,nout,*,name=None,nargs=None,identity=None,call=None,reduce=None,accumulate=None,at=None,reduceat=None)#
Universal functions which operation element-by-element on arrays.
JAX implementation of
numpy.ufunc.This is a class for JAX-backed implementations of NumPy’s ufunc APIs.Most users will never need to instantiate
ufunc, but ratherwill use the pre-defined ufuncs injax.numpy.For constructing your own ufuncs, see
jax.numpy.frompyfunc().Examples
Universal functions are functions that apply element-wise to broadcastedarrays, but they also come with a number of extra attributes and methods.
As an example, consider the function
jax.numpy.add. The objectacts as a function that applies addition to broadcasted arrays in anelement-wise manner:>>>x=jnp.array([1,2,3,4,5])>>>jnp.add(x,1)Array([2, 3, 4, 5, 6], dtype=int32)
Each
ufuncobject includes a number of attributes that describeits behavior:>>>jnp.add.nin# number of inputs2>>>jnp.add.nout# number of outputs1>>>jnp.add.identity# identity value, or None if no identity exists0
Binary ufuncs like
jax.numpy.addinclude number of methods toapply the function to arrays in different manners.The
outer()method applies the function to thepair-wise outer-product of the input array values:>>>jnp.add.outer(x,x)Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
The
ufunc.reduce()method performs a reduction over the array.For example,jnp.add.reduce()is equivalent tojnp.sum:>>>jnp.add.reduce(x)Array(15, dtype=int32)
The
ufunc.accumulate()method performs a cumulative reductionover the array. For example,jnp.add.accumulate()is equivalenttojax.numpy.cumulative_sum():>>>jnp.add.accumulate(x)Array([ 1, 3, 6, 10, 15], dtype=int32)
The
ufunc.at()method applies the function at particular indices in thearray; forjnp.addthe computation is similar tojax.lax.scatter_add():>>>jnp.add.at(x,0,100,inplace=False)Array([101, 2, 3, 4, 5], dtype=int32)
And the
ufunc.reduceat()method performs a number ofreduceoperations between specified indices of an array; forjnp.addtheoperation is similar tojax.ops.segment_sum():>>>jnp.add.reduceat(x,jnp.array([0,2]))Array([ 3, 12], dtype=int32)
In this case, the first element is
x[0:2].sum(), and the second elementisx[2:].sum().- Parameters:
- __init__(func,/,nin,nout,*,name=None,nargs=None,identity=None,call=None,reduce=None,accumulate=None,at=None,reduceat=None)[source]#
Methods
__init__(func, /, nin, nout, *[, name, ...])accumulate(a[, axis, dtype, out])Accumulate operation derived from binary ufunc.
at(a, indices[, b, inplace])Update elements of an array via the specified unary or binary ufunc.
outer(A, B, /)Apply the function to all pairs of values in
AandB.reduce(a[, axis, dtype, out, keepdims, ...])Reduction operation derived from a binary function.
reduceat(a, indices[, axis, dtype, out])Reduce an array between specified indices via a binary ufunc.
Attributes
identitynargsninnout
