Reputation: 35
I'm trying to obtain a matrix, where each element is calculated as follows:
X = torch.ones(batch_size, dim)
X_ = torch.ones(batch_size, dim)
Y = torch.ones(batch_size, dim)
M = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
for j in range(batch_size):
M[i, j] = ((X[i] - X_[i] * Y[j])**2).sum()
It's very slow to calculate M
element-wise, is there any suggestion about how to use matrix multiplication to replace the for loops?
Thanks.
Upvotes: 1
Views: 583
Reputation: 114986
If you want to sum()
over dim, you can "lift" your 2D problem to 3D and sum there:
M = ((X[:, None, :] - X_[:, None, :] * Y[None, ...])**2).sum(dim=2)
How it works:
X[:, None, :]
and X_[:, None, :]
are 3D of size (batch_size, 1, dim)
, and Y[None, ...]
is of size (1, batch_size, dim)
.
When multiplying X_[:, None, :] * Y[None, ...]
pytorch broadcasts the dimensions of size 1 to the appropriate dimension to get a result of size (batch_size, batch_size, dim)
.
Finally, you sum()
only over the last dimension (dim=2)
to get an output M
of size (batch_size, batch_size)
.
The trick here is done by taking advantage of broadcasting.
Upvotes: 2