Rate this Page

torch.addmm#

torch.addmm(input,mat1,mat2,out_dtype=None,*,beta=1,alpha=1,out=None)Tensor#

Performs a matrix multiplication of the matricesmat1 andmat2.The matrixinput is added to the final result.

Ifmat1 is a(n×m)(n \times m) tensor,mat2 is a(m×p)(m \times p) tensor, theninput must bebroadcastable with a(n×p)(n \times p) tensorandout will be a(n×p)(n \times p) tensor.

alpha andbeta are scaling factors on matrix-vector product betweenmat1 andmat2 and the added matrixinput respectively.

out=β input+α (mat1i@mat2i)\text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i)

Ifbeta is 0, then the content ofinput will be ignored, andnan andinf init will not be propagated.

For inputs of typeFloatTensor orDoubleTensor, argumentsbeta andalpha must be real numbers, otherwise they should be integers.

This operation has support for arguments withsparse layouts. Ifinput is sparse the result will have the same layout and ifoutis provided it must have the same layout asinput.

Warning

Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported,or may not have autograd support. If you notice missing functionality pleaseopen a feature request.

This operator supportsTensorFloat32.

On certain ROCm devices, when using float16 inputs this module will usedifferent precision for backward.

Parameters
  • input (Tensor) – matrix to be added

  • mat1 (Tensor) – the first matrix to be matrix multiplied

  • mat2 (Tensor) – the second matrix to be matrix multiplied

  • out_dtype (dtype,optional) – the dtype of the output tensor,Supported only on CUDA and for torch.float32 giventorch.float16/torch.bfloat16 input dtypes

Keyword Arguments
  • beta (Number,optional) – multiplier forinput (β\beta)

  • alpha (Number,optional) – multiplier format1@mat2mat1 @ mat2 (α\alpha)

  • out (Tensor,optional) – the output tensor.

Example:

>>>M=torch.randn(2,3)>>>mat1=torch.randn(2,3)>>>mat2=torch.randn(3,3)>>>torch.addmm(M,mat1,mat2)tensor([[-4.8716,  1.4671, -1.3746],        [ 0.7573, -3.9555, -2.8681]])