Ruchit Patel
Ruchit Patel

Reputation: 795

Multi-dimensional tensor dot product in pytorch

I have two tensors of shapes (8, 1, 128) as follows.

q_s.shape
Out[161]: torch.Size([8, 1, 128])

p_s.shape
Out[162]: torch.Size([8, 1, 128])

Above two tensors represent a batch of eight 128 dimensional vectors. I want the dot product of batch q_s with batch p_s. How can I do this? I tried to use torch.tensordot function as follows. It works as expected as well. But it also does the extra work, which I don't want it to do. See the following example.

dt = torch.tensordot(q_s, p_s, dims=([1,2], [1,2]))

dt
Out[176]: 
tensor([[0.9051, 0.9156, 0.7834, 0.8726, 0.8581, 0.7858, 0.7881, 0.8063],
        [1.0235, 1.5533, 1.2155, 1.2048, 1.3963, 1.1310, 1.1724, 1.0639],
        [0.8762, 1.3490, 1.2923, 1.0926, 1.4703, 0.9566, 0.9658, 0.8558],
        [0.8136, 1.0611, 0.9131, 1.1636, 1.0969, 0.9443, 0.9587, 0.8521],
        [0.6104, 0.9369, 0.9576, 0.8773, 1.3042, 0.7900, 0.8378, 0.6136],
        [0.8623, 0.9678, 0.8163, 0.9727, 1.1161, 1.6464, 0.9765, 0.7441],
        [0.6911, 0.8392, 0.6931, 0.7325, 0.8239, 0.7757, 1.0456, 0.6657],
        [0.8493, 0.8174, 0.8041, 0.9013, 0.8003, 0.7451, 0.7408, 1.1771]],
       grad_fn=<AsStridedBackward>)

dt.shape
Out[177]: torch.Size([8, 8])

As we can see, this produces the tensor of size (8,8) with the dot products I want lying on the diagonal. Is there any different way to obtain a smaller required tensor of shape (8,1), which just contains the elements lying on the diagonal in above result. To be more clear, the elements lying on the diagonal are the correct required dot products we want as a dot product of two batches. Element at index [0][0] is dot product of q_s[0] and p_s[0]. Element at index [1][1] is dot product of q_s[1] and p_s[1] and so on.

Is there a better way to obtain the desired dot product in pytorch?

Upvotes: 4

Views: 3537

Answers (1)

BlackBear
BlackBear

Reputation: 22989

You can do it directly:

a = torch.rand(8, 1, 128)
b = torch.rand(8, 1, 128)

torch.sum(a * b, dim=(1, 2))
# tensor([29.6896, 30.4994, 32.9577, 30.2220, 33.9913, 35.1095, 32.3631, 30.9153])    

torch.diag(torch.tensordot(a, b, dim=([1,2], [1,2])))
# tensor([29.6896, 30.4994, 32.9577, 30.2220, 33.9913, 35.1095, 32.3631, 30.9153])

If you set axis=2 in the sum you will get a tensor with shape (8, 1).

Upvotes: 6

Related Questions