torch.tensordot#
- torch.tensordot(a,b,dims=2,out=None)[source]#
Returns a contraction of a and b over multiple dimensions.
tensordotimplements a generalized matrix product.- Parameters
When called with a non-negative integer argument
dims=, andthe number of dimensions ofaandbis and,respectively,tensordot()computesWhen called with
dimsof the list form, the given dimensions will be contractedin place of the last ofaand the first of. The sizesin these dimensions must match, buttensordot()will deal with broadcasteddimensions.Examples:
>>>a=torch.arange(60.).reshape(3,4,5)>>>b=torch.arange(24.).reshape(4,3,2)>>>torch.tensordot(a,b,dims=([1,0],[0,1]))tensor([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]])>>>a=torch.randn(3,4,5,device='cuda')>>>b=torch.randn(4,5,6,device='cuda')>>>c=torch.tensordot(a,b,dims=2).cpu()tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])>>>a=torch.randn(3,5,4,6)>>>b=torch.randn(6,4,5,3)>>>torch.tensordot(a,b,dims=([2,1,3],[1,2,0]))tensor([[ 7.7193, -2.4867, -10.3204], [ 1.5513, -14.4737, -6.5113], [ -0.2850, 4.2573, -3.5997]])