jax.numpy.sign
Contents
jax.numpy.sign#
- jax.numpy.sign(x,/)[source]#
Return an element-wise indication of sign of the input.
JAX implementation of
numpy.sign.The sign of
xfor real-valued input is:\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0\end{cases}\end{split}\]For complex valued input,
jnp.signreturns a unit vector representing thephase. For generalized case, the sign ofxis given by:\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0\end{cases}\end{split}\]- Parameters:
x (ArrayLike) – input array or scalar.
- Returns:
An array with same shape and dtype as
xcontaining the sign indication.- Return type:
See also
jax.numpy.positive(): Returns element-wise positive values of the input.jax.numpy.negative(): Returns element-wise negative values of the input.
Examples
For Real-valued inputs:
>>>x=jnp.array([0.,-3.,7.])>>>jnp.sign(x)Array([ 0., -1., 1.], dtype=float32)
For complex-inputs:
>>>x1=jnp.array([1,3+4j,5j])>>>jnp.sign(x1)Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)
Contents
