Thyrix Yang
Thyrix Yang

Reputation: 105

how to vectorize the scatter-matmul operation

I have many matrices w1, w2, w3...wn with shapes (k*n1, k*n2, k*n3...k*nn) and x1, x2, x3...xn with shapes (n1*m, n2*m, n3*m...nn*m).
I want to get w1@x1, w2@x2, w3@x3 ... respectively.

The resulting matrix is multiple k*m matrices and can be concatenated into a large matrix with shape (k*n)*m.

Multiply them one by one will be slow. How to vectorize this operation?

Note: The input can be a k*(n1+n2+n3+...+nn) matrix and a (n1+n2+n3+...+nn)*m matrix, and we may use a batch index to indicate those submatrices.

This operation is related to the scatter operations implemented in pytorch_scatter, so I refer it as "scatter_matmul".

Upvotes: 3

Views: 541

Answers (2)

Felipe Mello
Felipe Mello

Reputation: 395

Please take a look at this link. Apparently DGL is working on something similar already: https://github.com/dmlc/dgl/pull/3641

Upvotes: 0

Shai
Shai

Reputation: 114866

You can vectorize your operation by creating a large block-diagonal matrix W of shape n*kx(n1+..+nn) where the w_i matrices are the blocks on the diagonal. Then you can vertically stack all x matrices into an X matrix of shape (n1+..+nn)xm. Multiplying the block diagonal W with the vertical stack of all x matrices, X:

Y = W @ X

results with Y of shape (k*n)xm which is exactly the concatenated large matrix you are seeking.

If the shape of the block diagonal matrix W is too large to fit into memory, you may consider making W sparse and compute the product using torch.sparse.mm.

Upvotes: 2

Related Questions