Reputation: 168
This question is related to the new paper: Big Bird: Transformers for Longer Sequences. Mainly, about the implementation of the Sparse Attention (that is specified in the Supplemental material, part D). Currently, I am trying to implement it in PyTorch.
They suggest a new way to speed up the computation by blocking the original query and key matrices (see, below)
When you do the matrix multiplaciton in the step (b), you end up with something like that:
.
So I was wondering: how would you go from that representation (image above) to a sparse matrix (using PyTorch, see below)? In the paper, they just say: "simply reshape the result", and I do not know any easy ways to do so (especially, when I have multiple blocks in different positions (see step (c) on the first image).
RESOLUTION: Huggingface has an implementation of BigBird in pytorch.
Upvotes: 2
Views: 2010
Reputation: 168
I end up following the guidelines in the paper. When it comes to the unpacking of the result I use: torch.sparse_coo_tensor
EDIT: Sparse tensors are still memory-hungry! The more efficient solution is described here
Upvotes: 2