jax.numpy.sqrt
Contents
jax.numpy.sqrt#
- jax.numpy.sqrt(x,/)[source]#
Calculates element-wise non-negative square root of the input array.
JAX implementation of
numpy.sqrt.- Parameters:
x (ArrayLike) – input array or scalar.
- Returns:
An array containing the non-negative square root of the elements of
x.- Return type:
Note
For real-valued negative inputs,
jnp.sqrtproduces ananoutput.For complex-valued negative inputs,
jnp.sqrtproduces acomplexoutput.
See also
jax.numpy.square(): Calculates the element-wise square of the input.jax.numpy.power(): Calculates the element-wise basex1exponentialofx2.
Examples
>>>x=jnp.array([-8-6j,1j,4])>>>withjnp.printoptions(precision=3,suppress=True):...jnp.sqrt(x)Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64)>>>jnp.sqrt(-1)Array(nan, dtype=float32, weak_type=True)
Contents
