Reputation: 11
I'm working on a batchable, loop and recursion free, PyTorch utility concat_aggregate
for grouping rows of an input tensor x
according to labels given by an index
tensor. It should pad rows so that the resulting tensor is rectangular. For example,
x = torch.tensor([[5, 50], [6, 60], [7, 70], [8, 80], [9, 90], [10, 100], [11, 110], [12, 120]])
index = torch.tensor([3, 3, 1, 1, 1, 2, 3, 3])
concat_aggregate(x, index)
should output:
torch.tensor([
[[0, 0], [0, 0], [0, 0], [0, 0]],
[[7, 70], [8, 80], [9, 90], [0, 0]],
[[10, 100], [0, 0], [0, 0], [0, 0]],
[[5, 50], [6, 60], [11, 110], [12, 120]]
])
I hacked my way to this function:
def cat_aggregate(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
# Number of groups and the number of features in each row of x
num_groups = index.max().item() + 1
num_features = x.size(1)
# Compute the maximum number of elements in any group
group_sizes = torch.zeros(num_groups, dtype=torch.long, device=x.device)
group_sizes.index_add_(0, index, torch.ones_like(index, dtype=torch.long))
# Prepare the output tensor, padded with zeros
max_num_elements = group_sizes.max()
result = torch.zeros(num_groups, max_num_elements, num_features, dtype=x.dtype, device=x.device)
# Positions to fill in the result tensor
positions = group_sizes.clone().fill_(0) # Current fill position in each group
# Fill the tensor
for i in range(x.size(0)):
group_id = index[i]
result[group_id, positions[group_id]] = x[i]
positions[group_id] += 1
return result
which returns the correct results for 1 and 2D tensors. But, it requires iterating over x.size(0)
, making it at least linear in the length of x
. I'm not sure if what I have is idiomatic. Does anyone here see any possible efficiency/complexity improvements or an obvious way to extend it to 2D tensors? I'm surprised such a function is missing from the PyTorch API.
Upvotes: 0
Views: 43
Reputation: 564
This should be equivalent to your function without using a for loop
def cat_aggregate(x, index):
index_count = torch.bincount(index)
fill_count = index_count.max() - index_count
# fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(),1) ## <- Only support 2D tensor
fill_zeros = torch.zeros_like(x[0]).repeat(fill_count.sum(),*([1]*(len(x.shape)-1))) ## <- change this to make the function takes in arbitrary shape
fill_index = torch.range(0, fill_count.shape[0]-1).repeat_interleave(fill_count)
index_ = torch.cat([index, fill_index], dim = 0)
x_ = torch.cat([x, fill_zeros], dim = 0)
# x_ = x_[torch.argsort(index_)].view(index_count.shape[0], index_count.max(), -1) ## <- Only support 2D tensor
x_ = x_[torch.argsort(index_)].view(index_count.shape[0], index_count.max(), *x.shape[1:]) ## <- change this to make the function takes in arbitrary shape
return x_
Output:
tensor([[[ 0, 0],
[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[ 7, 70],
[ 8, 80],
[ 9, 90],
[ 0, 0]],
[[ 10, 100],
[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[ 5, 50],
[ 6, 60],
[ 11, 110],
[ 12, 120]]])
Upvotes: 0