jax.numpy.cov
Contents
jax.numpy.cov#
- jax.numpy.cov(m,y=None,rowvar=True,bias=False,ddof=None,fweights=None,aweights=None,dtype=None)[source]#
Estimate the weighted sample covariance.
JAX implementation of
numpy.cov().The covariance\(C_{ij}\) between variablei and variablej is definedas
\[cov[X_i, X_j] = E[(X_i - E[X_i])(X_j - E[X_j])]\]Given an array ofN observations of the variables\(X_i\) and\(X_j\),this can be estimated via the sample covariance:
\[C_{ij} = \frac{1}{N - 1} \sum_{n=1}^N (X_{in} - \overline{X_i})(X_{jn} - \overline{X_j})\]Where\(\overline{X_i} = \frac{1}{N} \sum_{k=1}^N X_{ik}\) is the mean of theobservations.
- Parameters:
m (ArrayLike) – array of shape
(M,N)(ifrowvaris True), or(N,M)(ifrowvaris False) representingNobservations ofMvariables.mmay also be one-dimensional, representingNobservations of asingle variable.y (ArrayLike |None) – optional set of additional observations, with the same form as
m. Ifspecified, thenyis combined withm, i.e. for the defaultrowvar=Truecase,mbecomesjnp.vstack([m,y]).rowvar (bool) – if True (default) then each row of
mrepresents a variable. IfFalse, then each column represents a variable.bias (bool) – if False (default) then normalize the covariance by
N-1. If True,then normalize the covariance byNddof (int |None) – specify the degrees of freedom. Defaults to
1ifbiasis False,or to0ifbiasis True.fweights (ArrayLike |None) – optional array of integer frequency weights of shape
(N,). Thisis an absolute weight specifying the number of times each observation isincluded in the computation.aweights (ArrayLike |None) – optional array of observation weights of shape
(N,). This isa relative weight specifying the “importance” of each observation. In theddof=0case, it is equivalent to assigning probabilities to eachobservation.dtype (DTypeLike |None) – optional data type of the result. Must be a float or complex type;if not specified, it will be determined based on the dtype of the input.
- Returns:
A covariance matrix of shape
(M,M), or a scalar with shape()ifM=1.- Return type:
See also
jax.numpy.corrcoef(): compute the correlation coefficient, a normalizedversion of the covariance matrix.
Examples
Consider these observations of two variables that correlate perfectly.The covariance matrix in this case is a 2x2 matrix of ones:
>>>x=jnp.array([[0,1,2],...[0,1,2]])>>>jnp.cov(x)Array([[1., 1.], [1., 1.]], dtype=float32)
Now consider these observations of two variables that are perfectlyanti-correlated. The covariance matrix in this case has
-1in theoff-diagonal:>>>x=jnp.array([[-1,0,1],...[1,0,-1]])>>>jnp.cov(x)Array([[ 1., -1.], [-1., 1.]], dtype=float32)
Equivalently, these sequences can be specified as separate arguments,in which case they are stacked before continuing the computation.
>>>x=jnp.array([-1,0,1])>>>y=jnp.array([1,0,-1])>>>jnp.cov(x,y)Array([[ 1., -1.], [-1., 1.]], dtype=float32)
In general, the entries of the covariance matrix may be any positiveor negative real value. For example, here is the covariance of 100points drawn from a 3-dimensional standard normal distribution:
>>>key=jax.random.key(0)>>>x=jax.random.normal(key,shape=(3,100))>>>withjnp.printoptions(precision=2):...print(jnp.cov(x))[[0.9 0.03 0.1 ] [0.03 1. 0.01] [0.1 0.01 0.85]]
