jax.numpy.frompyfunc
Contents
jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func,/,nin,nout,*,identity=None)[source]#
Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
- Parameters:
- Returns:
jax.numpy.ufunc wrapper of func.
- Return type:
wrapped
Examples
Here is an example of creating a ufunc similar to
jax.numpy.add:>>>importoperator>>>add=frompyfunc(operator.add,nin=2,nout=1,identity=0)
Now all the standard
jax.numpy.ufuncmethods are available:>>>x=jnp.arange(4)>>>add(x,10)Array([10, 11, 12, 13], dtype=int32)>>>add.outer(x,x)Array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32)>>>add.reduce(x)Array(6, dtype=int32)>>>add.accumulate(x)Array([0, 1, 3, 6], dtype=int32)>>>add.at(x,1,10,inplace=False)Array([ 0, 11, 2, 3], dtype=int32)
Contents
