Reputation: 303
Let's consider example. I have a tensor of size (10, 3). I want to sum first 3 rows, next 2 rows and 5 next rows by 0 axis. For example from:
t = torch.ones([10, 3])
I want to get:
[
[3.0, 3.0, 3.0],
[2.0, 2.0, 2.0],
[5.0, 5.0, 5.0],
]
I want to specify a tensor with values and a tensor with part sizes and possibly axis and get a tensor summed along this axis by parts of specified sizes. How can I achieve that?
Upvotes: 1
Views: 1628
Reputation: 3938
A non-looping* approach is:
The asterisk is because we can do the pad and stack with torch.nn.utils.rnn.pad_sequence
. This function is implemented for quick computation making use of C-style threading for parallelized computation, which is likely much quicker than a for
loop (though it's possible the C implementation of this function actually does use some looping). The one downside to this approach is that you very temporarily need to allocate memory for the padding zeros, so if the largest and smallest chunk sizes are vastly different this could cause a memory issue.
Ok, that's the explanation, here's the code:
def chunk_sum(X,chunk_lengths,dim = 0):
# X has size [A,B,C...]
# The sum of chunk_lengths must equal X.shape[dim]
# X_chunk is a list of tensors of size [*varies*,B,C...] if dim=0, and so on
X_chunk = torch.split.split(X,chunk_lengths,dim = dim)
# X_chunk_pad is size [len(chunk_lengths),max(chunk_lengths),B,C,...] if dim=0 and so on
X_chunk_pad = torch.nn.utils.rnn.pad_sequence(X_chunk,batch_first = True)
# X_sum is size[len(chunk_lengths),B,C,...]
X_sum = X_chunk_pad.sum(dim = 1+dim) # add one because we added batch dimension first
# lastly, we need to permute dimensions so that batch (currently dimension 0) replaces dim
X_sum = torch.transpose(X_sum,0,dim)
return X_sum
Upvotes: 2
Reputation: 303
Following the great idea of @ben-grossmann I modified it a little to use sparse tensor and make it more efficient. And implemented it as a function:
def sum_var_parts(t, lens):
t_size_0 = t.size(0)
ind_x = torch.repeat_interleave(torch.arange(lens.size(0)), lens)
indices = torch.cat(
[
torch.unsqueeze(ind_x, dim=0),
torch.unsqueeze(torch.arange(t_size_0), dim=0)
],
dim=0
)
M = torch.sparse_coo_tensor(
indices,
torch.ones(t_size_0, dtype=torch.float32),
size=[lens.size(0), t_size_0]
)
return M @ t
Upvotes: 2
Reputation: 4772
One approach is to use matrix multiplication. For example, consider the following.
import torch
lens = [3,2,5]
t = torch.ones([10, 3])
m,_ = t.size()
M = torch.zeros([len(lens),m])
ind_x = [i for i,L in enumerate(lens) for _ in range(L)]
M[ind_x,range(m)] = 1
ans = M@t
The resulting tensor ans
:
tensor([[3., 3., 3.],
[2., 2., 2.],
[5., 5., 5.]])
Another approach:
lens = [2,3,5]
t = torch.ones([10, 3])
lens = torch.cumsum(torch.tensor([0]+lens),dim=0)
ans = torch.stack([torch.sum(t[a:b],axis = 0) for a,b in zip(lens,lens[1:])])
Upvotes: 0