Reputation: 105
I have many matrices w1
, w2
, w3...wn
with shapes (k*n1
, k*n2
, k*n3...k*nn
) and x1
, x2
, x3...xn
with shapes (n1*m
, n2*m
, n3*m...nn*m
).
I want to get w1@x1
, w2@x2
, w3@x3
... respectively.
The resulting matrix is multiple k*m
matrices and can be concatenated into a large matrix with shape (k*n)*m
.
Multiply them one by one will be slow. How to vectorize this operation?
Note: The input can be a k*(n1+n2+n3+...+nn)
matrix and a (n1+n2+n3+...+nn)*m
matrix, and we may use a batch index to indicate those submatrices.
This operation is related to the scatter operations implemented in pytorch_scatter
, so I refer it as "scatter_matmul
".
Upvotes: 3
Views: 541
Reputation: 395
Please take a look at this link. Apparently DGL is working on something similar already: https://github.com/dmlc/dgl/pull/3641
Upvotes: 0
Reputation: 114866
You can vectorize your operation by creating a large block-diagonal matrix W
of shape n*k
x(n1+..+nn)
where the w_i
matrices are the blocks on the diagonal. Then you can vertically stack all x
matrices into an X
matrix of shape (n1+..+nn)
xm
. Multiplying the block diagonal W
with the vertical stack of all x
matrices, X
:
Y = W @ X
results with Y
of shape (k*n)
xm
which is exactly the concatenated large matrix you are seeking.
If the shape of the block diagonal matrix W
is too large to fit into memory, you may consider making W
sparse and compute the product using torch.sparse.mm
.
Upvotes: 2