Reputation: 17468
I defined a tensor like this
t_shape = [4, 1]
data = torch.rand(t_shape)
I want to apply different functions to each row.
funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2] # each function for each row.
I can do it with the following code
d = torch.tensor([f(data[i]) for i, f in enumerate(funcs)])
How can I do it in a proper way with more advanced APIs defined in PyTorch?
Upvotes: 4
Views: 5815
Reputation: 37691
I think your solution is good. But it won't work with any tensor shape. You can slightly modify the solution as follows.
t_shape = [4, 10, 10]
data = torch.rand(t_shape)
funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]
# only change the following 2 lines
d = [f(data[i]) for i, f in enumerate(funcs)]
d = torch.stack(d, dim=0)
Upvotes: 5