Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.frompyfunc

Contents

jax.numpy.frompyfunc#

jax.numpy.frompyfunc(func,/,nin,nout,*,identity=None)[source]#

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

Parameters:
  • func (Callable[...,Any]) – a callable that takesnin scalar arguments and returnsnout outputs.

  • nin (int) – integer specifying the number of scalar inputs

  • nout (int) – integer specifying the number of scalar outputs

  • identity (Any) – (optional) a scalar specifying the identity of the operation, if any.

Returns:

jax.numpy.ufunc wrapper of func.

Return type:

wrapped

Examples

Here is an example of creating a ufunc similar tojax.numpy.add:

>>>importoperator>>>add=frompyfunc(operator.add,nin=2,nout=1,identity=0)

Now all the standardjax.numpy.ufunc methods are available:

>>>x=jnp.arange(4)>>>add(x,10)Array([10, 11, 12, 13], dtype=int32)>>>add.outer(x,x)Array([[0, 1, 2, 3],       [1, 2, 3, 4],       [2, 3, 4, 5],       [3, 4, 5, 6]], dtype=int32)>>>add.reduce(x)Array(6, dtype=int32)>>>add.accumulate(x)Array([0, 1, 3, 6], dtype=int32)>>>add.at(x,1,10,inplace=False)Array([ 0, 11,  2,  3], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp