gspr
gspr

Reputation: 11227

Product of PyTorch tensors along arbitrary axes à la NumPy's `tensordot`

NumPy provides the very useful tensordot function. It allows you to compute the product of two ndarrays along any axes (whose sizes match). I'm having a hard time finding anything similar in PyTorch. mm works only with 2D arrays, and matmul has some undesirable broadcasting properties.

Am I missing something? Am I really meant to reshape the arrays to mimic the products I want using mm?

Upvotes: 3

Views: 7218

Answers (2)

Jacob Stern
Jacob Stern

Reputation: 4587

The original answer is totally correct, but as an update, Pytorch now supports tensordot natively. Same call signature as numpy but change axes to dims.

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)

Upvotes: 4

benjaminplanche
benjaminplanche

Reputation: 15119

As mentioned by @McLawrence, this feature is being currently discussed (issue thread).

In the meantime, you could consider torch.einsum(), e.g.:

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.einsum("ijk,jil->kl", (a, b))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)

Upvotes: 3

Related Questions