obtuseFox
obtuseFox

Reputation: 25

reorder columns in a tensor according to a dictionary

I don't know how to explain it correctly, so the title might be misleading. What I want to do is to move columns from a 3d tensor t1 to another 3d tensor t2 according to the indices. There's a dictionary td, and a (k,v) pair in td means that kth column of t1 will be the vth column of t2

Currently, I'm doing it this way:

for k,v in td.items():
    t2[:,:,v] = torch.select(t1, 2, k)

but yes, it's super slow, as there are millions of them. What would be the best way to do the work?

Upvotes: 2

Views: 300

Answers (1)

jodag
jodag

Reputation: 22214

Assuming no repeated values then you can use

t2[:,:,list(td.values())] = t1[:,:,list(td.keys())]

Upvotes: 1

Related Questions