Reputation: 23
I have a 6 dimensional all-zero pytorch tensor lrel_w
that I want to fill with 1s at positions where the indices of the first three dimensions and the indices of the last three dimensions match. I'm currently solving this trivially using 3 nested for loops:
lrel_w = torch.zeros(
input_size[0], input_size[1], input_size[2],
input_size[0], input_size[1], input_size[2]
)
for c in range(input_size[0]):
for x in range(input_size[1]):
for y in range(input_size[2]):
lrel_w[c,x,y,c,x,y] = 1
I'm sure there must be a more efficient way of doing this, however I have not been able to figure it out.
Upvotes: 2
Views: 1077
Reputation: 2696
You can try this one.
import torch
c, m, n = input_size[0], input_size[1], input_size[2]
t = torch.zeros(c, m, n, c, m, n)
i, j, k = torch.meshgrid(torch.arange(c), torch.arange(m), torch.arange(n))
i = i.flatten()
j = j.flatten()
k = k.flatten()
t[i, j, k, i, j, k] = 1
Here is how meshgrid works in case you need reference.
Upvotes: 3