jax.numpy.fromfunction
Contents
jax.numpy.fromfunction#
- jax.numpy.fromfunction(function,shape,*,dtype=<class'float'>,**kwargs)[source]#
Create an array from a function applied over indices.
JAX implementation of
numpy.fromfunction(). The JAX implementationdiffers in that it dispatches viajax.vmap(), and so unlike in NumPythe function logically operates on scalar inputs, and need not explicitlyhandle broadcasted inputs (SeeExamples below).- Parameters:
function (Callable[...,Array]) – a function that takesN dynamic scalars and outputs a scalar.
shape (Any) – a length-N tuple of integers specifying the output shape.
dtype (DTypeLike) – optionally specify the dtype of the inputs. Defaults to floating-point.
kwargs – additional keyword arguments are passed statically to
function.
- Returns:
An array of shape
shapeiffunctionreturns a scalar, or in generala pytree of arrays with leading dimensionsshape, as determined by theoutput offunction.- Return type:
See also
jax.vmap(): the core transformation that thefromfunction()API is built on.
Examples
Generate a multiplication table of a given shape:
>>>jnp.fromfunction(jnp.multiply,shape=(3,6),dtype=int)Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
When
functionreturns a non-scalar the output will have leadingdimension ofshape:>>>deff(x):...return(x+1)*jnp.arange(3)>>>jnp.fromfunction(f,shape=(2,))Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
functionmay return multiple results, in which case each is mappedindependently:>>>deff(x,y):...returnx+y,x*y>>>x_plus_y,x_times_y=jnp.fromfunction(f,shape=(3,5))>>>print(x_plus_y)[[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]]>>>print(x_times_y)[[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
The JAX implementation differs slightly from NumPy’s implementation. In
numpy.fromfunction(), the function is expected to explicitly operateelement-wise on the full grid of input values:>>>deff(x,y):...print(f"{x.shape= }\n{y.shape= }")...returnx+y...>>>np.fromfunction(f,(2,3))x.shape = (2, 3)y.shape = (2, 3)array([[0., 1., 2.], [1., 2., 3.]])
In
jax.numpy.fromfunction(), the function is vectorized viajax.vmap(), and so is expected to operate on scalar values:>>>jnp.fromfunction(f,(2,3))x.shape = ()y.shape = ()Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)
