jax.nn module
Contents
jax.nn module#
Common functions for neural network libraries.
Activation functions#
Rectified linear unit activation function. | |
Rectified Linear Unit 6 activation function. | |
| Sigmoid activation function. |
| Softplus activation function. |
| Sparse plus function. |
Sparse sigmoid activation function. | |
| Soft-sign activation function. |
| SiLU (aka swish) activation function. |
| SiLU (aka swish) activation function. |
| Log-sigmoid activation function. |
| Leaky rectified linear unit activation function. |
| Hard Sigmoid activation function. |
| Hard SiLU (swish) activation function |
| Hard SiLU (swish) activation function |
| Hard\(\mathrm{tanh}\) activation function. |
| Calculate element-wise hyperbolic tangent of input. |
| Exponential linear unit activation function. |
| Continuously-differentiable exponential linear unit activation. |
| Scaled exponential linear unit activation. |
| Gaussian error linear unit activation function. |
| Gated linear unit activation function. |
| Squareplus activation function. |
| Mish activation function. |
| Identity activation function. |
Other functions#
| Softmax function. |
| Log-Softmax function. |
| Log mean exp. |
Log-sum-exp reduction. | |
| Standardizes input to zero mean and unit variance. |
| One-hot encodes the given indices. |
Scaled dot product attention function. | |
| Scaled matrix multiplication function. |
| Get quantization configs for scaled_dot_general. |
| Scaled dot general operation. |
Numerically stable calculation of\(\log(1 - \exp(-x))\). |
