jax.numpy.square
Contents
jax.numpy.square#
- jax.numpy.square(x,/)[source]#
Calculate element-wise square of the input array.
JAX implementation of
numpy.square.- Parameters:
x (ArrayLike) – input array or scalar.
- Returns:
An array containing the square of the elements of
x.- Return type:
Note
jnp.squareis equivalent to computingjnp.power(x,2).See also
jax.numpy.sqrt(): Calculates the element-wise non-negative square rootof the input array.jax.numpy.power(): Calculates the element-wise basex1exponentialofx2.jax.lax.integer_pow(): Computes element-wise power\(x^y\), where\(y\) is a fixed integer.jax.numpy.float_power(): Computes the first array raised to the powerof second array, element-wise, by promoting to the inexact dtype.
Examples
>>>x=jnp.array([3,-2,5.3,1])>>>jnp.square(x)Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)>>>jnp.power(x,2)Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
For integer inputs:
>>>x1=jnp.array([2,4,5,6])>>>jnp.square(x1)Array([ 4, 16, 25, 36], dtype=int32)
For complex-valued inputs:
>>>x2=jnp.array([1-3j,-1j,2])>>>jnp.square(x2)Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)
Contents
