skydfy
skydfy

Reputation: 23

Extracting tensor data with index in pytorch

I would like to have the tensor indexed a certain way.

Suppose my data, tensor X shaped (1, 3, 16, 9) is

tensor([[[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
      [ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
      [ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
      [ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
      [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
      [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
      [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
      [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
      [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
      [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
      [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
      [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
      [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.],
      [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.],
      [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.],
      [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]],

     [[ 0.,  0.,  0.,  0., 17., 18.,  0., 21., 22.],
      [ 0.,  0.,  0., 17., 18., 19., 21., 22., 23.],
      [ 0.,  0.,  0., 18., 19., 20., 22., 23., 24.],
      [ 0.,  0.,  0., 19., 20.,  0., 23., 24.,  0.],
      [ 0., 17., 18.,  0., 21., 22.,  0., 25., 26.],
      [17., 18., 19., 21., 22., 23., 25., 26., 27.],
      [18., 19., 20., 22., 23., 24., 26., 27., 28.],
      [19., 20.,  0., 23., 24.,  0., 27., 28.,  0.],
      [ 0., 21., 22.,  0., 25., 26.,  0., 29., 30.],
      [21., 22., 23., 25., 26., 27., 29., 30., 31.],
      [22., 23., 24., 26., 27., 28., 30., 31., 32.],
      [23., 24.,  0., 27., 28.,  0., 31., 32.,  0.],
      [ 0., 25., 26.,  0., 29., 30.,  0.,  0.,  0.],
      [25., 26., 27., 29., 30., 31.,  0.,  0.,  0.],
      [26., 27., 28., 30., 31., 32.,  0.,  0.,  0.],
      [27., 28.,  0., 31., 32.,  0.,  0.,  0.,  0.]],

     [[ 0.,  0.,  0.,  0., 33., 34.,  0., 37., 38.],
      [ 0.,  0.,  0., 33., 34., 35., 37., 38., 39.],
      [ 0.,  0.,  0., 34., 35., 36., 38., 39., 40.],
      [ 0.,  0.,  0., 35., 36.,  0., 39., 40.,  0.],
      [ 0., 33., 34.,  0., 37., 38.,  0., 41., 42.],
      [33., 34., 35., 37., 38., 39., 41., 42., 43.],
      [34., 35., 36., 38., 39., 40., 42., 43., 44.],
      [35., 36.,  0., 39., 40.,  0., 43., 44.,  0.],
      [ 0., 37., 38.,  0., 41., 42.,  0., 45., 46.],
      [37., 38., 39., 41., 42., 43., 45., 46., 47.],
      [38., 39., 40., 42., 43., 44., 46., 47., 48.],
      [39., 40.,  0., 43., 44.,  0., 47., 48.,  0.],
      [ 0., 41., 42.,  0., 45., 46.,  0.,  0.,  0.],
      [41., 42., 43., 45., 46., 47.,  0.,  0.,  0.],
      [42., 43., 44., 46., 47., 48.,  0.,  0.,  0.],
      [43., 44.,  0., 47., 48.,  0.,  0.,  0.,  0.]]]]

I would like to have those rows where (row_index % n) == i (say n = 4 and i = 0 to 3) is saved in another tensor Y.

For example, for the data X[0][0]:

[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
 [ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
 [ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
 [ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
 [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
 [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
 [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
 [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
 [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
 [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
 [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
 [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
 [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.],
 [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.],
 [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.],      
 [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]

I would like to have a tensor containing the following data, which is basically collection of the rows where row_index % 4 == 0 (here i = 0):

[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
 [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
 [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
 [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]]

Similarly, where i = 1, row_index % 4 == i will look like:

[[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
 [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
 [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
 [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]]

when i = 2, row_index % 4 == i:

[[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
 [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
 [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
 [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]]

when i = 3, row_index % 4 == i:

[[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
 [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
 [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
 [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]

I have tried hard coding it and it doesn't seem practical when the data becomes larger and the size becomes dynamic and I assume that there would be a better way to come about it.

temp0 = data[0][0][0][:] 
temp1 = data[0][0][4][:]
temp2 = data[0][0][8][:]
temp3 = data[0][0][12][:]
temp = torch.stack([temp0,temp1,temp2,temp3],dim = 0)

Also, it would be great if the result can come back in one tensor like :

tensor Y = ([[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
              [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
              [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
              [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]], 

             [[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
              [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
              [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
              [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]], 
   
             [[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
              [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
              [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
              [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]], 

             [[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
              [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
              [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
              [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]])

Upvotes: 2

Views: 834

Answers (2)

Ivan
Ivan

Reputation: 40628

You can achieve this by first constructing a tensor containing the selected rows, then using torch.gather to assemble the final tensor.

Assuming we two lists I and N containing the values of i and n respectively:

I = [0, 1, 2, 3]
N = [4, 4, 4, 4]

First we construct the index tensor:

>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
tensor([[[ 0],
         [ 4],
         [ 8],
         [12]],

        [[ 1],
         [ 5],
         [ 9],
         [13]],

        [[ 2],
         [ 6],
         [10],
         [14]],

        [[ 3],
         [ 7],
         [11],
         [15]]])

Then some expanding and reshaping is required:

>>> index_ = index[None].flatten(1,2).expand(X.size(0), -1, X.size(-1))
tensor([[[ 0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  4,  4,  4,  4],
         [ 8,  8,  8,  8,  8,  8,  8,  8,  8],
         [12, 12, 12, 12, 12, 12, 12, 12, 12],
         [ 1,  1,  1,  1,  1,  1,  1,  1,  1],
         [ 5,  5,  5,  5,  5,  5,  5,  5,  5],
         [ 9,  9,  9,  9,  9,  9,  9,  9,  9],
         [13, 13, 13, 13, 13, 13, 13, 13, 13],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2],
         [ 6,  6,  6,  6,  6,  6,  6,  6,  6],
         [10, 10, 10, 10, 10, 10, 10, 10, 10],
         [14, 14, 14, 14, 14, 14, 14, 14, 14],
         [ 3,  3,  3,  3,  3,  3,  3,  3,  3],
         [ 7,  7,  7,  7,  7,  7,  7,  7,  7],
         [11, 11, 11, 11, 11, 11, 11, 11, 11],
         [15, 15, 15, 15, 15, 15, 15, 15, 15]]])

As a rule of thumb, we want index_ to have the same number of dimensions as X.

Now we can apply torch.gather and reshape to the final form:

>>> X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)
tensor([[[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
          [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
          [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
          [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
          [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
          [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
          [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
          [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
          [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
          [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
          [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
          [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
          [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]]])

This method can be extended to batch tensors:

>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
>>> index_  = index[None,None].flatten(2,3).expand(X.size(0), X.size(1), -1, X.size(-1))

>>> X.gather(2, index_).reshape(*X.shape[:2], *index.shape[:2], -1)

Upvotes: 1

A. Maman
A. Maman

Reputation: 972

First, to get each patrition you can try this:

import torch

data = torch.tensor([[[[0., 0., 0., 0., 1., 2., 0., 5., 6.],
                       [0., 0., 0., 1., 2., 3., 5., 6., 7.],
                       [0., 0., 0., 2., 3., 4., 6., 7., 8.],
                       [0., 0., 0., 3., 4., 0., 7., 8., 0.],
                       [0., 1., 2., 0., 5., 6., 0., 9., 10.],
                       [1., 2., 3., 5., 6., 7., 9., 10., 11.],
                       [2., 3., 4., 6., 7., 8., 10., 11., 12.],
                       [3., 4., 0., 7., 8., 0., 11., 12., 0.],
                       [0., 5., 6., 0., 9., 10., 0., 13., 14.],
                       [5., 6., 7., 9., 10., 11., 13., 14., 15.],
                       [6., 7., 8., 10., 11., 12., 14., 15., 16.],
                       [7., 8., 0., 11., 12., 0., 15., 16., 0.],
                       [0., 9., 10., 0., 13., 14., 0., 0., 0.],
                       [9., 10., 11., 13., 14., 15., 0., 0., 0.],
                       [10., 11., 12., 14., 15., 16., 0., 0., 0.],
                       [11., 12., 0., 15., 16., 0., 0., 0., 0.]],

                      [[0., 0., 0., 0., 17., 18., 0., 21., 22.],
                       [0., 0., 0., 17., 18., 19., 21., 22., 23.],
                       [0., 0., 0., 18., 19., 20., 22., 23., 24.],
                       [0., 0., 0., 19., 20., 0., 23., 24., 0.],
                       [0., 17., 18., 0., 21., 22., 0., 25., 26.],
                       [17., 18., 19., 21., 22., 23., 25., 26., 27.],
                       [18., 19., 20., 22., 23., 24., 26., 27., 28.],
                       [19., 20., 0., 23., 24., 0., 27., 28., 0.],
                       [0., 21., 22., 0., 25., 26., 0., 29., 30.],
                       [21., 22., 23., 25., 26., 27., 29., 30., 31.],
                       [22., 23., 24., 26., 27., 28., 30., 31., 32.],
                       [23., 24., 0., 27., 28., 0., 31., 32., 0.],
                       [0., 25., 26., 0., 29., 30., 0., 0., 0.],
                       [25., 26., 27., 29., 30., 31., 0., 0., 0.],
                       [26., 27., 28., 30., 31., 32., 0., 0., 0.],
                       [27., 28., 0., 31., 32., 0., 0., 0., 0.]],

                      [[0., 0., 0., 0., 33., 34., 0., 37., 38.],
                       [0., 0., 0., 33., 34., 35., 37., 38., 39.],
                       [0., 0., 0., 34., 35., 36., 38., 39., 40.],
                       [0., 0., 0., 35., 36., 0., 39., 40., 0.],
                       [0., 33., 34., 0., 37., 38., 0., 41., 42.],
                       [33., 34., 35., 37., 38., 39., 41., 42., 43.],
                       [34., 35., 36., 38., 39., 40., 42., 43., 44.],
                       [35., 36., 0., 39., 40., 0., 43., 44., 0.],
                       [0., 37., 38., 0., 41., 42., 0., 45., 46.],
                       [37., 38., 39., 41., 42., 43., 45., 46., 47.],
                       [38., 39., 40., 42., 43., 44., 46., 47., 48.],
                       [39., 40., 0., 43., 44., 0., 47., 48., 0.],
                       [0., 41., 42., 0., 45., 46., 0., 0., 0.],
                       [41., 42., 43., 45., 46., 47., 0., 0., 0.],
                       [42., 43., 44., 46., 47., 48., 0., 0., 0.],
                       [43., 44., 0., 47., 48., 0., 0., 0., 0.]]]])

print(data.shape)

n, i = 4, 0
indices = [index for index in range(data.shape[2]) if index % n == i]
print(data[0, 0, indices])

For the combination of those tensors you can try using:

n = 4
result = []
for i in range(n):
    indices = [index for index in range(data.shape[2]) if index % n == i]
    result.append(data[0, 0, indices])

final = torch.stack(result, dim=0)

Upvotes: 1

Related Questions