jax.numpy.power
Contents
jax.numpy.power#
- jax.numpy.power(x1,x2,/)[source]#
Calculate element-wise base
x1exponential ofx2.JAX implementation of
numpy.power.- Parameters:
x1 (ArrayLike) – scalar or array. Specifies the bases.
x2 (ArrayLike) – scalar or array. Specifies the exponent.
x1andx2should eitherhave same shape or be broadcast compatible.
- Returns:
An array containing the base
x1exponentials ofx2with same dtypeas input.- Return type:
Note
When
x2is a concrete integer scalar,jnp.powerlowers tojax.lax.integer_pow().When
x2is a traced scalar or an array,jnp.powerlowers tojax.lax.pow().jnp.powerraises aTypeErrorfor integer type raised to a concretenegative integer power. For a non-concrete power, the operation is invalidand the returned value is implementation-defined.jnp.powerreturnsnanfor negative value raised to the power ofnon-integer values.
See also
jax.lax.pow(): Computes element-wise power,\(x^y\).jax.lax.integer_pow(): Computes element-wise power\(x^y\), where\(y\) is a fixed integer.jax.numpy.float_power(): Computes the first array raised to the powerof second array, element-wise, by promoting to the inexact dtype.jax.numpy.pow(): Computes the first array raised to the power of secondarray, element-wise.
Examples
Inputs with scalar integers:
>>>jnp.power(4,3)Array(64, dtype=int32, weak_type=True)
Inputs with same shape:
>>>x1=jnp.array([2,4,5])>>>x2=jnp.array([3,0.5,2])>>>jnp.power(x1,x2)Array([ 8., 2., 25.], dtype=float32)
Inputs with broadcast compatibility:
>>>x3=jnp.array([-2,3,1])>>>x4=jnp.array([[4,1,6],...[1.3,3,5]])>>>jnp.power(x3,x4)Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)
