Stepochkin
Stepochkin

Reputation: 303

How can I sum parts of pytorch tensor of variable sizes?

Let's consider example. I have a tensor of size (10, 3). I want to sum first 3 rows, next 2 rows and 5 next rows by 0 axis. For example from:

t = torch.ones([10, 3])

I want to get:

[
    [3.0, 3.0, 3.0],
    [2.0, 2.0, 2.0],
    [5.0, 5.0, 5.0],
]

I want to specify a tensor with values and a tensor with part sizes and possibly axis and get a tensor summed along this axis by parts of specified sizes. How can I achieve that?

Upvotes: 1

Views: 1628

Answers (3)

DerekG
DerekG

Reputation: 3938

A non-looping* approach is:

  1. break tensor into subtensors based on splitting list
  2. Pad (with 0s) and stack subtensors along a new dimension
  3. Sum stack along desired dimension

The asterisk is because we can do the pad and stack with torch.nn.utils.rnn.pad_sequence. This function is implemented for quick computation making use of C-style threading for parallelized computation, which is likely much quicker than a for loop (though it's possible the C implementation of this function actually does use some looping). The one downside to this approach is that you very temporarily need to allocate memory for the padding zeros, so if the largest and smallest chunk sizes are vastly different this could cause a memory issue.

Ok, that's the explanation, here's the code:

def chunk_sum(X,chunk_lengths,dim = 0):
    # X has size [A,B,C...]
    # The sum of chunk_lengths must equal X.shape[dim]

    # X_chunk is a list of tensors of size [*varies*,B,C...] if dim=0, and so on
    X_chunk = torch.split.split(X,chunk_lengths,dim = dim)

    # X_chunk_pad is size [len(chunk_lengths),max(chunk_lengths),B,C,...] if dim=0 and so on
    X_chunk_pad = torch.nn.utils.rnn.pad_sequence(X_chunk,batch_first = True)
    # X_sum is size[len(chunk_lengths),B,C,...]
    X_sum = X_chunk_pad.sum(dim = 1+dim) # add one because we added batch dimension first


    # lastly, we need to permute dimensions so that batch (currently dimension 0) replaces dim
    X_sum = torch.transpose(X_sum,0,dim)

    return X_sum

Upvotes: 2

Stepochkin
Stepochkin

Reputation: 303

Following the great idea of @ben-grossmann I modified it a little to use sparse tensor and make it more efficient. And implemented it as a function:

def sum_var_parts(t, lens):
    t_size_0 = t.size(0)
    ind_x = torch.repeat_interleave(torch.arange(lens.size(0)), lens)
    indices = torch.cat(
        [
            torch.unsqueeze(ind_x, dim=0),
            torch.unsqueeze(torch.arange(t_size_0), dim=0)
        ],
        dim=0
    )
    M = torch.sparse_coo_tensor(
        indices,
        torch.ones(t_size_0, dtype=torch.float32),
        size=[lens.size(0), t_size_0]
    )
    return M @ t

Upvotes: 2

Ben Grossmann
Ben Grossmann

Reputation: 4772

One approach is to use matrix multiplication. For example, consider the following.

import torch

lens = [3,2,5]
t = torch.ones([10, 3])

m,_ = t.size()
M = torch.zeros([len(lens),m])
ind_x = [i for i,L in enumerate(lens) for _ in range(L)]
M[ind_x,range(m)] = 1
ans = M@t

The resulting tensor ans:

tensor([[3., 3., 3.],
        [2., 2., 2.],
        [5., 5., 5.]])

Another approach:

lens = [2,3,5]
t = torch.ones([10, 3])

lens = torch.cumsum(torch.tensor([0]+lens),dim=0)
ans = torch.stack([torch.sum(t[a:b],axis = 0) for a,b in zip(lens,lens[1:])])

Upvotes: 0

Related Questions