tachyon
tachyon

Reputation: 479

Custom Pytorch layer to apply LSTM on each group

I have a N × F tensor with features and a N × 1 tensor with group index. I want to design a custom pytorch layer which will apply LSTM on each group with sorted features. I have mentioned LSTM with sorted group features as an example, hypothetically it can be anything which supports variable length input or sequence. Please refer to the image below for visual interpretation of the problem.


Grouped Operation


The obvious approach would be calling a LSTM layer for each unique group but that would be inefficient. Is there any better way to do it?

Upvotes: 0

Views: 395

Answers (1)

KonstantinosKokos
KonstantinosKokos

Reputation: 3453

You can certainly parallelize the LSTM application -- the problem is indexing the feature tensor efficiently. The best thing I could come up with (I use something similar for my own stuff) would be to list comprehend over the unique group ids to make a list of variable-length tensors, then pad them over and run the LSTM on top.

In code:

import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

n = 13
f = 77
n_groups = 3

xs = torch.rand(n, f)
ids = torch.randint(low=0, high=n_groups, size=(n,))


def groupbyid(xs: Tensor, ids: Tensor, batch_first: bool, 
              padding_value: int = 0) -> Tensor:
   return pad_sequence([xs[ids==idx] for idx in ids.unique()], 
                       batch_first=batch_first, 
                       padding_value=padding_value)

grouped = groupbyid(xs, ids)
print(grouped.shape)  
# torch.Size([3, 5, 77]) 

You can then apply your LSTM in parallel over the n_groups dimension on the grouped Tensor.

Note that you will also need to inspect the content of ids.unique() to assign each LSTM output to its corresponding group id, but this is easy to write and depends on your application.

Upvotes: 1

Related Questions