Serzhev
Serzhev

Reputation: 41

Add column and row to multidimensional torch.tensor (kind of wrap-up or padding)

The tensor should be updated with additional row-zeros (bottom) and column-zeros (on the right side).

My solution will be provided below. Is there any better (actually simpler) one?

input: («ones» are just for clarification - figures might be different, because in my case there is a tensor exactly the same size but with real values in it)

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])

output:

tensor([[[[1., 1., 1., 0.],
          [1., 1., 1., 0.],
          [1., 1., 1., 0.],
          [0., 0., 0., 0.]],

         [[1., 1., 1., 0.],
          [1., 1., 1., 0.],
          [1., 1., 1., 0.],
          [0., 0., 0., 0.]],

         [[1., 1., 1., 0.],
          [1., 1., 1., 0.],
          [1., 1., 1., 0.],
          [0., 0., 0., 0.]]]])

possible solution:

x1 = torch.ones(1, 3, 3, 3)

z2 = torch.cat((torch.cat((x1[0, 0, :], torch.zeros(1, 3)), 0), torch.zeros(4, 1)), 1)
z3 = torch.cat((torch.cat((x1[0, 1, :], torch.zeros(1, 3)), 0), torch.zeros(4, 1)), 1)
z4 = torch.cat((torch.cat((x1[0, 2, :], torch.zeros(1, 3)), 0), torch.zeros(4, 1)), 1)

output_t = torch.zeros(1, 3, 4, 4)

output_t[0, 0, :] = z2
output_t[0, 1, :] = z3
output_t[0, 2, :] = z4

output_t

Upvotes: 2

Views: 6899

Answers (1)

Nerveless_child
Nerveless_child

Reputation: 1412

You can do this with pytorch's torch.nn.ConstantPad?d functions.

from torch import nn

x1 = torch.ones(1, 3, 3, 3)

pad_value = 0
pad_func = nn.ConstantPad1d((0, 1, 0, 1), pad_value)

output_t = pad_func(x1)

You could also exchange nn.ConstantPad1d with nn.ConstantPad2d or nn.ConstantPad3d. All did what you want with the same settings.

Then there is also numpy's np.pad.

x1 = torch.ones(1, 3, 3, 3)

pad_value = 0

output_n = np.pad(x1.numpy(), (0, 0), (0, 0), (0, 1), (0, 1)), "constant", constant_values=pad_value)
output_t = torch.from_numpy(output_n)

Upvotes: 2

Related Questions