jax.numpy.copy
Contents
jax.numpy.copy#
- jax.numpy.copy(a,order=None)[source]#
Return a copy of the array.
JAX implementation of
numpy.copy().- Parameters:
a (ArrayLike) – arraylike object to copy
order (str |None) – not implemented in JAX
- Returns:
a copy of the input array
a.- Return type:
See also
jax.numpy.array(): create an array with or without a copy.jax.Array.copy(): same function accessed as an array method.
Examples
Since JAX arrays are immutable, in most cases explicit array copiesare not necessary. One exception is when using a function with donatedarguments (see the
donate_argnumsargument tojax.jit()).>>>f=jax.jit(lambdax:2*x,donate_argnums=0)>>>x=jnp.arange(4)>>>y=f(x)>>>print(y)[0 2 4 6]
Because we marked
xas being donated, the original array is no longeravailable:>>>print(x)Traceback (most recent call last):RuntimeError:Array has been deleted with shape=int32[4].
In situations like this, an explicit copy will let you keep access to theoriginal buffer:
>>>x=jnp.arange(4)>>>y=f(x.copy())>>>print(y)[0 2 4 6]>>>print(x)[0 1 2 3]
Contents
