Reputation: 25
I don't know how to explain it correctly, so the title might be misleading. What I want to do is to move columns from a 3d tensor t1 to another 3d tensor t2 according to the indices. There's a dictionary td, and a (k,v) pair in td means that kth column of t1 will be the vth column of t2
Currently, I'm doing it this way:
for k,v in td.items():
t2[:,:,v] = torch.select(t1, 2, k)
but yes, it's super slow, as there are millions of them. What would be the best way to do the work?
Upvotes: 2
Views: 300
Reputation: 22214
Assuming no repeated values then you can use
t2[:,:,list(td.values())] = t1[:,:,list(td.keys())]
Upvotes: 1