Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.fromfunction

jax.numpy.fromfunction#

jax.numpy.fromfunction(function,shape,*,dtype=<class'float'>,**kwargs)[source]#

Create an array from a function applied over indices.

JAX implementation ofnumpy.fromfunction(). The JAX implementationdiffers in that it dispatches viajax.vmap(), and so unlike in NumPythe function logically operates on scalar inputs, and need not explicitlyhandle broadcasted inputs (SeeExamples below).

Parameters:
  • function (Callable[...,Array]) – a function that takesN dynamic scalars and outputs a scalar.

  • shape (Any) – a length-N tuple of integers specifying the output shape.

  • dtype (DTypeLike) – optionally specify the dtype of the inputs. Defaults to floating-point.

  • kwargs – additional keyword arguments are passed statically tofunction.

Returns:

An array of shapeshape iffunction returns a scalar, or in generala pytree of arrays with leading dimensionsshape, as determined by theoutput offunction.

Return type:

Array

See also

Examples

Generate a multiplication table of a given shape:

>>>jnp.fromfunction(jnp.multiply,shape=(3,6),dtype=int)Array([[ 0,  0,  0,  0,  0,  0],       [ 0,  1,  2,  3,  4,  5],       [ 0,  2,  4,  6,  8, 10]], dtype=int32)

Whenfunction returns a non-scalar the output will have leadingdimension ofshape:

>>>deff(x):...return(x+1)*jnp.arange(3)>>>jnp.fromfunction(f,shape=(2,))Array([[0., 1., 2.],       [0., 2., 4.]], dtype=float32)

function may return multiple results, in which case each is mappedindependently:

>>>deff(x,y):...returnx+y,x*y>>>x_plus_y,x_times_y=jnp.fromfunction(f,shape=(3,5))>>>print(x_plus_y)[[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]]>>>print(x_times_y)[[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]

The JAX implementation differs slightly from NumPy’s implementation. Innumpy.fromfunction(), the function is expected to explicitly operateelement-wise on the full grid of input values:

>>>deff(x,y):...print(f"{x.shape= }\n{y.shape= }")...returnx+y...>>>np.fromfunction(f,(2,3))x.shape = (2, 3)y.shape = (2, 3)array([[0., 1., 2.],       [1., 2., 3.]])

Injax.numpy.fromfunction(), the function is vectorized viajax.vmap(), and so is expected to operate on scalar values:

>>>jnp.fromfunction(f,(2,3))x.shape = ()y.shape = ()Array([[0., 1., 2.],       [1., 2., 3.]], dtype=float32)

[8]ページ先頭

©2009-2025 Movatter.jp