Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.numpy.cross

Contents

jax.numpy.cross#

jax.numpy.cross(a,b,axisa=-1,axisb=-1,axisc=-1,axis=None)[source]#

Compute the (batched) cross product of two arrays.

JAX implementation ofnumpy.cross().

This computes the 2-dimensional or 3-dimensional cross product,

\[c = a \times b\]

In 3 dimensions,c is a length-3 array. In 2 dimensions,c isa scalar.

Parameters:
  • a – N-dimensional array.a.shape[axisa] indicates the dimension ofthe cross product, and must be 2 or 3.

  • b – N-dimensional array. Must haveb.shape[axisb]==a.shape[axisb],and other dimensions ofa andb must be broadcast compatible.

  • axisa (int) – specicy the axis ofa along which to compute the cross product.

  • axisb (int) – specicy the axis ofb along which to compute the cross product.

  • axisc (int) – specicy the axis ofc along which the cross product resultwill be stored.

  • axis (int |None) – if specified, this overridesaxisa,axisb, andaxiscwith a single value.

Returns:

The arrayc containing the (batched) cross product ofa andbalong the specified axes.

See also

Examples

A 2-dimensional cross product returns a scalar:

>>>a=jnp.array([1,2])>>>b=jnp.array([3,4])>>>jnp.cross(a,b)Array(-2, dtype=int32)

A 3-dimensional cross product returns a length-3 vector:

>>>a=jnp.array([1,2,3])>>>b=jnp.array([4,5,6])>>>jnp.cross(a,b)Array([-3,  6, -3], dtype=int32)

With multi-dimensional inputs, the cross-product is computed alongthe last axis by default. Here’s a batched 3-dimensional crossproduct, operating on the rows of the inputs:

>>>a=jnp.array([[1,2,3],...[3,4,3]])>>>b=jnp.array([[2,3,2],...[4,5,6]])>>>jnp.cross(a,b)Array([[-5,  4, -1],       [ 9, -6, -1]], dtype=int32)

Specifying axis=0 makes this a batched 2-dimensional cross product,operating on the columns of the inputs:

>>>jnp.cross(a,b,axis=0)Array([-2, -2, 12], dtype=int32)

Equivalently, we can independently specify the axis of the inputsaandb and the outputc:

>>>jnp.cross(a,b,axisa=0,axisb=0,axisc=0)Array([-2, -2, 12], dtype=int32)
Contents

[8]ページ先頭

©2009-2025 Movatter.jp