GoingMyWay
GoingMyWay

Reputation: 17468

PyTorch, apply different functions element-wise

I defined a tensor like this

t_shape = [4, 1]
data = torch.rand(t_shape)

I want to apply different functions to each row.

funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]  # each function for each row.

I can do it with the following code

d = torch.tensor([f(data[i]) for i, f in enumerate(funcs)])

How can I do it in a proper way with more advanced APIs defined in PyTorch?

Upvotes: 4

Views: 5815

Answers (1)

Wasi Ahmad
Wasi Ahmad

Reputation: 37691

I think your solution is good. But it won't work with any tensor shape. You can slightly modify the solution as follows.

t_shape = [4, 10, 10]
data = torch.rand(t_shape)

funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]

# only change the following 2 lines
d = [f(data[i]) for i, f in enumerate(funcs)]
d = torch.stack(d, dim=0) 

Upvotes: 5

Related Questions