Reputation: 31
I want to make a function f1(arg_tensor)
which gets a pytorch tensor as an argument.
In this function I use another function:
f2(tensor_row_1, tensor_row_2)
which gets two pytorch's tensor rows as an arguments and outputs a scalar.
f2(..)
should be applied over all combinations of tensor's rows [1..n]
(i.e. apply function f2(..)
on tensor rows' indices: [0,1], [0,2], [0,3]...[0,n-1]...[n-1,0]..[n-1,n-1]
).
The output of f1(..)
should be a tensor such that at element [0,0]
there will the output value of f2(tensor_rows[0], tensor_rows[0])
and so on...
Is there a way to perform it efficiently (and not with double for
loop)?
Upvotes: 3
Views: 1622
Reputation: 24681
Yes, one can do it with a simple broadcasting trick:
def f1(tensor):
tensor = tensor.permute(1, 0)
return torch.nn.functional.kl_div(
tensor.unsqueeze(dim=2), tensor.unsqueeze(dim=1), reduction="none"
).mean(dim=0)
def manual_f1(tensor):
result = []
for row1 in tensor:
for row2 in tensor:
result.append(torch.nn.functional.kl_div(row1, row2))
return torch.stack(result).reshape(tensor.shape[0], -1)
data = torch.randn(5, 7)
result = f1(data)
manual_result = manual_f1(data)
print(torch.all(result == manual_result).item())
Please notice, for more rows the result will differ due to numerical difference. You can:
print
the values and inspect manuallytorch.isclose
to verify similarityIn the second case, last print
would become:
print(torch.all(torch.isclose(result, manual_result)).item())
Upvotes: 2