jax.numpy.cross
Contents
jax.numpy.cross#
- jax.numpy.cross(a,b,axisa=-1,axisb=-1,axisc=-1,axis=None)[source]#
Compute the (batched) cross product of two arrays.
JAX implementation of
numpy.cross().This computes the 2-dimensional or 3-dimensional cross product,
\[c = a \times b\]In 3 dimensions,
cis a length-3 array. In 2 dimensions,cisa scalar.- Parameters:
a – N-dimensional array.
a.shape[axisa]indicates the dimension ofthe cross product, and must be 2 or 3.b – N-dimensional array. Must have
b.shape[axisb]==a.shape[axisb],and other dimensions ofaandbmust be broadcast compatible.axisa (int) – specicy the axis of
aalong which to compute the cross product.axisb (int) – specicy the axis of
balong which to compute the cross product.axisc (int) – specicy the axis of
calong which the cross product resultwill be stored.axis (int |None) – if specified, this overrides
axisa,axisb, andaxiscwith a single value.
- Returns:
The array
ccontaining the (batched) cross product ofaandbalong the specified axes.
See also
jax.numpy.linalg.cross(): an array API compatible function forcomputing cross products over 3-vectors.
Examples
A 2-dimensional cross product returns a scalar:
>>>a=jnp.array([1,2])>>>b=jnp.array([3,4])>>>jnp.cross(a,b)Array(-2, dtype=int32)
A 3-dimensional cross product returns a length-3 vector:
>>>a=jnp.array([1,2,3])>>>b=jnp.array([4,5,6])>>>jnp.cross(a,b)Array([-3, 6, -3], dtype=int32)
With multi-dimensional inputs, the cross-product is computed alongthe last axis by default. Here’s a batched 3-dimensional crossproduct, operating on the rows of the inputs:
>>>a=jnp.array([[1,2,3],...[3,4,3]])>>>b=jnp.array([[2,3,2],...[4,5,6]])>>>jnp.cross(a,b)Array([[-5, 4, -1], [ 9, -6, -1]], dtype=int32)
Specifying axis=0 makes this a batched 2-dimensional cross product,operating on the columns of the inputs:
>>>jnp.cross(a,b,axis=0)Array([-2, -2, 12], dtype=int32)
Equivalently, we can independently specify the axis of the inputs
aandband the outputc:>>>jnp.cross(a,b,axisa=0,axisb=0,axisc=0)Array([-2, -2, 12], dtype=int32)
