Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.ufunc

jax.numpy.ufunc#

classjax.numpy.ufunc(func,/,nin,nout,*,name=None,nargs=None,identity=None,call=None,reduce=None,accumulate=None,at=None,reduceat=None)#

Universal functions which operation element-by-element on arrays.

JAX implementation ofnumpy.ufunc.

This is a class for JAX-backed implementations of NumPy’s ufunc APIs.Most users will never need to instantiateufunc, but ratherwill use the pre-defined ufuncs injax.numpy.

For constructing your own ufuncs, seejax.numpy.frompyfunc().

Examples

Universal functions are functions that apply element-wise to broadcastedarrays, but they also come with a number of extra attributes and methods.

As an example, consider the functionjax.numpy.add. The objectacts as a function that applies addition to broadcasted arrays in anelement-wise manner:

>>>x=jnp.array([1,2,3,4,5])>>>jnp.add(x,1)Array([2, 3, 4, 5, 6], dtype=int32)

Eachufunc object includes a number of attributes that describeits behavior:

>>>jnp.add.nin# number of inputs2>>>jnp.add.nout# number of outputs1>>>jnp.add.identity# identity value, or None if no identity exists0

Binary ufuncs likejax.numpy.add include number of methods toapply the function to arrays in different manners.

Theouter() method applies the function to thepair-wise outer-product of the input array values:

>>>jnp.add.outer(x,x)Array([[ 2,  3,  4,  5,  6],       [ 3,  4,  5,  6,  7],       [ 4,  5,  6,  7,  8],       [ 5,  6,  7,  8,  9],       [ 6,  7,  8,  9, 10]], dtype=int32)

Theufunc.reduce() method performs a reduction over the array.For example,jnp.add.reduce() is equivalent tojnp.sum:

>>>jnp.add.reduce(x)Array(15, dtype=int32)

Theufunc.accumulate() method performs a cumulative reductionover the array. For example,jnp.add.accumulate() is equivalenttojax.numpy.cumulative_sum():

>>>jnp.add.accumulate(x)Array([ 1,  3,  6, 10, 15], dtype=int32)

Theufunc.at() method applies the function at particular indices in thearray; forjnp.add the computation is similar tojax.lax.scatter_add():

>>>jnp.add.at(x,0,100,inplace=False)Array([101,   2,   3,   4,   5], dtype=int32)

And theufunc.reduceat() method performs a number ofreduceoperations between specified indices of an array; forjnp.add theoperation is similar tojax.ops.segment_sum():

>>>jnp.add.reduceat(x,jnp.array([0,2]))Array([ 3, 12], dtype=int32)

In this case, the first element isx[0:2].sum(), and the second elementisx[2:].sum().

Parameters:
  • func (Callable[...,Any])

  • nin (int)

  • nout (int)

  • name (str |None)

  • nargs (int |None)

  • identity (Any)

  • call (Callable[...,Any]|None)

  • reduce (Callable[...,Any]|None)

  • accumulate (Callable[...,Any]|None)

  • at (Callable[...,Any]|None)

  • reduceat (Callable[...,Any]|None)

__init__(func,/,nin,nout,*,name=None,nargs=None,identity=None,call=None,reduce=None,accumulate=None,at=None,reduceat=None)[source]#
Parameters:
  • func (Callable[...,Any])

  • nin (int)

  • nout (int)

  • name (str |None)

  • nargs (int |None)

  • identity (Any)

  • call (Callable[...,Any]|None)

  • reduce (Callable[...,Any]|None)

  • accumulate (Callable[...,Any]|None)

  • at (Callable[...,Any]|None)

  • reduceat (Callable[...,Any]|None)

Methods

__init__(func, /, nin, nout, *[, name, ...])

accumulate(a[, axis, dtype, out])

Accumulate operation derived from binary ufunc.

at(a, indices[, b, inplace])

Update elements of an array via the specified unary or binary ufunc.

outer(A, B, /)

Apply the function to all pairs of values inA andB.

reduce(a[, axis, dtype, out, keepdims, ...])

Reduction operation derived from a binary function.

reduceat(a, indices[, axis, dtype, out])

Reduce an array between specified indices via a binary ufunc.

Attributes

identity

nargs

nin

nout


[8]ページ先頭

©2009-2025 Movatter.jp