Reputation: 11227
NumPy provides the very useful tensordot
function. It allows you to compute the product of two ndarrays
along any axes (whose sizes match). I'm having a hard time finding anything similar in PyTorch. mm
works only with 2D arrays, and matmul
has some undesirable broadcasting properties.
Am I missing something? Am I really meant to reshape the arrays to mimic the products I want using mm
?
Upvotes: 3
Views: 7218
Reputation: 4587
The original answer is totally correct, but as an update, Pytorch now supports tensordot
natively. Same call signature as numpy but change axes
to dims
.
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
Upvotes: 4
Reputation: 15119
As mentioned by @McLawrence, this feature is being currently discussed (issue thread).
In the meantime, you could consider torch.einsum()
, e.g.:
import torch
import numpy as np
a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]
a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)
Upvotes: 3