ashered
ashered

Reputation: 79

How to efficiently repeat tensor element variable of time in pytorch?

For example, if I have a tensor A = [[1,1,1], [2,2,2], [3,3,3]], and B = [1,2,3]. How do I get C = [[1,1,1], [2,2,2], [2,2,2], [3,3,3], [3,3,3], [3,3,3]], and doing this batch-wise?

My current element-wise solution btw (takes forever...):

        def get_char_context(valid_embeds, words_lens):
            chars_contexts = []
            for ve, wl in zip(valid_embeds, words_lens):
                for idx, (e, l) in enumerate(zip(ve, wl)):
                    if idx ==0:
                        chars_context = e.view(1,-1).repeat(l, 1)
                    else:
                        chars_context = torch.cat((chars_context, e.view(1,-1).repeat(l, 1)),0)
                chars_contexts.append(chars_context)
            return chars_contexts

I'm doing this to add bert word embedding to a char level seq2seq task...

Upvotes: 2

Views: 1356

Answers (1)

swag2198
swag2198

Reputation: 2696

Use this:

import torch
# A is your tensor
B = torch.tensor([1, 2, 3])
C = A.repeat_interleave(B, dim = 0)

EDIT:

The above works fine if A is a single 2D tensor. To repeat all (2D) tensors in a batch in the same manner, this is a simple workaround:

A = torch.tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], 
    [[1, 2, 3], [4, 5, 6], [2,2,2]]]) # A has 2 tensors each of shape (3, 3)
B = torch.tensor([1, 2, 3]) # Rep. of each row of every tensor in the batch

A1 = A.reshape(1, -1, A.shape[2]).squeeze()
B1 = B.repeat(A.shape[0])
C = A1.repeat_interleave(B1, dim = 0).reshape(A.shape[0], -1, A.shape[2])

C is:

tensor([[[1, 1, 1],
         [2, 2, 2],
         [2, 2, 2],
         [3, 3, 3],
         [3, 3, 3],
         [3, 3, 3]],

        [[1, 2, 3],
         [4, 5, 6],
         [4, 5, 6],
         [2, 2, 2],
         [2, 2, 2],
         [2, 2, 2]]])

As you can see each inside tensor in the batch is repeated in the same manner.

Upvotes: 4

Related Questions