jax.numpy.roots
Contents
jax.numpy.roots#
- jax.numpy.roots(p,*,strip_zeros=True)[source]#
Returns the roots of a polynomial given the coefficients
p.JAX implementations of
numpy.roots().- Parameters:
p (ArrayLike) – Array of polynomial coefficients having rank-1.
strip_zeros (bool) – bool, default=True. If True, then leading zeros in thecoefficients will be stripped, similar to
numpy.roots(). If set toFalse, leading zeros will not be stripped, and undefined roots will berepresented by NaN values in the function output.strip_zerosmust beset toFalsefor the function to be compatible withjax.jit()andother JAX transformations.
- Returns:
An array containing the roots of the polynomial.
- Return type:
Note
Unlike
np.rootsof this function, thejnp.rootsreturns the rootsin a complex array regardless of the values of the roots.See also
jax.numpy.poly(): Finds the polynomial coefficients of the givensequence of roots.jax.numpy.polyfit(): Least squares polynomial fit to data.jax.numpy.polyval(): Evaluate a polynomial at specific values.
Examples
>>>coeffs=jnp.array([0,1,2])
The default behavior matches numpy and strips leading zeros:
>>>jnp.roots(coeffs)Array([-2.+0.j], dtype=complex64)
With
strip_zeros=False, extra roots are set to NaN:>>>jnp.roots(coeffs,strip_zeros=False)Array([-2. +0.j, nan+nanj], dtype=complex64)
