Wowee
Wowee

Reputation: 11

Group PyTorch feature tensors according to labels by concatenation

I'm working on a batchable, loop and recursion free, PyTorch utility concat_aggregate for grouping rows of an input tensor x according to labels given by an index tensor. It should pad rows so that the resulting tensor is rectangular. For example,

x = torch.tensor([[5, 50], [6, 60], [7, 70], [8, 80], [9, 90], [10, 100], [11, 110], [12, 120]])
index = torch.tensor([3, 3, 1, 1, 1, 2, 3, 3])
concat_aggregate(x, index)

should output:

torch.tensor([
    [[0, 0], [0, 0], [0, 0], [0, 0]],
    [[7, 70], [8, 80], [9, 90], [0, 0]],
    [[10, 100], [0, 0], [0, 0], [0, 0]],
    [[5, 50], [6, 60], [11, 110], [12, 120]]
])

I hacked my way to this function:

def cat_aggregate(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
    # Number of groups and the number of features in each row of x
    num_groups = index.max().item() + 1
    num_features = x.size(1)
    # Compute the maximum number of elements in any group
    group_sizes = torch.zeros(num_groups, dtype=torch.long, device=x.device)
    group_sizes.index_add_(0, index, torch.ones_like(index, dtype=torch.long))
    # Prepare the output tensor, padded with zeros
    max_num_elements = group_sizes.max()
    result = torch.zeros(num_groups, max_num_elements, num_features, dtype=x.dtype, device=x.device)
    # Positions to fill in the result tensor
    positions = group_sizes.clone().fill_(0)  # Current fill position in each group
    # Fill the tensor
    for i in range(x.size(0)):
        group_id = index[i]
        result[group_id, positions[group_id]] = x[i]
        positions[group_id] += 1
    return result

which returns the correct results for 1 and 2D tensors. But, it requires iterating over x.size(0), making it at least linear in the length of x. I'm not sure if what I have is idiomatic. Does anyone here see any possible efficiency/complexity improvements or an obvious way to extend it to 2D tensors? I'm surprised such a function is missing from the PyTorch API.

Upvotes: 0

Views: 43

Answers (1)

MinhNH
MinhNH

Reputation: 564

This should be equivalent to your function without using a for loop

def cat_aggregate(x, index):
    index_count = torch.bincount(index)
    fill_count = index_count.max() - index_count
    # fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(),1) ## <- Only support 2D tensor
    fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(),*([1]*(len(x.shape)-1))) ## <- change this to make the function takes in arbitrary shape
    fill_index = torch.range(0, fill_count.shape[0]-1).repeat_interleave(fill_count)
    index_ = torch.cat([index, fill_index], dim = 0)
    x_ = torch.cat([x, fill_zeros], dim = 0)
    # x_ = x_[torch.argsort(index_)].view(index_count.shape[0], index_count.max(), -1) ## <- Only support 2D tensor
    x_ = x_[torch.argsort(index_)].view(index_count.shape[0], index_count.max(), *x.shape[1:]) ## <- change this to make the function takes in arbitrary shape
    return x_

Output:

tensor([[[  0,   0],
         [  0,   0],
         [  0,   0],
         [  0,   0]],

        [[  7,  70],
         [  8,  80],
         [  9,  90],
         [  0,   0]],

        [[ 10, 100],
         [  0,   0],
         [  0,   0],
         [  0,   0]],

        [[  5,  50],
         [  6,  60],
         [ 11, 110],
         [ 12, 120]]])

Upvotes: 0

Related Questions