Rate this Page

torch.diag_embed#

torch.diag_embed(input,offset=0,dim1=-2,dim2=-1)Tensor#

Creates a tensor whose diagonals of certain 2D planes (specified bydim1 anddim2) are filled byinput.To facilitate creating batched diagonal matrices, the 2D planes formed bythe last two dimensions of the returned tensor are chosen by default.

The argumentoffset controls which diagonal to consider:

  • Ifoffset = 0, it is the main diagonal.

  • Ifoffset > 0, it is above the main diagonal.

  • Ifoffset < 0, it is below the main diagonal.

The size of the new matrix will be calculated to make the specified diagonalof the size of the last input dimension.Note that foroffset other than00, the order ofdim1anddim2 matters. Exchanging them is equivalent to changing thesign ofoffset.

Applyingtorch.diagonal() to the output of this function withthe same arguments yields a matrix identical to input. However,torch.diagonal() has different default dimensions, so thoseneed to be explicitly specified.

Parameters:
  • input (Tensor) – the input tensor. Must be at least 1-dimensional.

  • offset (int,optional) – which diagonal to consider. Default: 0(main diagonal).

  • dim1 (int,optional) – first dimension with respect to which totake diagonal. Default: -2.

  • dim2 (int,optional) – second dimension with respect to which totake diagonal. Default: -1.

Example:

>>>a=torch.randn(2,3)>>>torch.diag_embed(a)tensor([[[ 1.5410,  0.0000,  0.0000],         [ 0.0000, -0.2934,  0.0000],         [ 0.0000,  0.0000, -2.1788]],        [[ 0.5684,  0.0000,  0.0000],         [ 0.0000, -1.0845,  0.0000],         [ 0.0000,  0.0000, -1.3986]]])>>>torch.diag_embed(a,offset=1,dim1=0,dim2=2)tensor([[[ 0.0000,  1.5410,  0.0000,  0.0000],         [ 0.0000,  0.5684,  0.0000,  0.0000]],        [[ 0.0000,  0.0000, -0.2934,  0.0000],         [ 0.0000,  0.0000, -1.0845,  0.0000]],        [[ 0.0000,  0.0000,  0.0000, -2.1788],         [ 0.0000,  0.0000,  0.0000, -1.3986]],        [[ 0.0000,  0.0000,  0.0000,  0.0000],         [ 0.0000,  0.0000,  0.0000,  0.0000]]])