jax.numpy.poly
Contents
jax.numpy.poly#
- jax.numpy.poly(seq_of_zeros)[source]#
Returns the coefficients of a polynomial for the given sequence of roots.
JAX implementation of
numpy.poly().- Parameters:
seq_of_zeros (ArrayLike) – A scalar or an array of roots of the polynomial of shape
(M,)or(M,M).- Returns:
An array containing the coefficients of the polynomial. The dtype of theoutput is always promoted to inexact.
- Return type:
Note
jax.numpy.poly()differs fromnumpy.poly():When the input is a scalar,
np.polyraises aTypeError, whereasjnp.polytreats scalars the same as length-1 arrays.For complex-valued or square-shaped inputs,
jnp.polyalways returnscomplex coefficients, whereasnp.polymay return real or complexdepending on their values.
See also
jax.numpy.polyfit(): Least squares polynomial fit.jax.numpy.polyval(): Evaluate a polynomial at specific values.jax.numpy.roots(): Computes the roots of a polynomial for givencoefficients.
Examples
Scalar inputs:
>>>jnp.poly(1)Array([ 1., -1.], dtype=float32)
Input array with integer values:
>>>x=jnp.array([1,2,3])>>>jnp.poly(x)Array([ 1., -6., 11., -6.], dtype=float32)
Input array with complex conjugates:
>>>x=jnp.array([2,1+2j,1-2j])>>>jnp.poly(x)Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64)
Input array as square matrix with real valued inputs:
>>>x=jnp.array([[2,1,5],...[3,4,7],...[1,3,5]])>>>jnp.round(jnp.poly(x))Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)
